अपने मॉडल को कैसे scale करें: TPU पर LLMs के लिए एक सिस्टम दृष्टिकोण
(jax-ml.github.io)- बड़े पैमाने पर deep learning performance को optimize करना अक्सर ‘alchemy’ जैसा लगता है, लेकिन वास्तव में कुछ समझने योग्य सरल सिद्धांतों के ज़रिए model efficiency बढ़ाई जा सकती है
- एक single accelerator से लेकर दसियों हज़ार accelerators तक, अपेक्षाकृत सरल सिद्धांत हर जगह लागू होते हैं, और इन्हें समझकर निम्नलिखित उपयोगी काम किए जा सकते हैं:
- मोटे तौर पर समझना कि मॉडल का हर हिस्सा सैद्धांतिक optimum के कितना करीब पहुँचा है
- अलग-अलग scale पर विभिन्न parallelization techniques चुनने का आधार तैयार करना
- बड़े Transformer models को train और run करने के लिए ज़रूरी लागत और समय का अनुमान लगाना
- ऐसे algorithms डिज़ाइन करना जो खास hardware की विशेषताओं का लाभ उठाएँ
- मौजूदा algorithm performance की सीमाओं को स्पष्ट रूप से समझकर hardware डिज़ाइन करना
- आवश्यक background knowledge
- LLM और Transformer architecture की बुनियादी अवधारणाओं की समझ आवश्यक है
- बड़े पैमाने के operation modes की समझ अनिवार्य नहीं है
- LLM training की बुनियादी जानकारी और JAX का अनुभव हो तो और बेहतर है
- Transformer architecture पर blog posts और JAX में LLM scaling पर slides देखने की सिफारिश की जाती है
- लक्ष्य
- यह क्षमता विकसित करना कि दिए गए hardware पर मॉडल को किस तरह parallelize करना बेहतर होगा, इसका अनुमान लगाया जा सके
- training और inference में लगने वाले समय और लागत का मोटा हिसाब लगा सकने की क्षमता विकसित करना
इसमें रुचि क्यों लेनी चाहिए
- सिर्फ 3~4 साल पहले तक, ज़्यादातर ML researchers को ऐसे बड़े पैमाने के scale optimization को गहराई से समझने की ज़रूरत नहीं थी
- अब स्थिति यह है कि “छोटे” models भी hardware limits के बहुत करीब चल रहे हैं, इसलिए efficient large-scale execution को समझना ज़रूरी हो गया है
- ML का इतिहास system innovation और software improvements के परस्पर विकास की धारा के रूप में देखा जा सकता है
- हाल के Transformer models hardware limits तक पहुँचकर काम कर रहे हैं, इसलिए अगर model efficiency समझ में न आए तो नई architectures या research वास्तविक उपयोग में असफल हो सकती हैं
- benchmark पर 20% performance gain मिलने पर भी, यदि hardware efficiency 20% घट जाए तो अंततः उसकी practical utility कम हो जाती है
- model scaling का मुख्य लक्ष्य यह है कि chips (accelerators) की संख्या बढ़ाने पर throughput रैखिक रूप से बढ़े
- इसे "strong scaling" कहा जाता है
- chips जोड़ने से computation time कम होता है, लेकिन chips के बीच communication cost उत्पन्न होती है
- यदि communication, computation से अधिक समय लेने लगे तो सिस्टम "communication bound" हो जाता है और strong scaling संभव नहीं रहती
- यदि hardware को पर्याप्त अच्छी तरह समझकर यह अनुमान लगाया जा सके कि ऐसे bottlenecks कहाँ आएँगे, तो मॉडल को इस तरह डिज़ाइन या पुनर्गठित किया जा सकता है कि उन्हें रोका जा सके
- इस पुस्तक का लक्ष्य यह समझाना है कि TPU (और GPU) hardware कैसे काम करता है और Transformer architecture कैसे विकसित हुआ ताकि वह मौजूदा hardware पर अच्छी तरह चल सके
- आशा है कि यह नई architectures डिज़ाइन करने वाले researchers और मौजूदा पीढ़ी के LLMs को तेज़ी से चलाने की कोशिश कर रहे engineers, दोनों के लिए उपयोगी होगी
समग्र अवलोकन
- यह लेख निम्न प्रकार से संगठित है
- सेक्शन 1 में roofline analysis के माध्यम से मॉडल की performance limits तय करने वाले तत्वों (communication, computation, memory) को समझाया गया है
- सेक्शन 2, सेक्शन 3 में TPU और GPU की आंतरिक संरचना तथा chips के बीच connection के तरीकों पर चर्चा की गई है
- इसके ज़रिए निम्न प्रश्नों के उत्तर दिए जाते हैं
- किसी विशेष आकार की matrix multiplication सैद्धांतिक रूप से कितनी तेज़ की जा सकती है
- किस बिंदु पर computation memory bandwidth या communication bandwidth से सीमित होने लगती है
- TPU cluster किस संरचना में जुड़ा होता है, और एक chip से दूसरी chip तक data ले जाने में लगभग कितना समय लगता है
- distributed matrices का multiplication दक्षता से कैसे किया जा सकता है
- इसके ज़रिए निम्न प्रश्नों के उत्तर दिए जाते हैं
- सेक्शन 4 में Transformer architecture के सूत्रों (matrix sizes, parameter counts, FLOPs) को विस्तार से समझाया गया है
- सेक्शन 5 और सेक्शन 7 मुख्य भाग हैं, जिनमें कई chips पर मॉडल को parallelize करने के विभिन्न तरीके प्रस्तुत किए गए हैं
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- ZeRO, Rematerialisation, Host offload, Gradient accumulation जैसी memory-saving techniques भी शामिल हैं
- सेक्शन 6, सेक्शन 8 में TPU पर LLaMA-3 model की training और inference को उदाहरण के रूप में लेकर वास्तविक लागत, समय और configuration दिखाए गए हैं
- अंत में सेक्शन 9, सेक्शन 10 में JAX में model profiling, debugging, और parallel processing लागू करने के व्यावहारिक तरीके बताए गए हैं
विस्तार से: पुस्तक के मुख्य सेक्शनों का सार
-
भाग 1: Preliminaries
-
सेक्शन 1: सरल Roofline analysis का परिचय
- algorithm को सीमित करने वाले तीन तत्व: computation, communication, memory
- इनके आधार पर computation speed की upper bound का अनुमान लगाना सीखा जाता है
-
सेक्शन 2: TPU को देखने का दृष्टिकोण
- TPU किस तरह computation करता है
- systolic array संरचना क्या होती है
- TPU memory और communication bandwidth कैसे उपलब्ध कराता है, इसकी बुनियादी समझ
-
सेक्शन 3: distributed matrices और distributed multiplication
- model parameters को कई chips में बाँटकर store करने की तकनीक (Sharding)
- distributed matrix operations के दौरान उत्पन्न communication और bottlenecks को संभालने के तरीके
-
-
भाग 2: Transformers
-
सेक्शन 4: आवश्यक Transformer formulas का संकलन
- Transformer में matrix multiplication वास्तव में किस रूप में होती है
- parameter counts, FLOPs, KV cache size आदि की गणना कैसे की जाए
- यह समझना कि Attention operations, Feed-Forward blocks की तुलना में कितना अधिक computation मांगते हैं
-
सेक्शन 5: Transformer training parallelization strategies
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel techniques का परिचय
- ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload जैसी memory-saving approaches
- किसी विशेष model size और chip count के अनुरूप parallelization को configure करने की अवधारणा
-
सेक्शन 6: LLaMA 3 TPU training application
- यदि वास्तविक TPU environment में LLaMA 3 model को train किया जाए, तो लगने वाले समय और लागत का अनुमान
- batch size, parallelization method, memory usage आदि के ठोस उदाहरण
-
सेक्शन 7: Transformer inference पर सब कुछ
- inference के समय latency एक महत्वपूर्ण नए factor के रूप में सामने आती है
- KV cache आदि के कारण memory usage और communication समस्याएँ
- model serving के लिए कई chips को कैसे बाँटा और जोड़ा जाए, इस पर चर्चा
-
सेक्शन 8: LLaMA 3 TPU serving application
- TPU v5e पर LLaMA 3 को serve करने की स्थिति मानकर, अनुमानित लागत, latency और throughput trade-offs का विश्लेषण
-
-
भाग 3: Practical Tutorials
-
सेक्शन 9: TPU code profiling कैसे करें
- JAX+XLA stack की समझ
- वास्तविक performance degradation issues की पहचान और उनके समाधान
- JAX/TensorBoard profiler का उपयोग
-
सेक्शन 10: JAX के साथ TPU programming
- JAX के parallelization API(primitives) का उपयोग कैसे करें
- examples और समस्याओं के माध्यम से parallel computation की अवधारणाएँ सीखना
-
सेक्शन 11: निष्कर्ष और अतिरिक्त सामग्री
- TPU और LLM पर आगे पढ़ने के लिए सामग्री
- पूरे विषय का संक्षिप्त समापन और भविष्य की दिशा का उल्लेख
-
1 टिप्पणियां
Hacker News टिप्पणियाँ