2 पॉइंट द्वारा GN⁺ 2024-09-24 | 1 टिप्पणियां | WhatsApp पर शेयर करें

Felafax BlogTune Llama3 405B on AMD MI300x (हमारी यात्रा)

परिचय

  • जैसे-जैसे open source मॉडल बड़े होते जा रहे हैं, बड़े पैमाने की AI training को संभालने के लिए शक्तिशाली infrastructure की आवश्यकता बढ़ रही है
  • Felafax ने AMD GPU पर LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करके AMD hardware की दक्षता साबित की
  • पूरा काम GitHub पर open source के रूप में सार्वजनिक किया गया है
  • AMD MI300X GPU, NVIDIA AI hardware की तुलना में उच्च प्रदर्शन प्रदान करता है
  • यह प्रोजेक्ट TensorWave के समर्थन से संभव हो पाया

JAX क्या है और इसे क्यों चुना गया

  • JAX एक शक्तिशाली machine learning library है, जो NumPy-जैसे API, automatic differentiation, और Google's XLA compiler को जोड़ती है
  • यह model parallelism के लिए उत्कृष्ट API प्रदान करती है, जिससे यह बड़े मॉडलों की training के लिए आदर्श बनती है

JAX के फायदे

  • Pure functions: JAX pure functions लिखने को प्रोत्साहित करता है, जिससे code को compose करना, debug करना, और पढ़ना आसान होता है
  • Advanced parallelism: JAX का flexible JIT API, advanced data और model parallelism को support करता है, जो बड़े पैमाने की training के लिए आवश्यक है
  • Clean codebase: JAX की design philosophy अलग-अलग hardware platforms के बीच portable code लिखने को प्रोत्साहित करती है

JAX non-NVIDIA hardware पर क्यों बेहतर है

  • Hardware-independent approach: JAX, XLA compiler का उपयोग करके computations को hardware-independent intermediate representation में compile करता है
  • Platform-independent optimization: XLA compiler hardware से स्वतंत्र रूप से optimization करता है
  • आसान portability: JAX का उपयोग करने पर NVIDIA से AMD पर जाने में code changes न्यूनतम रहते हैं

AMD GPU पर JAX सेटअप

  • Docker image को pull करने, container शुरू करने और installation verify करने के बाद सेटअप पूरा किया गया
  • 8 AMD MI300x GPU का उपयोग करके LLaMA 405B मॉडल को train किया गया

LLaMA 405B training: प्रदर्शन और scalability

  • JAX का उपयोग करके AMD GPU पर LLaMA 405B मॉडल को train किया गया
  • LoRA फाइन-ट्यूनिंग के माध्यम से model weights और LoRA parameters को bfloat16 precision में adjust किया गया
  • मॉडल आकार: लगभग 800GB VRAM का उपयोग
  • LoRA weights और optimizer state: लगभग 400GB VRAM का उपयोग
  • कुल VRAM उपयोग: लगभग 1200GB
  • Training speed: लगभग 35 tokens प्रति सेकंड
  • Memory efficiency: लगभग 70% बनी रही
  • Scalability: JAX के साथ 8 GPU पर लगभग linear scaling मिली

हमारी training setup

  • LLaMA 3.1 को PyTorch से JAX में convert किया गया
  • Model loading और parameter sharding के माध्यम से इसे कुशलतापूर्वक distribute किया गया

JAX में parameter sharding

  • JAX की device mesh functionality का उपयोग करके मॉडल को 8 AMD GPU में कुशलतापूर्वक distribute किया गया
  • Parameter sharding rules परिभाषित करके प्रत्येक tensor के dimensions को mesh axes के अनुसार shard किया गया

LoRA training implementation

  • LoRA, weight updates को low-rank matrices में विभाजित करके trainable parameters की संख्या कम करता है
  • LoRADense layer को implement किया गया, जिसमें LoRA parameters शामिल हैं
  • LoRA parameters को कुशलतापूर्वक distribute करके memory usage और computation efficiency को optimize किया गया

निष्कर्ष

  • AMD GPU और JAX का उपयोग करके LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करने का अनुभव बहुत सकारात्मक रहा
  • JAX की मजबूत parallelism capabilities और hardware-independent approach का उपयोग करके मॉडल को कुशलतापूर्वक distribute किया गया
  • इससे साबित हुआ कि AMD GPU बड़े पैमाने की AI training के लिए एक मजबूत विकल्प हैं
  • पूरा code GitHub repository में देखा और सीधे चलाया जा सकता है

GN⁺ की संक्षिप्त टिप्पणी

  • यह लेख AMD GPU और JAX का उपयोग करके बड़े AI मॉडलों को कुशलतापूर्वक train करने का तरीका बताता है
  • यह रेखांकित करता है कि AMD hardware, NVIDIA की तुलना में अधिक cost-effective विकल्प हो सकता है
  • JAX का hardware-independent approach, code portability बढ़ाता है और maintenance को आसान बनाता है
  • बड़े मॉडल training में रुचि रखने वालों के लिए यह उपयोगी जानकारी और practical code प्रदान करता है
  • समान क्षमताओं वाले प्रोजेक्ट्स में NVIDIA का CUDA और PyTorch शामिल हैं

