PyTorch मर चुका है। JAX अमर रहे
(neel04.github.io)- PyTorch उत्पादकता में कमी और डेवलपमेंट समय की बर्बादी का कारण इसलिए बनता है क्योंकि "फ़्रेमवर्क खुद खराब है" ऐसा नहीं, बल्कि इसलिए कि इसे मौजूदा उपयोग-परिदृश्यों के हिसाब से डिज़ाइन नहीं किया गया था।
PyTorch का दर्शन
- PyTorch का दर्शन है dynamic, डिबग करना आसान, और Pythonic होना
- इसके विपरीत TensorFlow 1.x ने XLA compiler का आक्रामक उपयोग कर static लेकिन high-performance फ़्रेमवर्क बनने की कोशिश की
- TensorFlow डेवलपर्स ने समझा कि कम्युनिटी 1.x API को पसंद नहीं करती, इसलिए Keras को main interface बनाने और XLA compiler की भूमिका कम करने का फैसला किया
- PyTorch ने अपनी जड़ों को बनाए रखा, और TensorFlow के static व delayed approach के विपरीत
torch.Tensorको तुरंत evaluate करने वाला अधिक dynamic "eager execution" approach अपनाया - इसका असर हुआ और बहुत सा research PyTorch पर शिफ्ट हो गया
- 2021 में GPT-3 आने के बाद performance और scalability मुख्य चिंता बन गए
- PyTorch ने इन मांगों का कुछ हद तक अच्छा जवाब दिया, लेकिन चूंकि इसे इस दर्शन को ध्यान में रखकर डिज़ाइन नहीं किया गया था, इसलिए तकनीकी कर्ज बढ़ता गया और इसकी नींव डगमगाने लगी
- PyTorch डेवलपर्स कोई भी compromise नहीं चाहते थे और उन्होंने एक साथ दो रास्ते अपनाने का विकल्प चुना
- XLA compiler को performance और stability वाले default backend की तरह इस्तेमाल करना
torch.compilestack बनाना ताकि ज़रूरत पड़ने पर यूज़र को compiler invoke करने की आज़ादी मिले
- दीर्घकालिक रणनीति की कमी एक गंभीर समस्या है
- PyTorch compiler-केंद्रित दर्शन (JAX जैसा) के प्रति पूरी तरह प्रतिबद्ध नहीं होना चाहता, लेकिन कोई अच्छा विकल्प भी दिखाई नहीं देता
- इस समस्या के लिए प्रतिस्पर्धी उत्पादों के समाधान क्या हैं?
JAX का compiler-आधारित विकास
- JAX, TensorFlow के शक्तिशाली compiler stack XLA का उपयोग करता है
- XLA एक शक्तिशाली compiler है, लेकिन end user के लिए यह सब abstracted रहता है
- जब तक function pure है,
@jax.jitdecorator का उपयोग करके function को JIT compile किया जा सकता है और XLA में चलाया जा सकता है - XLA generated graph की correctness verify करने, JAX में sharding-आधारित automatic parallelism के लिए GSPMD partitioner, graph optimization, operator और kernel fusion, latency-hiding scheduling, asynchronous communication overlap,
tritonजैसे अन्य backends के लिए code generation आदि सब कुछ पर्दे के पीछे संभालता है - बस JAX की सीमाओं का पालन करें, बाकी XLA अपने आप कर देता है
- उदाहरण के लिए parallelization करते समय
torch.distributed.barrier()जैसे communication primitive की ज़रूरत नहीं पड़ती - DDP support बहुत सरल code से संभव है
- XLA का approach यह है कि computation sharding का अनुसरण करती है। इसलिए यदि input array किसी axis के अनुसार sharded है, तो XLA downstream computation के लिए उसे अपने आप संभालता है
- "compiler-आधारित development" का विचार Rust compiler के काम करने के तरीके जैसा है
- PyTorch की सीमाएँ
- PyTorch डेवलपर्स ने flexibility और freedom की मूल विचारधारा बनाए रखने के बजाय नई सुविधाओं के लिए compiler stack को integrate और उस पर depend करने का जो विकल्प चुना, वह असंतोषजनक है
- PyTorch 2.x के आधिकारिक roadmap के अनुसार, XLA को Torch के साथ पूरी तरह integrate करने की दीर्घकालिक योजना साफ़ दिखाई देती है
- यह एक भयानक विचार है। यह वैसा ही है जैसे कहना कि Rust compiler में ज़बरदस्ती C++ code ठूंसना, Rust को सीधे इस्तेमाल करने से बेहतर अनुभव देगा
- JAX के विपरीत Torch को XLA को केंद्र में रखकर डिज़ाइन नहीं किया गया था
- अगर PyTorch XLA-आधारित compiler stack इस्तेमाल करने का निर्णय लेता है, तो क्या आदर्श फ़्रेमवर्क वही नहीं होगा जिसे उसी के इर्द-गिर्द खास तौर पर डिज़ाइन और बनाया गया हो?
- भले ही PyTorch मनचाहे compiler backend चुनने वाला "multi-backend" approach अपनाए, क्या इससे fragmentation और खराब नहीं होगी, और क्या यह सभी compiler stacks की सीमाओं का सम्मान करने की कोशिश में API को पूरी तरह बर्बाद नहीं कर देगा?
- जिसने भी TPU पर Torch/XLA इस्तेमाल किया है, वह गंभीर PTSD से पीड़ित है
Multi-Backend विफल हो चुका है
- PyTorch ने एक साथ सब कुछ करने की कोशिश की और बुरी तरह असफल रहा
- "multi-backend" डिज़ाइन निर्णय ने इस समस्या को घातांकीय रूप से और बदतर बना दिया
- सिद्धांत में यह मनचाहा stack चुनने की आज़ादी जैसा लगता है, लेकिन व्यवहार में यह समझना मुश्किल tracebacks और incompatibility समस्याओं की उलझी हुई अव्यवस्था है
- backends के बीच constraints और PyTorch API का टकराव
- इन backends को काम में लाना ही मुश्किल नहीं है, बल्कि इनकी अपेक्षित constraints PyTorch की flexible और Pythonic API के साथ अच्छी तरह मेल नहीं खातीं
- API consistency बनाए रखने और backend limitations का पालन करने के बीच एक trade-off है
- नतीजतन, डेवलपर्स किसी एक backend के साथ वास्तव में integrate/commit करने के बजाय code generation पर अधिक निर्भर होने लगते हैं
- PyTorch की रणनीति की कमी
- PyTorch meaningful trade-offs से इनकार करता है, इसलिए हर निर्णय एक compromise जैसा लगता है
- न consistency है, न कोई समग्र रणनीति
- अंततः यह यूज़र्स में बहुत निराशा पैदा करता है और ऐसे features के बेतरतीब मिश्रण जैसा लगता है जो एक-दूसरे के साथ फिट नहीं बैठते
- ecosystem को मारने का इससे तेज़ तरीका नहीं है
- JAX approach का अनुसरण क्यों नहीं करना चाहिए
- PyTorch को JAX के "integrated compiler and backend" approach का अनुकरण नहीं करना चाहिए
- क्योंकि JAX को XLA के साथ काम करने के लिए स्पष्ट रूप से डिज़ाइन किया गया था
- PyTorch frontend को JAX वाले frontend से बदल देना कोई रणनीति नहीं हो सकती
- XLA के ऊपर JAX से बेहतर API सोचना व्यावहारिक रूप से असंभव है
- डेवलपर्स नए और अलग विचारों को आज़माएँ, इसके लिए उन्हें दोष नहीं दिया जा सकता
- लेकिन यदि PyTorch को समय की कसौटी पर टिकना है, तो उसे अपनी नींव मजबूत करने पर अधिक ध्यान देना होगा, बजाय इसके कि वह ऐसे चमकदार नए features दे जो आदर्श tutorial स्थितियों के बाहर तुरंत टूट जाएँ
PyTorch का विखंडन और JAX की functional programming
- JAX का functional API
- JAX functions pure होने चाहिए। यानी उनमें global side effects नहीं होने चाहिए
- गणितीय functions की तरह, एक ही data मिलने पर वे execution context की परवाह किए बिना हमेशा वही output लौटाएँ
- इस डिज़ाइन दर्शन की वजह से JAX functions composable होते हैं और एक-दूसरे के साथ अच्छी interoperability रखते हैं
- development complexity कम हो जाती है, और functions को specific signature और अच्छी तरह परिभाषित ठोस कार्यों के रूप में परिभाषित किया जाता है
- type सही हो तो function तुरंत काम करेगा, इसकी गारंटी रहती है
- यह scientific computing, खासकर deep learning में आवश्यक कार्यों के प्रकार के लिए उपयुक्त है
- optax API उदाहरण
- functional approach की वजह से optax में "chain" जैसी चीज़ है
- इसमें कई functions शामिल होते हैं जो gradients पर क्रम से लागू किए जाते हैं
- मूल building block
GradientTransformationहै - इससे एक शक्तिशाली और expressive API बनती है
- उदाहरण के लिए gradients clip करना, gradients की EMA लेना, या optimizers को combine करना जैसी चीज़ें बहुत आसान हो जाती हैं
- functional design के फ़ायदे
- functional design का एक और शानदार परिणाम
vmapहै - इसका मतलब 'vectorized' map है, और यह ठीक वही करता है
- आप हर चीज़ को map कर सकते हैं, और अगर वह
vmapहै तो XLA अपने आप fuse और optimize कर देता है - functions लिखते समय batch dimension के बारे में सोचने की ज़रूरत नहीं होती
- बस पूरे code को
vmapकर दीजिए - इसका मतलब है कि
ein-*operations की ज़रूरत कम पड़ती है - 2D/3D tensor manipulation को समझना अधिक intuitive हो जाता है और readability भी काफी बेहतर रहती है
- चूंकि आप केवल individual components को isolate करके reasoning करते हैं, इसलिए सही काम करने वाला जटिल code लिखना आसान हो जाता है
- बस purity constraints का सम्मान करें और सही signatures रखें, फिर composability जैसे बाकी सभी लाभ मिलते हैं
- functional design का एक और शानदार परिणाम
- PyTorch ecosystem की समस्याएँ
torchमें, चाहे आप कोई भी stack इस्तेमाल करें (FSDP + multi-node + torch.compileआदि), कुछ न कुछ टूटने की संभावना हमेशा रहती है- कई चीज़ों का सही तरीके से साथ काम करना ज़रूरी होता है, और किसी एक component के फेल होते ही रात 3 बजे तक debugging करनी पड़ सकती है
- PyTorch द्वारा दी गई दर्जनों सुविधाओं के हर संयोजन का परीक्षण करना संभव नहीं है, इसलिए development के दौरान न पकड़े गए bugs हमेशा रहेंगे
- पर्याप्त प्रयास के बिना, सही से काम करने वाला code लिखना लगभग असंभव है
torchecosystem बहुत फूला हुआ और bug-prone हो गया है- साझा abstraction न होने के कारण नए libraries और frameworks सामने आते हैं जिन्हें दूसरे "solutions" के साथ interface करने के लिए डिज़ाइन ही नहीं किया गया
- यह जल्दी ही dependencies और
requirements.txtकी अव्यवस्था में बदल जाता है - GitHub issues या forum discussions का 70-80% सिर्फ इसलिए होता है क्योंकि अलग-अलग libraries में errors आते हैं
- इसे सुलझाने का लगभग कोई तरीका नहीं है
- समाधान का अभाव
- यह OOP और डिज़ाइन की समस्या है
- लगता है कि PyTree जैसी कोई बुनियादी और PyTorch-जैसी object abstraction के लिए common foundation बनाने में मदद कर सकती थी
- functional programming paradigm को अपनाना भी संभव नहीं है
- ऐसा करने पर यह JAX के कमज़ोर performance वाले संस्करण की ओर सिमट जाएगा, और सभी मौजूदा torch codebases की backward compatibility टूट जाएगी
- PyTorch इस मामले में पूरी तरह टूटा हुआ दिखता है
JAX की reproducibility बढ़त
- seed handling
- PyTorch में seed handling आदर्श नहीं है
- आम तौर पर इसके लिए कई lines of code चलानी पड़ती हैं
- इसे भूलना या गलत configure करना आसान है
- JAX explicit keys बनाता है और randomness की ज़रूरत वाले हर function में उन्हें pass करना अनिवार्य करता है
- यह approach समस्या को पूरी तरह खत्म कर देता है क्योंकि RNG हमेशा statically seeded रहता है
- JAX के पास NumPy का अपना version (
jax.numpy) है, इसलिए अलग से seed सेट करने की ज़रूरत नहीं होती - ऐसे छोटे QoL निर्णय पूरे फ़्रेमवर्क के user experience को बहुत बेहतर बना सकते हैं
- portability
- PyTorch codebase इस्तेमाल करते समय सबसे बड़ी समस्याओं में से एक portability की कमी है
- CUDA/GPU के लिए लिखा गया codebase TPU, NPU, AMD GPU जैसे non-Nvidia hardware पर चलने पर अच्छा काम नहीं करता
- 1-node के लिए लिखे गए PyTorch code को multi-node पर port करना कठिन है
- multi-node में अक्सर दर्जनों घंटे का development time और काफी code changes लगते हैं
- JAX का compiler-केंद्रित approach यहाँ लाभ देता है
- XLA device backends के बीच switching संभालता है और बहुत कम code changes के साथ GPU/TPU/multi-node/multi-slice पर अच्छा काम करता है
- इससे hardware vendors के लिए अपने devices को support करना आसान होता है और devices के बीच switching भी सरल होती है
- हर किसी के पास एक जैसा hardware नहीं होता, इसलिए विभिन्न प्रकार के hardware पर portable codebases deep learning को beginners/intermediates के लिए अधिक सुलभ बनाने की दिशा में एक छोटा लेकिन महत्वपूर्ण कदम हो सकते हैं
- automatic scaling
- ऐसा codebase जो अपने आप अच्छी तरह auto-scale कर सके, reproducibility के लिए बहुत मददगार होता है
- आदर्श स्थिति में यह न्यूनतम code changes के साथ, networking boundaries की परवाह किए बिना अपने आप होना चाहिए
- JAX यह काम अच्छी तरह करता है
- JAX code लिखते समय communication primitives specify करने या हर जगह
torch.distributed.barrier()लगाने की ज़रूरत नहीं होती - XLA उपलब्ध hardware को ध्यान में रखकर यह सब अपने आप insert करता है
- JAX जिन भी devices को detect कर सकता है, वे networking, topology, configuration आदि की परवाह किए बिना अपने आप उपयोग में आ जाते हैं
- यह अपने आप computation को synchronize और stage करता है तथा optimization passes लागू करता है ताकि kernels की asynchronous execution अधिकतम हो और latency न्यूनतम रहे
- इंसान को सिर्फ उन tensors की sharding specify करनी होती है जिन्हें devices में बाँटना है, जैसे input array की batch dimension
- XLA के "computation follows sharding" approach की वजह से बाकी सब अपने आप तय हो जाता है
- इससे validation किए गए large-scale experiments को शौकिया तौर पर भी अपेक्षाकृत आसानी से चलाया, परखा और संभावित रूप से दोहराया जा सकता है
- यह भूले हुए विचारों की फिर से खोज को आसान बना सकता है, और न्यूनतम प्रयास से बड़े पैमाने पर functions के रूप में उनका परीक्षण करना संभव होने से ऐसे experiments को प्रोत्साहन मिल सकता है
JAX की कमियाँ
- governance structure
- इस समय XLA, TensorFlow governance के अधीन है
- PyTorch जैसी अलग organizational body स्थापित करने पर चर्चा हुई है, लेकिन ठोस प्रयास बहुत कम हुए हैं
- Google की unpopular products बंद कर देने वाली प्रतिष्ठा के कारण उस पर भरोसा बहुत अधिक नहीं है
- JAX तकनीकी रूप से DeepMind project है और Google की समग्र AI रणनीति में केंद्रीय महत्व रखता है, लेकिन अलग governance body पूरे ecosystem के लिए दीर्घकाल में बड़ा लाभ दे सकती है
- एक अलग governance body project development को दिशा दे सकती है
- इससे ठोस संरचना मिलेगी और Google की कुख्यात bureaucracy से अलग होकर एक साथ कई समस्याओं से बचा जा सकेगा
- ज़रूरी नहीं कि JAX को इस तरह की औपचारिक संरचना अनिवार्य रूप से चाहिए, लेकिन यह आश्वासन अच्छा होगा कि Google के शीर्ष प्रबंधन के फैसलों से इतर भी JAX का विकास लंबे समय तक जारी रहेगा
- इससे कंपनियों और बड़े research labs में adoption निश्चित रूप से बढ़ेगा, जो ऐसे tools में resources लगाने से हिचकते हैं जिन्हें कभी भविष्य में maintain न किया जाए
- XLA का open source transition
- लंबे समय तक XLA एक closed-source project था
- लेकिन इसे open source बनाने के प्रयास हुए, और अब OpenXLA अंदरूनी XLA builds से कहीं बेहतर performance दिखाता है
- फिर भी XLA के internals पर documentation अब भी कम है
- अधिकतर resources live talks और कभी-कभार papers तक सीमित हैं, और वे भी अक्सर पुराने होते हैं
- अगर planned features के लिए publicly accessible roadmap हो, तो लोगों के लिए प्रगति ट्रैक करना और खास दिलचस्प चीज़ों में योगदान देना आसान होगा
- XLA compiler stack के हर चरण का विश्लेषण और विवरण देने वाले Edward Yang शैली के छोटे blog posts हों, तो practitioners बेहतर आकलन कर सकेंगे कि XLA क्या कर सकता है और क्या नहीं
- यह resource-intensive है और शायद effort कहीं और बेहतर ढंग से लगाया जा सकता है, यह समझा जा सकता है, लेकिन जब लोग tools को समझते हैं तो उन पर अधिक भरोसा करते हैं, और इसका पूरे ecosystem में सकारात्मक ripple effect होता है, जिससे सबको फायदा होता है
- ecosystem integration
flaxJAX ecosystem का एक सिरदर्द है- इसका API intuitive नहीं है, syntax संक्षिप्त है, और PyTorch से आने वाले beginners के लिए यह बिल्कुल नरक जैसा है
equinoxइस्तेमाल करना बेहतर हैflaxकी कमियों को सुधारने के लिए development team ने कोशिशें की हैं, लेकिन आखिरकार यह समय की बर्बादी है- यदि आपको equinox-शैली का API चाहिए, तो
equinoxही इस्तेमाल करें flaxऐसी बहुत कम चीज़ें बेहतर करता है जिन्हेंequinoxमें दोहराना मुश्किल हो- वर्तमान में JAX ecosystem का बड़ा हिस्सा
flaxके आसपास डिज़ाइन किया गया है equinoxमूल रूप से PyTree के साथ interface करता है, इसलिए यह सभी libraries के साथ interoperable है, हालांकि थोड़ाeqx.partitionऔर filter चाहिए- इस status quo को बदलने की इच्छा है।
equinoxको हर जगह first-class support मिलना चाहिए - यह विवादास्पद राय हो सकती है, लेकिन यह classic sunk-cost fallacy है
equinoxउस तरह बेहतर काम करता है जैसा JAX framework को हमेशा करना चाहिए थाequinoxdocumentation में संक्षेपित तुलना के अनुसार,equinoxflaxसे बेहतर है- यह अच्छा है कि JAX ecosystem maintainers
equinoxकी लोकप्रियता को पहचान रहे हैं और उसी अनुसार समायोजन कर रहे हैं, लेकिन Google औरflaxटीम से भी आधिकारिक रूप से अधिक समर्थन देखने की इच्छा है - यदि आप JAX आज़माना चाहते हैं, तो
equinoxइस्तेमाल करना बेहतर है
- sharp edges
- API design decisions और XLA limitations के कारण JAX में कुछ "sharp edges" हैं जिनसे सावधान रहना चाहिए
- अच्छी तरह लिखी गई documentation में इन्हें बहुत संक्षेप में समझाया गया है
- JAX इस्तेमाल करने से पहले कम-से-कम एक बार इसे पढ़ लेना बेहतर है
- हमेशा की तरह RTFM बहुत समय और ऊर्जा बचा सकता है
निष्कर्ष
- इस ब्लॉग पोस्ट का उद्देश्य उस बार-बार दोहराए जाने वाले मिथक को सुधारना था कि वास्तविक research workloads, खासकर GPU पर, PyTorch सबसे उपयुक्त है। अब ऐसा नहीं है
- वास्तव में यह तर्क इतनी दूर तक जाता है कि पूरे क्षेत्र के लिए सभी PyTorch code को JAX में port करना अत्यधिक लाभकारी होगा
- automatic parallelization, reproducibility, और साफ़ functional API जैसी चीज़ें मामूली features नहीं हैं, और ये कई research codebases के लिए बहुत मददगार होंगी
- यदि आप इस क्षेत्र को थोड़ा भी बेहतर बनाना चाहते हैं, तो अपने codebase को JAX में फिर से लिखने पर विचार करें
8 टिप्पणियां
दुनिया चलती रहती है. हा हा
2022 में PyTorch और TensorFlow की तुलना
मैं torch और onnx के साथ ही काम चलाऊँगा
अंडरग्रैजुएट छात्र ने लिखा है.. ओहो
अगर Hugging Face नहीं होता, तो PyTorch सच में गया था lol
JAX ज़िंदाबाद! मैंने इसे हाल ही में इस्तेमाल किया, और NNX API मुझे बहुत पसंद आया।
JAX की सबसे बड़ी समस्या यह है कि यह Google का है। Google open source प्रोजेक्ट्स को छोड़ देने के लिए काफ़ी मशहूर है (Tflite, Android Things, Dart, Angular, Bazel वगैरह)। TensorFlow का भी किसी बिंदु के बाद अपडेट ठीक से आना कम हो गया। दूसरी ओर Torch की शुरुआत विशाल open source इकोसिस्टम चलाने वाली Facebook से हुई, और इसका संचालन काफ़ी अच्छे से हुआ है, साथ ही यह पहले से ही Torch Foundation के तहत चल रहा है। Torch की कमियाँ निश्चित रूप से सही हैं, लेकिन यह सवाल कि उस open source को टिकाऊ तरीके से कौन चलाता है, इस मामले में JAX शुरुआत से ही बड़ा जोखिम लेकर आता दिखता है।
कम से कम Dart तो Flutter की वजह से कुछ समय तक अच्छी तरह जिंदा रहेगा, ऐसा लगता है।
Facebook कम से कम React, Django वगैरह जैसी अपनी इस्तेमाल की जाने वाली tech stack के लिए वफादारी से(?) लगातार योगदान देता हुआ लगता है, लेकिन Google तो लगता है कि कोई चीज़ ज़रा भी पुरानी पड़ जाए तो उसे फटे पुराने कपड़े की तरह फेंक देता है...