Felafax BlogTune Llama3 405B on AMD MI300x (हमारी यात्रा)
परिचय
- जैसे-जैसे open source मॉडल बड़े होते जा रहे हैं, बड़े पैमाने की AI training को संभालने के लिए शक्तिशाली infrastructure की आवश्यकता बढ़ रही है
- Felafax ने AMD GPU पर LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करके AMD hardware की दक्षता साबित की
- पूरा काम GitHub पर open source के रूप में सार्वजनिक किया गया है
- AMD MI300X GPU, NVIDIA AI hardware की तुलना में उच्च प्रदर्शन प्रदान करता है
- यह प्रोजेक्ट TensorWave के समर्थन से संभव हो पाया
JAX क्या है और इसे क्यों चुना गया
- JAX एक शक्तिशाली machine learning library है, जो NumPy-जैसे API, automatic differentiation, और Google's XLA compiler को जोड़ती है
- यह model parallelism के लिए उत्कृष्ट API प्रदान करती है, जिससे यह बड़े मॉडलों की training के लिए आदर्श बनती है
JAX के फायदे
- Pure functions: JAX pure functions लिखने को प्रोत्साहित करता है, जिससे code को compose करना, debug करना, और पढ़ना आसान होता है
- Advanced parallelism: JAX का flexible JIT API, advanced data और model parallelism को support करता है, जो बड़े पैमाने की training के लिए आवश्यक है
- Clean codebase: JAX की design philosophy अलग-अलग hardware platforms के बीच portable code लिखने को प्रोत्साहित करती है
JAX non-NVIDIA hardware पर क्यों बेहतर है
- Hardware-independent approach: JAX, XLA compiler का उपयोग करके computations को hardware-independent intermediate representation में compile करता है
- Platform-independent optimization: XLA compiler hardware से स्वतंत्र रूप से optimization करता है
- आसान portability: JAX का उपयोग करने पर NVIDIA से AMD पर जाने में code changes न्यूनतम रहते हैं
AMD GPU पर JAX सेटअप
- Docker image को pull करने, container शुरू करने और installation verify करने के बाद सेटअप पूरा किया गया
- 8 AMD MI300x GPU का उपयोग करके LLaMA 405B मॉडल को train किया गया
LLaMA 405B training: प्रदर्शन और scalability
- JAX का उपयोग करके AMD GPU पर LLaMA 405B मॉडल को train किया गया
- LoRA फाइन-ट्यूनिंग के माध्यम से model weights और LoRA parameters को bfloat16 precision में adjust किया गया
- मॉडल आकार: लगभग 800GB VRAM का उपयोग
- LoRA weights और optimizer state: लगभग 400GB VRAM का उपयोग
- कुल VRAM उपयोग: लगभग 1200GB
- Training speed: लगभग 35 tokens प्रति सेकंड
- Memory efficiency: लगभग 70% बनी रही
- Scalability: JAX के साथ 8 GPU पर लगभग linear scaling मिली
हमारी training setup
- LLaMA 3.1 को PyTorch से JAX में convert किया गया
- Model loading और parameter sharding के माध्यम से इसे कुशलतापूर्वक distribute किया गया
JAX में parameter sharding
- JAX की device mesh functionality का उपयोग करके मॉडल को 8 AMD GPU में कुशलतापूर्वक distribute किया गया
- Parameter sharding rules परिभाषित करके प्रत्येक tensor के dimensions को mesh axes के अनुसार shard किया गया
LoRA training implementation
- LoRA, weight updates को low-rank matrices में विभाजित करके trainable parameters की संख्या कम करता है
- LoRADense layer को implement किया गया, जिसमें LoRA parameters शामिल हैं
- LoRA parameters को कुशलतापूर्वक distribute करके memory usage और computation efficiency को optimize किया गया
निष्कर्ष
- AMD GPU और JAX का उपयोग करके LLaMA 3.1 405B मॉडल का फाइन-ट्यूनिंग करने का अनुभव बहुत सकारात्मक रहा
- JAX की मजबूत parallelism capabilities और hardware-independent approach का उपयोग करके मॉडल को कुशलतापूर्वक distribute किया गया
- इससे साबित हुआ कि AMD GPU बड़े पैमाने की AI training के लिए एक मजबूत विकल्प हैं
- पूरा code GitHub repository में देखा और सीधे चलाया जा सकता है
GN⁺ की संक्षिप्त टिप्पणी
- यह लेख AMD GPU और JAX का उपयोग करके बड़े AI मॉडलों को कुशलतापूर्वक train करने का तरीका बताता है
- यह रेखांकित करता है कि AMD hardware, NVIDIA की तुलना में अधिक cost-effective विकल्प हो सकता है
- JAX का hardware-independent approach, code portability बढ़ाता है और maintenance को आसान बनाता है
- बड़े मॉडल training में रुचि रखने वालों के लिए यह उपयोगी जानकारी और practical code प्रदान करता है
- समान क्षमताओं वाले प्रोजेक्ट्स में NVIDIA का CUDA और PyTorch शामिल हैं
1 टिप्पणियां
Hacker News प्रतिक्रिया
JAX का उपयोग करके 8xAMD MI300x GPU पर Llama3.1 405B मॉडल को fine-tune करने की उपलब्धि साझा की गई
memory constraints को पार करके JIT-compiled version चलाने के तरीकों की पड़ताल करने का सुझाव
AMD GPU और ROCm support के बारे में अनुभव साझा किया गया
405B मॉडल के inference पहलू पर किए गए प्रयोग का अनुभव साझा किया गया
rocm:pytorchcontainer का उपयोग करनाrocm:jaxcontainer जितना ही आसान हैperformance data की कमी पर सवाल
यह सवाल कि Obsidian (note-taking app) यह काम क्यों कर रहा है
@dang से URL में username शामिल करने का अनुरोध