AMD GPU पर Llama 405B का फाइन-ट्यूनिंग
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (हमारी यात्रा)
परिचय
- जैसे-जैसे open source मॉडल बड़े होते जा रहे हैं, बड़े पैमाने की AI training को संभालने के लिए शक्तिशाली infrastructure की आवश्यकता बढ़ रही है
- Felafax ने AMD GPU पर LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करके AMD hardware की दक्षता साबित की
- पूरा काम GitHub पर open source के रूप में सार्वजनिक किया गया है
- AMD MI300X GPU, NVIDIA AI hardware की तुलना में उच्च प्रदर्शन प्रदान करता है
- यह प्रोजेक्ट TensorWave के समर्थन से संभव हो पाया
JAX क्या है और इसे क्यों चुना गया
- JAX एक शक्तिशाली machine learning library है, जो NumPy-जैसे API, automatic differentiation, और Google's XLA compiler को जोड़ती है
- यह model parallelism के लिए उत्कृष्ट API प्रदान करती है, जिससे यह बड़े मॉडलों की training के लिए आदर्श बनती है
JAX के फायदे
- Pure functions: JAX pure functions लिखने को प्रोत्साहित करता है, जिससे code को compose करना, debug करना, और पढ़ना आसान होता है
- Advanced parallelism: JAX का flexible JIT API, advanced data और model parallelism को support करता है, जो बड़े पैमाने की training के लिए आवश्यक है
- Clean codebase: JAX की design philosophy अलग-अलग hardware platforms के बीच portable code लिखने को प्रोत्साहित करती है
JAX non-NVIDIA hardware पर क्यों बेहतर है
- Hardware-independent approach: JAX, XLA compiler का उपयोग करके computations को hardware-independent intermediate representation में compile करता है
- Platform-independent optimization: XLA compiler hardware से स्वतंत्र रूप से optimization करता है
- आसान portability: JAX का उपयोग करने पर NVIDIA से AMD पर जाने में code changes न्यूनतम रहते हैं
AMD GPU पर JAX सेटअप
- Docker image को pull करने, container शुरू करने और installation verify करने के बाद सेटअप पूरा किया गया
- 8 AMD MI300x GPU का उपयोग करके LLaMA 405B मॉडल को train किया गया
LLaMA 405B training: प्रदर्शन और scalability
- JAX का उपयोग करके AMD GPU पर LLaMA 405B मॉडल को train किया गया
- LoRA फाइन-ट्यूनिंग के माध्यम से model weights और LoRA parameters को bfloat16 precision में adjust किया गया
- मॉडल आकार: लगभग 800GB VRAM का उपयोग
- LoRA weights और optimizer state: लगभग 400GB VRAM का उपयोग
- कुल VRAM उपयोग: लगभग 1200GB
- Training speed: लगभग 35 tokens प्रति सेकंड
- Memory efficiency: लगभग 70% बनी रही
- Scalability: JAX के साथ 8 GPU पर लगभग linear scaling मिली
हमारी training setup
- LLaMA 3.1 को PyTorch से JAX में convert किया गया
- Model loading और parameter sharding के माध्यम से इसे कुशलतापूर्वक distribute किया गया
JAX में parameter sharding
- JAX की device mesh functionality का उपयोग करके मॉडल को 8 AMD GPU में कुशलतापूर्वक distribute किया गया
- Parameter sharding rules परिभाषित करके प्रत्येक tensor के dimensions को mesh axes के अनुसार shard किया गया
LoRA training implementation
- LoRA, weight updates को low-rank matrices में विभाजित करके trainable parameters की संख्या कम करता है
- LoRADense layer को implement किया गया, जिसमें LoRA parameters शामिल हैं
- LoRA parameters को कुशलतापूर्वक distribute करके memory usage और computation efficiency को optimize किया गया
निष्कर्ष
- AMD GPU और JAX का उपयोग करके LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करने का अनुभव बहुत सकारात्मक रहा
- JAX की मजबूत parallelism capabilities और hardware-independent approach का उपयोग करके मॉडल को कुशलतापूर्वक distribute किया गया
- इससे साबित हुआ कि AMD GPU बड़े पैमाने की AI training के लिए एक मजबूत विकल्प हैं
- पूरा code GitHub repository में देखा और सीधे चलाया जा सकता है
GN⁺ की संक्षिप्त टिप्पणी
- यह लेख AMD GPU और JAX का उपयोग करके बड़े AI मॉडलों को कुशलतापूर्वक train करने का तरीका बताता है
- यह रेखांकित करता है कि AMD hardware, NVIDIA की तुलना में अधिक cost-effective विकल्प हो सकता है
- JAX का hardware-independent approach, code portability बढ़ाता है और maintenance को आसान बनाता है
- बड़े मॉडल training में रुचि रखने वालों के लिए यह उपयोगी जानकारी और practical code प्रदान करता है
- समान क्षमताओं वाले प्रोजेक्ट्स में NVIDIA का CUDA और PyTorch शामिल हैं
1 टिप्पणियां
Hacker News की राय
हाल ही में 8xAMD MI300x GPU पर PyTorch की जगह JAX से llama3.1 405B मॉडल को fine-tune किया
JAX के advanced sharding API की वजह से अच्छा performance मिला, और इस्तेमाल की गई sharding तकनीकें ब्लॉग में लिखी हैं। कोड भी open कर दिया है: https://github.com/felafax/felafax
हम एक छोटा startup हैं जो NVIDIA के अलावा दूसरे hardware (TPU, AMD, Trainium) पर LLM fine-tuning और serving के लिए AI infrastructure बना रहा है
कई कंपनियां AMD GPU पर PyTorch चलाने की कोशिश कर रही हैं, लेकिन PyTorch
torch.cudaयाscaled_dot_product_attentionजैसे हिस्सों के साथ NVIDIA ecosystem में गहराई से जुड़ा है, इसलिए लगता है कि उसे “NVIDIA से अलग” करने के लिए काफी काम चाहिएJAX में model code hardware-independent HLO graph में compile होता है, फिर XLA compiler उसे optimize करता है और उसके बाद hardware-specific optimizations लगाता है, इसलिए हमें लगता है कि यह NVIDIA के अलावा दूसरे hardware के लिए बेहतर फिट है। वही LLaMA3 JAX code Google TPU और AMD GPU पर बिना बदलाव चला
कंपनी की strategy है कि पहले models को JAX में port किया जाए, फिर JAX framework और XLA kernels का इस्तेमाल करके NVIDIA के अलावा backends पर maximum performance निकाला जाए। इसलिए हमने पहले Llama 3.1 को PyTorch से JAX में लाया, और वही JAX model TPU और AMD GPU पर अच्छे से चलता है
निजी तौर पर PyTorch इस्तेमाल करने की मुख्य वजह यह है कि original model PyTorch में बना था। अलग-अलग model versions में logic समान दिखे तब भी, बेहद बड़े data scale पर बहुत छोटी floating-point errors जमा होकर model drift पैदा कर सकती हैं
बड़े models में ऐसी accuracy mismatch debug करना नरक के 10वें घेरे से भी ज्यादा तकलीफदेह काम जैसा था
hipblaslt, Composable Kernel FA वगैरहमैं JAX को बहुत अच्छी तरह नहीं जानता, लेकिन MI300x पर PyTorch training performance खराब होने की बड़ी वजहों में से एक अंदर इस्तेमाल होने वाली ROCm libraries की धीमी performance लगती है
यहां चलने का मतलब यह नहीं है कि drivers सेट करने में 2 हफ्ते लग जाएं और उसके बाद server को फिर कभी update न कर पाने वाली हालत हो
सामने आए technical issues भी जानना चाहूंगा
साफ कहूं तो यह performance काफी खराब है। शायद compilation ठीक से काम नहीं करा पाए, इसलिए ऐसा लग रहा है
405B model पर 35 tokens/sec मिल रहे हैं, जो करीब 85 teraflops के बराबर है। 8 MI300x GPU करीब 10.4 petaflops level के हैं, इसलिए MFU लगभग 0.8% है
decent training performance यानी 30~40% MFU से यह 40~50 गुना कम है, इसलिए AMD के लिए अच्छा होगा अगर bottleneck software stack ही निकले
GitHub page कहता है कि “Google Cloud TPU पर LLaMa3.1 को 30% कम cost में tune कर सकते हैं”, लेकिन performance का जिक्र नहीं है
शानदार काम। लगभग एक साल पहले AMD GPU और ROCm support के साथ थोड़ा काम किया था, और साफ था कि AMD को Nvidia तक पहुंचने के लिए अभी लंबा रास्ता तय करना है
JAX चुनने वाला approach दिलचस्प है, लेकिन machine learning की standard library जैसी बन चुकी PyTorch से हटते समय कौन-सी मुश्किलें आईं, यह जानना चाहूंगा
शुरुआत में लक्ष्य TPU पर LLaMA 3 fine-tune करना था, लेकिन PyTorch XLA भद्दा लगा, इसलिए model को JAX में फिर से लिखने का फैसला किया
जैसा ऊपर कहा, हमें लगता है कि JAX NVIDIA के अलावा GPU के लिए बेहतर platform है, और हम JAX+openXLA के ऊपर NVIDIA के अलावा GPU के लिए infrastructure बनाना चाहते हैं
अच्छा काम। पिछले weekend मैं भी 405B के inference वाले हिस्से से छेड़छाड़ कर रहा था [0]
torch.cudaइतना बुरा है, इस पर मुझे यकीन नहीं है। AMD के लिए PyTorch उसे replace/translate कर देता है। यह असल समस्या से ज्यादा naming issue जैसा हैसच में
rocm:pytorchcontainer लेनाrocm:jaxcontainer लेने जितना ही आसान हैpublished numbers ज्यादा नहीं हैं, जानना चाहूंगा कि MFU कितना निकला
[0] https://x.com/HotAisle/status/1837580046732874026
MFU calculate करना पड़ेगा। GPU और VRAM details repository में देख सकते हैं: https://dub.sh/amd-405b-res
अगले weekend training run फिर से try करते हुए पूरे training step को JIT compile करने और तब MFU calculate करने की योजना है
ZML में हमने जब measure किया, MI300X H100 से 30% तेज था। ये शानदार chips हैं
जानना चाहूंगा कि क्या कोई cloud provider है जहां 8xAMD MI300 host rent कर सकें
काम में AWS बहुत इस्तेमाल करता हूं, लेकिन AMD GPU एक बार try करना चाहता था
performance data कहां है?
code और VRAM constraints की वजह से 405B model का JIT compiled version नहीं चला पाए। इस हिस्से पर और investigation चाहिए
पूरा training run JAX eager mode में किया गया था, इसलिए performance improvement की काफी गुंजाइश है
eager mode में भी GPU utilization कुल मिलाकर लगभग 30~40% था, जो काफी ठीक है। JIT इस्तेमाल करने पर GPU utilization आसानी से 50~60% तक जा सकता है, ऐसा लगता है
अगर संभव हो तो memory constraints को पार करके JIT compiled version चलाने का तरीका explore करना दिलचस्प होगा। इससे और performance improvement मिल सकता है
JIT compiled training step, ज्यादा optimized data loading और sharding, gradient accumulation, activation checkpointing चाहिए
बनाना जारी रखेंगे और सभी improvements implement करने के बाद जल्द ही फिर ब्लॉग पोस्ट करेंगे
जानना चाहूंगा कि AMD GPU की बड़ी orders और supply shortage के जरिए यहां value निकालने के थोड़ा भी करीब पहुंचा है या नहीं
मेरा impression “नहीं” के ज्यादा करीब है
सामने वाले के पास बहुत बड़ा first-mover advantage है, और software side पर करने को साफ तौर पर बहुत काम है। समय लगेगा
note app Obsidian यह क्यों कर रहा है?