2 पॉइंट द्वारा GN⁺ 2024-09-24 | 1 टिप्पणियां | WhatsApp पर शेयर करें

Felafax BlogTune Llama3 405B on AMD MI300x (हमारी यात्रा)

परिचय

  • जैसे-जैसे open source मॉडल बड़े होते जा रहे हैं, बड़े पैमाने की AI training को संभालने के लिए शक्तिशाली infrastructure की आवश्यकता बढ़ रही है
  • Felafax ने AMD GPU पर LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करके AMD hardware की दक्षता साबित की
  • पूरा काम GitHub पर open source के रूप में सार्वजनिक किया गया है
  • AMD MI300X GPU, NVIDIA AI hardware की तुलना में उच्च प्रदर्शन प्रदान करता है
  • यह प्रोजेक्ट TensorWave के समर्थन से संभव हो पाया

JAX क्या है और इसे क्यों चुना गया

  • JAX एक शक्तिशाली machine learning library है, जो NumPy-जैसे API, automatic differentiation, और Google's XLA compiler को जोड़ती है
  • यह model parallelism के लिए उत्कृष्ट API प्रदान करती है, जिससे यह बड़े मॉडलों की training के लिए आदर्श बनती है

JAX के फायदे

  • Pure functions: JAX pure functions लिखने को प्रोत्साहित करता है, जिससे code को compose करना, debug करना, और पढ़ना आसान होता है
  • Advanced parallelism: JAX का flexible JIT API, advanced data और model parallelism को support करता है, जो बड़े पैमाने की training के लिए आवश्यक है
  • Clean codebase: JAX की design philosophy अलग-अलग hardware platforms के बीच portable code लिखने को प्रोत्साहित करती है

JAX non-NVIDIA hardware पर क्यों बेहतर है

  • Hardware-independent approach: JAX, XLA compiler का उपयोग करके computations को hardware-independent intermediate representation में compile करता है
  • Platform-independent optimization: XLA compiler hardware से स्वतंत्र रूप से optimization करता है
  • आसान portability: JAX का उपयोग करने पर NVIDIA से AMD पर जाने में code changes न्यूनतम रहते हैं

AMD GPU पर JAX सेटअप

  • Docker image को pull करने, container शुरू करने और installation verify करने के बाद सेटअप पूरा किया गया
  • 8 AMD MI300x GPU का उपयोग करके LLaMA 405B मॉडल को train किया गया

LLaMA 405B training: प्रदर्शन और scalability

  • JAX का उपयोग करके AMD GPU पर LLaMA 405B मॉडल को train किया गया
  • LoRA फाइन-ट्यूनिंग के माध्यम से model weights और LoRA parameters को bfloat16 precision में adjust किया गया
  • मॉडल आकार: लगभग 800GB VRAM का उपयोग
  • LoRA weights और optimizer state: लगभग 400GB VRAM का उपयोग
  • कुल VRAM उपयोग: लगभग 1200GB
  • Training speed: लगभग 35 tokens प्रति सेकंड
  • Memory efficiency: लगभग 70% बनी रही
  • Scalability: JAX के साथ 8 GPU पर लगभग linear scaling मिली

हमारी training setup

  • LLaMA 3.1 को PyTorch से JAX में convert किया गया
  • Model loading और parameter sharding के माध्यम से इसे कुशलतापूर्वक distribute किया गया

JAX में parameter sharding

  • JAX की device mesh functionality का उपयोग करके मॉडल को 8 AMD GPU में कुशलतापूर्वक distribute किया गया
  • Parameter sharding rules परिभाषित करके प्रत्येक tensor के dimensions को mesh axes के अनुसार shard किया गया

LoRA training implementation

  • LoRA, weight updates को low-rank matrices में विभाजित करके trainable parameters की संख्या कम करता है
  • LoRADense layer को implement किया गया, जिसमें LoRA parameters शामिल हैं
  • LoRA parameters को कुशलतापूर्वक distribute करके memory usage और computation efficiency को optimize किया गया

निष्कर्ष

  • AMD GPU और JAX का उपयोग करके LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करने का अनुभव बहुत सकारात्मक रहा
  • JAX की मजबूत parallelism capabilities और hardware-independent approach का उपयोग करके मॉडल को कुशलतापूर्वक distribute किया गया
  • इससे साबित हुआ कि AMD GPU बड़े पैमाने की AI training के लिए एक मजबूत विकल्प हैं
  • पूरा code GitHub repository में देखा और सीधे चलाया जा सकता है

