HN पर जारी: Sparse Autoencoders का उपयोग करके Llama 3.2 की interpretability पर अध्ययन
(github.com/PaulPauls)प्रोजेक्ट अवलोकन
- आधुनिक बड़े भाषा मॉडल (LLM) कई फीचर्स को एक ही neuron में ओवरलैप करके concepts को encode करते हैं, और हर neuron की activation दूसरे neurons की activation के अनुसार कई interpret किए जा सकने वाले अर्थ रख सकती है। इसे superposition कहा जाता है.
- Sparse Autoencoders (SAE) प्रशिक्षित LLM में डाले जाते हैं और activations को बहुत बड़े sparse latent space में project करके ओवरलैप्ड representations को अलग करते हैं और उन्हें interpret किए जा सकने वाले फीचर्स में बदलते हैं.
- इस प्रोजेक्ट का लक्ष्य Anthropic, OpenAI, और Google DeepMind द्वारा सफलतापूर्वक किए गए शोध को पुन:निर्मित करके interpret किए जा सकने वाले फीचर्स निकालना है.
- यह Llama 3.2-3B मॉडल के लिए कार्यात्मक और interpret किए जा सकने वाले Sparse Autoencoder बनाने हेतु पूरा pipeline प्रदान करता है.
मुख्य विशेषताएँ
- PyTorch में लिखे गए activation capture से लेकर SAE training, feature interpretation, और validation तक का पूरा end-to-end pipeline प्रदान करता है.
- बड़े भाषा मॉडलों से residual activations कैप्चर करके उन्हें SAE training dataset के रूप में उपयोग करता है.
- training data को कुशलतापूर्वक preprocess करता है, और multi-GPU का उपयोग करने वाली बड़े पैमाने की distributed training का समर्थन करता है.
- SAE training के दौरान auxiliary loss लागू करके dead latent variables को रोकता है और training dynamics को स्थिर करता है.
- Weights & Biases के माध्यम से SAE training की व्यापक logging, visualization, और checkpoints प्रदान करता है.
- interpretability analysis tools के माध्यम से सीखे गए फीचर्स के अर्थ के विश्लेषण का समर्थन करता है.
- Llama 3.1/3.2 के शुद्ध PyTorch implementation के माध्यम से बिना बाहरी dependencies के सामान्य उपयोग और परिणामों की validation संभव बनाता है.
- text और chat completion tasks के माध्यम से model behavior पर SAE के प्रभाव को validate करता है, और निकाले गए semantic features को adjust किया जा सकता है.
जारी किए गए संसाधन
-
OpenWebText sentence dataset:
- activation capture में उपयोग किए गए OpenWebText dataset का custom version.
- मूल text को बनाए रखते हुए, तेज़ access के लिए individual sentences को Parquet format में संग्रहीत किया गया है.
- sentence splitting के लिए NLTK 3.9.1 के "Punkt" tokenizer का उपयोग.
-
कैप्चर की गई Llama 3.2-3B activations:
- Llama 3.2-3B की 23वीं layer की residual activations वाली 2.5 करोड़ sentences.
- 4TB raw data को compress करके 3.2TB किया गया है और 100 archives में विभाजित किया गया है.
-
SAE training logs:
- Weights & Biases के माध्यम से training, validation, और debug metrics visualization logs.
- 10 epochs और 10,000 logged steps शामिल.
-
प्रशिक्षित 65,536 latent SAE model:
- 10 epochs के बाद अंतिम रूप से प्रशिक्षित SAE model.
प्रोजेक्ट संरचना
1. डेटा कैप्चर
capture_activations.py: LLM residual activations कैप्चर.openwebtext_sentences_dataset.py: sentence-level processing के लिए custom dataset.
2. SAE training
sae.py: मुख्य SAE model implementation.sae_preprocessing.py: SAE training के लिए data preprocessing.sae_training.py: distributed SAE training implementation.
3. interpretability
capture_top_activating_sentences.py: feature activation को अधिकतम करने वाले sentences की पहचान.interpret_top_sentences_send_batches.py: interpretation के लिए batches बनाना और भेजना.interpret_top_sentences_retrieve_batches.py: interpretation results एकत्र करना.interpret_top_sentences_parse_responses.py: interpretation results का विश्लेषण.
4. validation और testing
llama_3_inference.py: मुख्य inference implementation.llama_3_inference_text_completion_test.py: text completion test.llama_3_inference_chat_completion_test.py: chat completion test.llama_3_inference_text_completion_gradio.py: interactive testing के लिए Gradio interface.
1 टिप्पणियां
Hacker News राय
LLMs की mechanistic interpretability उस समस्या को हल करती है जिसमें मॉडल अपने बारे में समझाते समय भरोसेमंद जवाब बना लेते हैं। मॉडल जितना अधिक शक्तिशाली होता है, वह "झूठ" को सही ठहराने में उतना ही अधिक convincing हो सकता है, इसलिए self-detection tests में उसका स्कोर और कम हो सकता है। लक्ष्य consistency है, truth नहीं
Sparse Autoencoders (SAEs) पर शोध में देखा गया कि loss curve की निचली सीमा power law के रूप में scale होती है। auxiliary loss के जरिए dead latents की समस्या को पूरी तरह हल किया जा सका, और training iterations के दौरान smooth sine-wave pattern देखा गया
mechanistic interpretability को लेकर एक सवाल उठाया गया: यह चिंता है कि भविष्य का AI, अपनी training की निगरानी करते हुए, ambiguity का इस्तेमाल कर mechanistic interpretability observers को धोखा देने वाले मॉडल बना सकता है
SAEs के evaluation की कठिनाइयों पर एक blog post पढ़कर यह जिज्ञासा हुई कि इस समस्या को कैसे हल किया गया। repository में उस approach को समझने योग्य हिस्से खोजने की इच्छा जताई गई
लगता है कि यह काम alignment पर सकारात्मक प्रभाव डाल सकता है, लेकिन अभी details की जांच नहीं की गई है। यह भी जिज्ञासा है कि समय, लागत और जोखिम की भरपाई के लिए कितना भुगतान किया जाना चाहिए
documentation पर बहुत समय देने के लिए धन्यवाद
यह बहुत शानदार काम है, और यह जानने की जिज्ञासा है कि क्या इसे SAELens के साथ integrate करने की योजना है