1 टिप्पणियां

 
GN⁺ 2024-09-24
Hacker News की राय
  • हाल ही में 8xAMD MI300x GPU पर PyTorch की जगह JAX से llama3.1 405B मॉडल को fine-tune किया
    JAX के advanced sharding API की वजह से अच्छा performance मिला, और इस्तेमाल की गई sharding तकनीकें ब्लॉग में लिखी हैं। कोड भी open कर दिया है: https://github.com/felafax/felafax
    हम एक छोटा startup हैं जो NVIDIA के अलावा दूसरे hardware (TPU, AMD, Trainium) पर LLM fine-tuning और serving के लिए AI infrastructure बना रहा है
    कई कंपनियां AMD GPU पर PyTorch चलाने की कोशिश कर रही हैं, लेकिन PyTorch torch.cuda या scaled_dot_product_attention जैसे हिस्सों के साथ NVIDIA ecosystem में गहराई से जुड़ा है, इसलिए लगता है कि उसे “NVIDIA से अलग” करने के लिए काफी काम चाहिए
    JAX में model code hardware-independent HLO graph में compile होता है, फिर XLA compiler उसे optimize करता है और उसके बाद hardware-specific optimizations लगाता है, इसलिए हमें लगता है कि यह NVIDIA के अलावा दूसरे hardware के लिए बेहतर फिट है। वही LLaMA3 JAX code Google TPU और AMD GPU पर बिना बदलाव चला
    कंपनी की strategy है कि पहले models को JAX में port किया जाए, फिर JAX framework और XLA kernels का इस्तेमाल करके NVIDIA के अलावा backends पर maximum performance निकाला जाए। इसलिए हमने पहले Llama 3.1 को PyTorch से JAX में लाया, और वही JAX model TPU और AMD GPU पर अच्छे से चलता है

    • AMD GPU पर PyTorch को CUDA code बदले बिना चलाने में कोई खास दिक्कत नहीं हुई। MosaicML ब्लॉग भी देखने लायक है: https://www.databricks.com/blog/training-llms-scale-amd-mi25...
    • जानना चाहूंगा कि Llama 3.1 की JAX porting accuracy को कैसे verify कर रहे हैं
      निजी तौर पर PyTorch इस्तेमाल करने की मुख्य वजह यह है कि original model PyTorch में बना था। अलग-अलग model versions में logic समान दिखे तब भी, बेहद बड़े data scale पर बहुत छोटी floating-point errors जमा होकर model drift पैदा कर सकती हैं
      बड़े models में ऐसी accuracy mismatch debug करना नरक के 10वें घेरे से भी ज्यादा तकलीफदेह काम जैसा था
    • जानना चाहूंगा कि JAX के पास matrix multiplication या FlashAttention की अपनी implementation है, या वह PyTorch की तरह ROCm implementation इस्तेमाल करता है। जैसे hipblaslt, Composable Kernel FA वगैरह
      मैं JAX को बहुत अच्छी तरह नहीं जानता, लेकिन MI300x पर PyTorch training performance खराब होने की बड़ी वजहों में से एक अंदर इस्तेमाल होने वाली ROCm libraries की धीमी performance लगती है
    • जानना चाहूंगा कि क्या यह 7900 XTX जैसे consumer cards पर भी चलता है
      यहां चलने का मतलब यह नहीं है कि drivers सेट करने में 2 हफ्ते लग जाएं और उसके बाद server को फिर कभी update न कर पाने वाली हालत हो
    • अगर migration है, तो क्या उसी model की PyTorch version से तुलना वाले actual numbers हैं? लेख की comparison table ज्यादा technical पहलू जैसी दिखती है
      सामने आए technical issues भी जानना चाहूंगा
  • साफ कहूं तो यह performance काफी खराब है। शायद compilation ठीक से काम नहीं करा पाए, इसलिए ऐसा लग रहा है
    405B model पर 35 tokens/sec मिल रहे हैं, जो करीब 85 teraflops के बराबर है। 8 MI300x GPU करीब 10.4 petaflops level के हैं, इसलिए MFU लगभग 0.8% है
    decent training performance यानी 30~40% MFU से यह 40~50 गुना कम है, इसलिए AMD के लिए अच्छा होगा अगर bottleneck software stack ही निकले

    • मैं भी ठीक यही पूछना चाहता था
      GitHub page कहता है कि “Google Cloud TPU पर LLaMa3.1 को 30% कम cost में tune कर सकते हैं”, लेकिन performance का जिक्र नहीं है
  • शानदार काम। लगभग एक साल पहले AMD GPU और ROCm support के साथ थोड़ा काम किया था, और साफ था कि AMD को Nvidia तक पहुंचने के लिए अभी लंबा रास्ता तय करना है
    JAX चुनने वाला approach दिलचस्प है, लेकिन machine learning की standard library जैसी बन चुकी PyTorch से हटते समय कौन-सी मुश्किलें आईं, यह जानना चाहूंगा

    • कुछ हफ्ते पहले हमने अपनी journey समझाते हुए Show HN पोस्ट किया था: https://news.ycombinator.com/item?id=41512142
      शुरुआत में लक्ष्य TPU पर LLaMA 3 fine-tune करना था, लेकिन PyTorch XLA भद्दा लगा, इसलिए model को JAX में फिर से लिखने का फैसला किया
      जैसा ऊपर कहा, हमें लगता है कि JAX NVIDIA के अलावा GPU के लिए बेहतर platform है, और हम JAX+openXLA के ऊपर NVIDIA के अलावा GPU के लिए infrastructure बनाना चाहते हैं
    • अपने Debian 12 system पर AMD ROCm चलवा नहीं पा रहा हूं, इसलिए लगता है कि Ollama GPU की जगह CPU इस्तेमाल कर रहा है। अभी लंबा रास्ता बाकी दिखता है
  • अच्छा काम। पिछले weekend मैं भी 405B के inference वाले हिस्से से छेड़छाड़ कर रहा था [0]
    torch.cuda इतना बुरा है, इस पर मुझे यकीन नहीं है। AMD के लिए PyTorch उसे replace/translate कर देता है। यह असल समस्या से ज्यादा naming issue जैसा है
    सच में rocm:pytorch container लेना rocm:jax container लेने जितना ही आसान है
    published numbers ज्यादा नहीं हैं, जानना चाहूंगा कि MFU कितना निकला
    [0] https://x.com/HotAisle/status/1837580046732874026

    • अच्छा
      MFU calculate करना पड़ेगा। GPU और VRAM details repository में देख सकते हैं: https://dub.sh/amd-405b-res
      अगले weekend training run फिर से try करते हुए पूरे training step को JIT compile करने और तब MFU calculate करने की योजना है
  • ZML में हमने जब measure किया, MI300X H100 से 30% तेज था। ये शानदार chips हैं

  • जानना चाहूंगा कि क्या कोई cloud provider है जहां 8xAMD MI300 host rent कर सकें
    काम में AWS बहुत इस्तेमाल करता हूं, लेकिन AMD GPU एक बार try करना चाहता था

    • संदर्भ के लिए, हमारी company 8xMI300x rent पर दे रही है, चाहें तो संपर्क कर सकते हैं
    • Oracle देता है। बाकी भी शायद follow करेंगे, लेकिन मुझे लगता है छोटे providers से deal करना ज्यादा reasonable होगा
  • performance data कहां है?

    • GitHub repository में GPU और VRAM utilization data जोड़ दिया है: https://github.com/felafax/felafax?tab=readme-ov-file#amd-40...
      code और VRAM constraints की वजह से 405B model का JIT compiled version नहीं चला पाए। इस हिस्से पर और investigation चाहिए
      पूरा training run JAX eager mode में किया गया था, इसलिए performance improvement की काफी गुंजाइश है
      eager mode में भी GPU utilization कुल मिलाकर लगभग 30~40% था, जो काफी ठीक है। JIT इस्तेमाल करने पर GPU utilization आसानी से 50~60% तक जा सकता है, ऐसा लगता है
  • अगर संभव हो तो memory constraints को पार करके JIT compiled version चलाने का तरीका explore करना दिलचस्प होगा। इससे और performance improvement मिल सकता है

    • सहमत। अभी निकालने के लिए बहुत performance बची है
      JIT compiled training step, ज्यादा optimized data loading और sharding, gradient accumulation, activation checkpointing चाहिए
      बनाना जारी रखेंगे और सभी improvements implement करने के बाद जल्द ही फिर ब्लॉग पोस्ट करेंगे
  • जानना चाहूंगा कि AMD GPU की बड़ी orders और supply shortage के जरिए यहां value निकालने के थोड़ा भी करीब पहुंचा है या नहीं
    मेरा impression “नहीं” के ज्यादा करीब है

    • तंज समझ में आया। लेकिन अगर इस समय AI के hardware और software को पूरी तरह एक ही supplier पर छोड़ने का इरादा नहीं है, तो alternatives की तरफ बढ़ना शुरू करना होगा
      सामने वाले के पास बहुत बड़ा first-mover advantage है, और software side पर करने को साफ तौर पर बहुत काम है। समय लगेगा
  • note app Obsidian यह क्यों कर रहा है?

    • ऐसा नहीं है। यह company documents publish करने के लिए Obsidian Publish इस्तेमाल कर रही है