GN⁺ की संक्षिप्त टिप्पणी

  • यह लेख AMD GPU और JAX का उपयोग करके बड़े AI मॉडलों को कुशलतापूर्वक train करने का तरीका बताता है
  • यह रेखांकित करता है कि AMD hardware, NVIDIA की तुलना में अधिक cost-effective विकल्प हो सकता है
  • JAX का hardware-independent approach, code portability बढ़ाता है और maintenance को आसान बनाता है
  • बड़े मॉडल training में रुचि रखने वालों के लिए यह उपयोगी जानकारी और practical code प्रदान करता है
  • समान क्षमताओं वाले प्रोजेक्ट्स में NVIDIA का CUDA और PyTorch शामिल हैं

1 टिप्पणियां

 
GN⁺ 2024-09-24
Hacker News प्रतिक्रिया
  • JAX का उपयोग करके 8xAMD MI300x GPU पर Llama3.1 405B मॉडल को fine-tune करने की उपलब्धि साझा की गई

    • JAX के उन्नत sharding API की बदौलत शानदार performance हासिल की गई
    • ब्लॉग पोस्ट और open source code का लिंक उपलब्ध: GitHub link
    • यह एक startup है जो NVIDIA hardware के बजाय TPU, AMD, और Trainium पर LLM को fine-tune और serve करने के लिए AI infrastructure बना रहा है
    • कई कंपनियाँ AMD GPU पर PyTorch चलाने की कोशिश कर रही हैं, लेकिन इसे कठिन रास्ता माना गया है
    • PyTorch, NVIDIA ecosystem से गहराई से जुड़ा हुआ है, इसलिए इसे non-NVIDIA hardware पर चलाने के लिए काफी बदलाव करने पड़ते हैं
    • उनका मानना है कि JAX, non-NVIDIA hardware के लिए अधिक उपयुक्त है
    • JAX में ML model code को hardware-independent HLO graph में compile किया जाता है, और XLA compiler hardware-specific optimization करता है
    • वही JAX code Google TPU और AMD GPU पर बिना किसी बदलाव के चल सकता है
    • कंपनी की रणनीति JAX में models को port करना और XLA kernels का उपयोग करके non-NVIDIA backend से अधिकतम performance निकालना है
    • उन्होंने सबसे पहले Llama 3.1 को PyTorch से JAX में port किया, और अब वही JAX model TPU और AMD GPU दोनों पर अच्छी तरह काम करता है
    • वे अपने vision और repository पर राय सुनना चाहते हैं
  • memory constraints को पार करके JIT-compiled version चलाने के तरीकों की पड़ताल करने का सुझाव

    • इससे अतिरिक्त performance improvement मिल सकता है
  • AMD GPU और ROCm support के बारे में अनुभव साझा किया गया

    • एक साल पहले AMD GPU और ROCm support आज़माया था, लेकिन लगा कि AMD को NVIDIA की बराबरी तक पहुँचने में अभी काफी समय लगेगा
    • JAX चुनना एक दिलचस्प approach है, लेकिन यह जानने की जिज्ञासा है कि PyTorch से हटने में कौन-सी मुश्किलें आईं
  • 405B मॉडल के inference पहलू पर किए गए प्रयोग का अनुभव साझा किया गया

    • उनका मानना है कि 'torch.cuda' उतना बुरा नहीं है
    • AMD version का PyTorch इसे translate कर देता है, इसलिए यह सिर्फ नाम का मामला है
    • rocm:pytorch container का उपयोग करना rocm:jax container जितना ही आसान है
    • यह भी बताया गया कि performance data ज़्यादा प्रकाशित नहीं किया गया है
    • वे MFU (model utilization) आँकड़ों के बारे में जानना चाहते हैं
  • performance data की कमी पर सवाल

    • AMD GPU के बड़े पैमाने पर ऑर्डर के कारण उससे value निकाल पाने की संभावना पर संदेह जताया गया
    • समग्र impression "नहीं" का है
  • यह सवाल कि Obsidian (note-taking app) यह काम क्यों कर रहा है

    • शुरुआत में लगा कि यह Obsidian की पोस्ट है
    • यह भी पूछा गया कि GitHub.com और GitHub.io के बीच अब तक स्पष्ट अंतर क्यों नहीं किया गया
  • @dang से URL में username शामिल करने का अनुरोध

    • यह पोस्ट Obsidian खुद के बारे में नहीं, बल्कि user-generated blog के बारे में है