7 पॉइंट द्वारा GN⁺ 2025-02-07 | 1 टिप्पणियां | WhatsApp पर शेयर करें
  • बड़े पैमाने पर deep learning performance को optimize करना अक्सर ‘alchemy’ जैसा लगता है, लेकिन वास्तव में कुछ समझने योग्य सरल सिद्धांतों के ज़रिए model efficiency बढ़ाई जा सकती है
  • एक single accelerator से लेकर दसियों हज़ार accelerators तक, अपेक्षाकृत सरल सिद्धांत हर जगह लागू होते हैं, और इन्हें समझकर निम्नलिखित उपयोगी काम किए जा सकते हैं:
    • मोटे तौर पर समझना कि मॉडल का हर हिस्सा सैद्धांतिक optimum के कितना करीब पहुँचा है
    • अलग-अलग scale पर विभिन्न parallelization techniques चुनने का आधार तैयार करना
    • बड़े Transformer models को train और run करने के लिए ज़रूरी लागत और समय का अनुमान लगाना
    • ऐसे algorithms डिज़ाइन करना जो खास hardware की विशेषताओं का लाभ उठाएँ
    • मौजूदा algorithm performance की सीमाओं को स्पष्ट रूप से समझकर hardware डिज़ाइन करना
  • आवश्यक background knowledge
    • LLM और Transformer architecture की बुनियादी अवधारणाओं की समझ आवश्यक है
    • बड़े पैमाने के operation modes की समझ अनिवार्य नहीं है
    • LLM training की बुनियादी जानकारी और JAX का अनुभव हो तो और बेहतर है
    • Transformer architecture पर blog posts और JAX में LLM scaling पर slides देखने की सिफारिश की जाती है
  • लक्ष्य
    • यह क्षमता विकसित करना कि दिए गए hardware पर मॉडल को किस तरह parallelize करना बेहतर होगा, इसका अनुमान लगाया जा सके
    • training और inference में लगने वाले समय और लागत का मोटा हिसाब लगा सकने की क्षमता विकसित करना

इसमें रुचि क्यों लेनी चाहिए

  • सिर्फ 3~4 साल पहले तक, ज़्यादातर ML researchers को ऐसे बड़े पैमाने के scale optimization को गहराई से समझने की ज़रूरत नहीं थी
    • अब स्थिति यह है कि “छोटे” models भी hardware limits के बहुत करीब चल रहे हैं, इसलिए efficient large-scale execution को समझना ज़रूरी हो गया है
    • ML का इतिहास system innovation और software improvements के परस्पर विकास की धारा के रूप में देखा जा सकता है
    • हाल के Transformer models hardware limits तक पहुँचकर काम कर रहे हैं, इसलिए अगर model efficiency समझ में न आए तो नई architectures या research वास्तविक उपयोग में असफल हो सकती हैं
    • benchmark पर 20% performance gain मिलने पर भी, यदि hardware efficiency 20% घट जाए तो अंततः उसकी practical utility कम हो जाती है
  • model scaling का मुख्य लक्ष्य यह है कि chips (accelerators) की संख्या बढ़ाने पर throughput रैखिक रूप से बढ़े
    • इसे "strong scaling" कहा जाता है
    • chips जोड़ने से computation time कम होता है, लेकिन chips के बीच communication cost उत्पन्न होती है
    • यदि communication, computation से अधिक समय लेने लगे तो सिस्टम "communication bound" हो जाता है और strong scaling संभव नहीं रहती
    • यदि hardware को पर्याप्त अच्छी तरह समझकर यह अनुमान लगाया जा सके कि ऐसे bottlenecks कहाँ आएँगे, तो मॉडल को इस तरह डिज़ाइन या पुनर्गठित किया जा सकता है कि उन्हें रोका जा सके
  • इस पुस्तक का लक्ष्य यह समझाना है कि TPU (और GPU) hardware कैसे काम करता है और Transformer architecture कैसे विकसित हुआ ताकि वह मौजूदा hardware पर अच्छी तरह चल सके
    • आशा है कि यह नई architectures डिज़ाइन करने वाले researchers और मौजूदा पीढ़ी के LLMs को तेज़ी से चलाने की कोशिश कर रहे engineers, दोनों के लिए उपयोगी होगी

