-
Attention का महत्व
- Attention, Transformer आर्किटेक्चर की एक मुख्य लेयर है, और बड़े language models तथा long-context applications में bottleneck पैदा करती है।
- FlashAttention और FlashAttention-2 ने GPU पर memory read/write को न्यूनतम करके Attention को तेज़ करने का तरीका विकसित किया।
- इसके कारण LLMs की context length में काफ़ी बढ़ोतरी हुई।
-
FlashAttention-3 की प्रमुख तकनीकें
- asynchronous processing का उपयोग: Tensor Cores और TMA की asynchronous क्षमता का उपयोग करके computation और data movement को overlap किया जाता है।
- block-level operations: block स्तर पर matrix multiplication और softmax operations को बारी-बारी से चलाया जाता है।
- low precision processing: FP8 low precision support का उपयोग करके performance बेहतर की जाती है।
-
FlashAttention-3 के performance improvements
- GPU utilization efficiency: H100 GPU की अधिकतम performance का 75% तक उपयोग करते हुए, यह पिछले संस्करणों की तुलना में 1.5-2 गुना तेज़ है।
- low precision performance: FP8 का उपयोग करके processing speed बढ़ती है और memory usage घटता है।
- long-context processing: Attention mechanism को तेज़ करके लंबे text को अधिक कुशलता से प्रोसेस किया जा सकता है।
-
FlashAttention का सारांश
- FlashAttention, Attention computation को फिर से व्यवस्थित करता है और tiling व recomputation का उपयोग करके speed को काफ़ी बढ़ाता है तथा memory usage घटाता है।
- Tiling के ज़रिए input blocks लोड किए जाते हैं, उन blocks पर Attention चलाया जाता है, और फिर output update किया जाता है।
- बीच के Attention matrix को memory में लिखे बिना, memory read/write की मात्रा कम की जाती है।
-
Hopper GPU की नई hardware capabilities
- WGMMA: नए Tensor Cores का उपयोग करके उच्च throughput प्रदान करता है।
- TMA: global memory और shared memory के बीच data transfer को तेज़ करने वाला hardware unit।
- FP8 low precision: FP8 का उपयोग करके Tensor Core throughput को दोगुना किया जाता है।
-
asynchronous processing: GEMM और Softmax को overlap करना
- overlap की आवश्यकता: GEMM और softmax को parallel में चलाकर performance को अधिकतम किया जाता है।
- ping-pong scheduling: दो warp groups बारी-बारी से GEMM और softmax चलाकर performance बेहतर करते हैं।
- warp group के भीतर overlap: एक ही warp group के भीतर GEMM और softmax को parallel में चलाकर throughput बढ़ाया जाता है।
-
low precision: incoherent processing से quantization error में कमी
- incoherent processing: Hadamard transform का उपयोग करके quantization error घटाया जाता है।
- प्रयोग के परिणाम: incoherent processing के ज़रिए quantization error को 2.6 गुना कम किया गया।
-
Attention benchmarks
- FP16: FlashAttention-2 की तुलना में लगभग 1.6-1.8 गुना तेज़।
- FP8: अधिकतम 1.2 PFLOPS तक पहुँचा।
GN⁺ का सार
- FlashAttention-3, GPU की नई hardware capabilities का उपयोग करके Attention mechanism की performance को काफ़ी बेहतर बनाता है।
- यह long-context को कुशलता से प्रोसेस कर सकता है, जिससे बड़े language models की performance अधिकतम होती है।
- PyTorch जैसे प्रमुख frameworks में इसके integrate होने की संभावना अधिक है, इसलिए भविष्य के AI research और applications पर इसका बड़ा प्रभाव पड़ सकता है।
- इसी तरह की capabilities देने वाले projects में Triton और cuDNN शामिल हैं।
1 टिप्पणियां
Hacker News राय
लगता है कि Tri Dao ने FA3 पर काम अप्रैल 2022 से शुरू किया था
यह जानने की जिज्ञासा है कि Flash Attention algorithm hardware पर कितना निर्भर है
यह जिज्ञासा है कि क्या compiler अपने-आप FlashAttention जैसी optimization खोज सकते हैं
ROCm/AMD MI300x पर port करना चाहने वाले लोग संपर्क करें
TMA (Tensor Memory Accelerator) एक hardware unit है जो global memory और shared memory के बीच data transfer को तेज करता है
FlashAttention-3 को Hopper GPU (जैसे H100) के लिए optimize किया गया है
कहा गया है कि आधुनिक LLM में sigmoid जैसी activation functions बहुत धीमी हैं
यह सवाल है कि variable masking होने पर Flash Attention, masking न होने की तुलना में 5 गुना धीमा क्यों है
यह जिज्ञासा है कि क्या FlashAttention, LLM के attention operations को replace कर सकता है
महंगे hardware की आवश्यकता है