समग्र अवलोकन

  • यह लेख निम्न प्रकार से संगठित है
  • सेक्शन 1 में roofline analysis के माध्यम से मॉडल की performance limits तय करने वाले तत्वों (communication, computation, memory) को समझाया गया है
  • सेक्शन 2, सेक्शन 3 में TPU और GPU की आंतरिक संरचना तथा chips के बीच connection के तरीकों पर चर्चा की गई है
    • इसके ज़रिए निम्न प्रश्नों के उत्तर दिए जाते हैं
      • किसी विशेष आकार की matrix multiplication सैद्धांतिक रूप से कितनी तेज़ की जा सकती है
      • किस बिंदु पर computation memory bandwidth या communication bandwidth से सीमित होने लगती है
      • TPU cluster किस संरचना में जुड़ा होता है, और एक chip से दूसरी chip तक data ले जाने में लगभग कितना समय लगता है
      • distributed matrices का multiplication दक्षता से कैसे किया जा सकता है
  • सेक्शन 4 में Transformer architecture के सूत्रों (matrix sizes, parameter counts, FLOPs) को विस्तार से समझाया गया है
  • सेक्शन 5 और सेक्शन 7 मुख्य भाग हैं, जिनमें कई chips पर मॉडल को parallelize करने के विभिन्न तरीके प्रस्तुत किए गए हैं
    • Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
    • ZeRO, Rematerialisation, Host offload, Gradient accumulation जैसी memory-saving techniques भी शामिल हैं
  • सेक्शन 6, सेक्शन 8 में TPU पर LLaMA-3 model की training और inference को उदाहरण के रूप में लेकर वास्तविक लागत, समय और configuration दिखाए गए हैं
  • अंत में सेक्शन 9, सेक्शन 10 में JAX में model profiling, debugging, और parallel processing लागू करने के व्यावहारिक तरीके बताए गए हैं

विस्तार से: पुस्तक के मुख्य सेक्शनों का सार

  • भाग 1: Preliminaries

  • भाग 2: Transformers

    • सेक्शन 4: आवश्यक Transformer formulas का संकलन

      • Transformer में matrix multiplication वास्तव में किस रूप में होती है
      • parameter counts, FLOPs, KV cache size आदि की गणना कैसे की जाए
      • यह समझना कि Attention operations, Feed-Forward blocks की तुलना में कितना अधिक computation मांगते हैं
    • सेक्शन 5: Transformer training parallelization strategies

      • Data parallel, Tensor parallel, Pipeline parallel, Expert parallel techniques का परिचय
      • ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload जैसी memory-saving approaches
      • किसी विशेष model size और chip count के अनुरूप parallelization को configure करने की अवधारणा
    • सेक्शन 6: LLaMA 3 TPU training application

      • यदि वास्तविक TPU environment में LLaMA 3 model को train किया जाए, तो लगने वाले समय और लागत का अनुमान
      • batch size, parallelization method, memory usage आदि के ठोस उदाहरण
    • सेक्शन 7: Transformer inference पर सब कुछ

      • inference के समय latency एक महत्वपूर्ण नए factor के रूप में सामने आती है
      • KV cache आदि के कारण memory usage और communication समस्याएँ
      • model serving के लिए कई chips को कैसे बाँटा और जोड़ा जाए, इस पर चर्चा
    • सेक्शन 8: LLaMA 3 TPU serving application

      • TPU v5e पर LLaMA 3 को serve करने की स्थिति मानकर, अनुमानित लागत, latency और throughput trade-offs का विश्लेषण
  • भाग 3: Practical Tutorials

1 टिप्पणियां

 
GN⁺ 2025-02-07
Hacker News टिप्पणियाँ
  • ऐसी उम्मीद है कि आने वाले कुछ वर्षों में JAX, pytorch/cuda की जगह ले लेगा। Deepseek टीम के साथ PTX मुद्दा यह दिखाता है कि हार्डवेयर performance का पूरा लाभ उठाने के लिए low-level approach में निवेश करना कितना मूल्यवान है
    • इसे Google के भीतर performance work के लिए एक guidebook की तरह इस्तेमाल किया गया था। इसका public होना चौंकाने वाला है, लेकिन लगता है कि Gemini से जुड़े विवरण हटा दिए गए हैं
    • इस guide की अच्छी बात यह है कि JAX/XLA की वजह से इसे सीधे GPU पर ले जाया जा सकता है
    • एक राय है जिसमें पूछा गया है कि JAX, AST के बजाय tracing का उपयोग क्यों करता है
    • लेखक के tweet thread का लिंक साझा किया गया है
    • कोई Jekyll site को PDF में बदलने का तरीका ढूंढ रहा है
    • इसे शानदार लेख कहकर प्रशंसा और धन्यवाद व्यक्त किया गया है
    • एक राय है जिसमें पूछा गया है कि इतने शानदार animations कैसे बनाए जाते हैं