कस्टम मॉडल अनुकूलन के लिए Llama-2 fine-tuning का केस स्टडी
(anyscale.com)- जब सामान्य-उद्देश्य वाले LLM किसी विशेष कार्य के लिए जरूरत से ज़्यादा बड़े साबित हों, तब Llama-2 को सीधे fine-tune करके छोटे और सस्ते मॉडल के साथ quality, cost और latency—तीनों में सुधार किया जा सकता है
- Llama-2 13B में fine-tuning के बाद ViGGO function representation accuracy 58%→98%, SQL generation 42%→89%, और GSM8k 28%→47% तक बढ़ गया
- ViGGO और SQL generation जैसे ऐसे कार्यों में, जहाँ output format बहुत महत्वपूर्ण है, छोटे Llama-2 मॉडल ने GPT-4 से बेहतर नतीजे दिए, लेकिन math reasoning में वह GPT-4 के स्तर तक नहीं पहुँच पाया
- प्रयोग Ray Train, Ray Data, DeepSpeed और Accelerate आधारित scripts पर किए गए; 7B·13B को 16xA10G पर और 70B को 32xA10G पर train किया गया
- performance improvement की कुंजी model size से ज़्यादा data quality और evaluation pipeline है, और prompt engineering बनाम fine-tuning के cost-quality trade-off को हर task के हिसाब से तुलना करनी चाहिए
तीन कार्यों में fine-tuning का असर
- GPT-4, Claude-2 जैसे बड़े general-purpose मॉडल तेज prototyping के लिए उपयोगी हैं, लेकिन support ticket summarization या classification जैसी सीमित जरूरतों में वे cost और performance दोनों के लिहाज़ से ज़रूरत से ज़्यादा हो सकते हैं
- इस प्रयोग में देखा गया कि Llama-2 मॉडल को तीन वास्तविक प्रकार के tasks के लिए full-parameter fine-tuning करने पर कितना सुधार मिलता है
- ViGGO: unstructured text से functional representation निकालना
- SQL-create-context: natural language और CREATE TABLE context से SQL बनाना
- GSM8k: प्राथमिक स्तर के गणित प्रश्न हल करना
- Llama-2 13B के लिए accuracy में बदलाव इस प्रकार था
- ViGGO function representation: 58% → 98%
- SQL generation: 42% → 89%
- GSM8k: 28% → 47%
- ViGGO और SQL generation में छोटे Llama-2 मॉडल ने GPT-4 से बेहतर नतीजे दिए, जबकि GSM8k जैसे math reasoning tasks में fine-tuning के बाद भी GPT-4 की performance तक नहीं पहुँचा
fine-tuning का तरीका और training infrastructure
- तीनों tasks में standard full-parameter fine-tuning का उपयोग किया गया
- training next-token prediction तरीके से हुई
- model के सभी parameters gradient update के लक्ष्य थे
- LoRA या transformer blocks के कुछ हिस्सों को freeze करने जैसे तरीके इस प्रयोग में शामिल नहीं थे
- experiment scripts Ray Train, Ray Data, DeepSpeed, Accelerate पर बनाए गए थे
- Llama-2 7B, 13B, 70B runs को support किया गया
- Ray Train का TorchTrainer कई worker processes और GPU resources पर training loop को distribute करता है
- data sharding Ray Train संभालता है, और हर worker
session.get_dataset_shard("train"),session.get_dataset_shard("valid")के जरिए अपने assigned data shard तक पहुँचता है
- model sharding DeepSpeed ZeRO stage 3 और optimizer state offloading से की गई
- क्योंकि model के हिस्से कई workers में बँटे होते हैं, इसलिए checkpoint save जैसी स्थिति में, जहाँ पूरे model तक पहुँच चाहिए,
accelerator.unwrap_model(model)से model को unwrap करना पड़ता है
- क्योंकि model के हिस्से कई workers में बँटे होते हैं, इसलिए checkpoint save जैसी स्थिति में, जहाँ पूरे model तक पहुँच चाहिए,
- compute resources इस प्रकार थे
- 7B·13B: 16xA10G
- 70B: 32xA10G, 4
g5.48xlargeinstances - Ray के साथ full-parameter fine-tuning के लिए A100 अनिवार्य नहीं है
- training अधिकतम 10 epochs तक चलाई गई, और validation set पर सबसे कम perplexity वाले checkpoint को चुना गया
special tokens से input-output संरचना तय करना
- fine-tuning data में instruction prompt की जगह special tokens से task structure को व्यक्त किया गया
- उदाहरण:
<START_Q>{question}<END_Q><START_A>{answer}<END_A>
- उदाहरण:
- special tokens model को input और output segments अलग पहचानने और output कहाँ रोकना है यह साफ़ तौर पर सीखने में मदद करते हैं
- उदाहरण में
<END_A>को stopping token के रूप में define किया गया, ताकि task पूरा होने पर output रुक जाए
- उदाहरण में
- Llama tokenizer सामान्यतः 32,000 token IDs output करता है
- चार special tokens जोड़ने पर यह 32,004 IDs output करता है
<START_Q>को 32000,<END_Q>को 32001 जैसे नए IDs दिए जाते हैं
- script
tokenizer.add_tokens(special_tokens, special_tokens=True)से special tokens जोड़ती है, औरmodel.resize_token_embeddings(len(tokenizer))से नए trainable parameters बनाती है
ViGGO: unstructured text को functional representation में बदलना
- ViGGO मूल रूप से attribute-value आधारित functional representation को natural language text में बदलने वाला English dataset है, लेकिन इस प्रयोग में दिशा उलटकर unstructured text को structured functional representation में बदला गया
- domain video game opinions का है
- resulting representation indexing और downstream applications में इस्तेमाल हो सकती है
- model को sentence के अनुरूप function और attribute values generate करने होते हैं
- function candidates में
inform,request,give_opinion,confirm,verify_attribute,suggest,request_explanation,recommend,request_attributeशामिल हैं - attribute candidates में
name,release_year,esrb,genres,platforms,available_on_steam,has_linux_release,has_mac_release,specifier,rating,player_perspective,has_multiplayer,developer,exp_release_dateआदि शामिल हैं
- function candidates में
- उदाहरण input
What's a really fast-paced game with multiplayer that you like to play?का expected outputrequest(has_multiplayer[yes], specifier[fast-paced])है - सामान्य models इच्छित output format को ठीक से follow नहीं कर पाए, और लंबे input context के कारण output generation से ज़्यादा समय input processing में लगने की समस्या भी रही
- यह task जटिल logical reasoning से ज़्यादा pattern recognition और बुनियादी language understanding पर आधारित है
- यह एक grounded task है, जिसमें ज़रूरी तथ्य input में ही मौजूद होते हैं
- few-shot prompting का मददगार होना इस बात का संकेत माना गया कि छोटे Llama-2 मॉडल भी fine-tuning से सुधर सकते हैं
ViGGO evaluation और परिणाम
- evaluation में केवल exact string match का उपयोग नहीं किया गया
- जाँचा गया कि output function सही है या नहीं
- attribute type सही है या नहीं
- function के भीतर attributes तय priority order का पालन करते हैं या नहीं
- GPT, Llama-2-chat जैसे instruction-following models के लिए prompt में attribute ordering rule स्पष्ट दिया गया था, इसलिए evaluation में उनसे यह नियम पालन करने की अपेक्षा की गई
- evaluation तेज़ करने के लिए Ray की batch inference API और Anyscale की Aviary का साथ में उपयोग हुआ
- LLM generation और post-processing को जोड़ा गया और कई machines पर distribute किया गया
- 7B और 13B models में fine-tuning के बाद accuracy में बड़ा सुधार हुआ
- GPT-4 की accuracy attribute priority को evaluation में शामिल करने पर काफी गिर गई
- fine-tuned models ने हमेशा priority follow की, और यह constraint जोड़ने पर भी उनकी accuracy नहीं बदली
- ViGGO के नतीजे दिखाते हैं कि structured format वाले tasks में fine-tuning एक स्थिर और कुशल तरीका हो सकती है
- यह सिर्फ regex या JSON format मिलान का मामला नहीं था, बल्कि किन arguments को शामिल करना है और उनकी order क्या होगी, यह भी तय करना था
- 7B·13B models से मिले नतीजों का मतलब है कि serving cost GPT-4 endpoint call से कम हो सकती है
SQL generation: natural language और table context से query बनाना
- SQL generation task में natural language query और SQL
CREATE TABLEstatements को input लेकर executable SQL query बनानी होती है - उपयोग किया गया dataset b-mc2/sql-create-context WikiSQL और Spider को मिलाकर बनाया गया Hugging Face dataset है
- हर data point में natural language query, SQL
CREATE TABLEstatement, और उसके अनुरूप SQL query शामिल होती है - कुल 78,577 data points थे
- हर data point में natural language query, SQL
- dataset में ground-truth SQL से जुड़ी समस्याएँ थीं
CREATE TABLEमें integer attributes कोVARCHARदिखाया गया था, लेकिन SQL query में उन्हें अक्सर integer की तरह treat किया गया- ऐसे सभी SQL queries हटा दिए गए जो integer attributes मानकर लिखी गई थीं, जिससे dataset लगभग 70k से घटकर 45k रह गया
- यह task भी natural language को SQL जैसी structured representation में बदलने का है, इसलिए fine-tuning के लिए उपयुक्त है
- ViGGO के विपरीत, यहाँ कई SQL queries सही execution result दे सकती हैं, इसलिए ambiguity ज़्यादा है
SQL evaluation और परिणाम
- SQL generation की evaluation में साधारण string comparison उपयुक्त नहीं है
- character-level comparison कई false negatives दे सकता है
- AST comparison भी variable naming order जैसी चीज़ों के प्रति संवेदनशील हो सकता है
- सबसे भरोसेमंद तरीका है fake dataset पर code चलाकर outputs की तुलना करना
- प्रयोग में OpenAI GPT-3.5 endpoint से सैकड़ों examples के लिए unit test हेतु fake tables बनवाई गईं
- GPT-3.5 ने question, table schema और ground truth देखकर 10 data points वाली fake tables बनाईं
sqlglot.executor.executeसे ground-truth SQL और model SQL दोनों चलाकर result compare किया गया
- GPT-3.5 से बनी data tables की quality जाँचने के लिए ground-truth SQL पहले चलाया गया
- अगर result table खाली थी या मूल table जितनी ही लंबाई की थी, तो उस example को हटा दिया गया
- इस प्रक्रिया में GPT द्वारा बनाई गई लगभग 50% data tables filter हो गईं
- fine-tuned Llama-2 7B और 13B ने 70B-chat और GPT-4 से बेहतर performance दी
- Llama chat models की आम गलती यह थी कि prompt instructions के बावजूद SQL को
<SQL>tags के अंदर लगातार एक-सा format में नहीं रखते थे - यह समस्या 7B·13B chat models में 70B की तुलना में अधिक आम थी
- Llama chat models की आम गलती यह थी कि prompt instructions के बावजूद SQL को
- SQL dataset की कुछ natural language queries पूरी तरह सही English में नहीं थीं, और संभव है कि इस noise ने GPT-4 के परिणामों को प्रभावित किया हो
- fine-tuned models dataset की इन अजीब विशेषताओं के अनुरूप जल्दी ढल गए
GSM8k: structure learning से कठिन math reasoning
- GSM8k math reasoning और comprehension को मापने वाला एक standard academic benchmark है
- जहाँ पिछले दो tasks मुख्यतः structure learning पर आधारित थे, वहीं GSM8k यह देखने का task है कि model math problems हल करने के लिए अपने reasoning process को कितना सुधार सकता है
- उदाहरण problem में पूछा जाता है कि अप्रैल में 48 चीज़ें बिकीं और मई में उसका आधा बिका, तो कुल बिक्री कितनी हुई; answer बीच के calculations के साथ
#### 72format में समाप्त होता है - मौजूदा LLM अक्सर अंतिम उत्तर को अंदर ही अंदर निकालकर सीधे नहीं देते, बल्कि output के हिस्से के रूप में reasoning process generate करते हैं ताकि आगे का token generation उसी logical process पर आधारित रहे
- इस task में सिर्फ calculation नहीं, बल्कि premises से intermediate conclusions होते हुए final answer तक पहुँचने वाली logical chain of thought चाहिए
GSM8k evaluation तरीका और baselines
- evaluation के लिए model output से final answer को स्थिर तरीके से निकालना ज़रूरी है
- सामान्य language models मनचाहा output format लगातार follow नहीं करते, इसलिए automated evaluation कठिन हो सकती है
- इसके लिए OpenAI function calling API का उपयोग किया गया
gpt-3.5-turbo-0613कोreport_answerfunction call के जरिए दूसरे models के outputs से अंतिम integer answer निकालने के लिए इस्तेमाल किया गया- उदाहरण के लिए, अगर model कहे “The answer is four”, तब भी उसे
4के रूप में parse किया जा सकता है
- इस तरीके की वैधता dataset answers पर जाँचकर पुष्टि की गई, लेकिन इसकी कमी यह है कि evaluation में OpenAI token cost लगती है
- fine-tuned models ने target answer pattern जल्दी सीख लिया, इसलिए गलत होने पर भी उनका output structure अनुमानित रहा
- fine-tuned model evaluation
#### {answer}regex से की गई, जिससे OpenAI endpoint post-processing से बचा गया
- fine-tuned model evaluation
- baselines इस प्रकार थे
- paper में प्रकाशित base pre-trained model का 8-shot prompting परिणाम
- Meta द्वारा RLHF के साथ general-purpose assistant बनाने के लिए train किए गए Llama-2 chat-tuned variants पर कई prompt-engineered templates
GSM8k परिणाम और दो-चरणीय fine-tuning
- base model fine-tuning ने GSM8k performance को लगातार बेहतर किया, लेकिन हर बार chat-tuned models से बहुत बेहतर परिणाम नहीं दिए
- संभव है chat models ने chat-tuning के दौरान math examples पर training पाई हो, इसलिए उनकी accuracy base models से अधिक रही
- fine-tuned models में prompting का तरीका हमेशा base model से बेहतर परिणाम नहीं देता
- उदाहरण के लिए, Llama-2-70B-chat का प्रदर्शन 8-shot examples वाले prompted base model से कम हो सकता है
- fine-tuned models ने 8-shot prompted base models से लगातार बेहतर परिणाम दिए
- serving cost के लिहाज़ से fine-tuned models लाभ में हो सकते हैं
- prompt-based तरीकों में हर request पर prompt tokens की cost जुड़ती है
- fine-tuned models में प्रभावी तौर पर सिर्फ question tokens की लागत जुड़ती है
- GSM8k training data अपेक्षाकृत छोटी, लगभग 8k entries की थी, इसलिए माना गया कि यह Llama-13B की पूरी क्षमता निकालने के लिए पर्याप्त नहीं है
- Llama-13B base model को पहले MathQA पर fine-tune करके फिर GSM8k पर दोबारा fine-tune करने वाले दो-चरणीय तरीके से अतिरिक्त सुधार मिला
- केवल GSM8k से fine-tuning करने पर base की तुलना में 10%p सुधार मिला
- MathQA के बाद GSM8k पर की गई दो-चरणीय fine-tuning से शुरुआती fine-tuning परिणाम पर अतिरिक्त 10%p, और base की तुलना में कुल 20%p सुधार मिला
- MathQA में 30,000 question-answer pairs हैं, लेकिन यह GSM8k की तुलना में अधिक noisy है और इसकी structure अलग है
- answer quality कम है और final answer multiple choice format में होता है
- फिर भी दो-चरणीय fine-tuning ने MathQA का उपयोग करके GSM8k के अंतिम परिणामों को बेहतर बनाने में प्रभावशीलता दिखाई
प्रैक्टिकल उपयोग में किन बातों को देखना चाहिए
- GPT-4, Claude-2 जैसे closed models prototyping और शुरुआती value validation में मजबूत हैं, लेकिन production LLM apps चलाने के लिए हमेशा पर्याप्त नहीं होते
- niche tasks के लिए LLM fine-tuning सिर्फ privacy नहीं, बल्कि latency, cost और quality के लिहाज़ से भी मूल्यवान हो सकती है
- ViGGO और SQL examples में quality के मामले में GPT-4 से बेहतर परिणाम भी मिले
- fine-tuning में सबसे महत्वपूर्ण फोकस infrastructure implementation details से ज़्यादा data collection और evaluation pipeline बनाना है
- evaluation pipeline अलग-अलग समाधानों के trade-off को business requirements के अनुसार compare करने की बुनियाद बनती है
- प्रयोग Anyscale fine-tuning और serving platform तथा Anyscale Endpoints का उपयोग करके किए गए
- यही प्रक्रिया अपने data और अपने cloud में दोहराने के लिए Ray पर बने Anyscale fine-tuning और serving solution से कॉन्फ़िगर की जा सकती है
1 टिप्पणियां
Hacker News की राय
कुछ हफ्ते पहले कोडिंग लाइव स्ट्रीम में मैंने अपने डेटासेट से Llama 2 को fine-tune करने पर काफी बात की थी, और यह Colab के single GPU पर किया था
मेरे मामले में डेटासेट मेरा अपना code था।
Fine-tuning Llama stream: https://www.youtube.com/watch?v=TYgtG2Th6fI&t=2282s
QLoRA fine-tuning के कुछ और sessions भी हैं, और मैं concepts को ऐसे perspective से समझाता हूं जैसे एक 8 साल के अनुभव वाला software engineer जो हाल में machine learning में आया है और self-taught है
QloRa fine-tuning stream: https://www.youtube.com/watch?v=LitybCiLhSc&t=4584s
मैं कोशिश कर रहा हूं कि अपने personal projects और अभी चल रहे AI-based startup में इसे कैसे approach करता हूं, इसे जितना हो सके आसान तरीके से समझाऊं। सबसे छोटे web development LLM को fine-tune करने वाली series को भी response ठीक लग रहा है, और streaming को करीब एक महीना हुआ है; आगे और ज्यादा पोस्ट करने का plan है
fine-tuned models को बांटकर रखने का तरीका भी मुझे ठीक से समझ नहीं आता। क्या Terraform LLM, SQL LLM, Python LLM अलग-अलग चाहिए, या बस एक “code” LLM काफी है?
implementation details बहुत ज्यादा चाहिए होती हैं, इसलिए जब तक use case meaningful न हो, accessibility कम हो जाती है। privateGPT शायद धीरे-धीरे उस मुकाम तक पहुंच जाएगा
यह वह हिस्सा है जिसे दूसरे tutorials अक्सर skip कर देते हैं। खासकर safety, accuracy जैसे अलग-अलग goals के हिसाब से इसे कैसे तैयार किया जाए, यह जानना चाहूंगा
Llama 2 में भी मुझे यही समस्या आ रही है। सिर्फ desired text output करवाना लगभग असंभव है; यह हमेशा response के आगे-पीछे कुछ न कुछ जोड़ देता है
जानना चाहूंगा कि इस समस्या को ठीक करने के लिए कोई prompt technique है क्या
airoboros backticks, explanations आदि से बचते हुए सिर्फ code output करवाने के लिए PLAINFORMAT token support करता है
https://huggingface.co/TheBloke/airoboros-l2-70B-GPT4-2.0-GG...
guarantee चाहिए तो छोटा dataset, लगभग 1 हजार examples, लेकर fine-tune करना और फिर वहां से improve करना सबसे अच्छा है
मेरा use case creative writing से ज्यादा text से information extraction/synthesis करने वाला simple task था। base model हर task के लिए फिट नहीं हो सकता
contentstring या JSON के अंदर output करेअगर JSON है तो start और end पहचान सकते हैं, इसलिए JSON के बाहर का content हटा सकते हैं
ऐसा लेख देखकर खुशी हुई। online model customization पर बहुत चर्चा थी, लेकिन यह लेख noise को काफी अच्छी तरह हटाता है
evaluation methodology भी पसंद आई और लेख भी अच्छी तरह लिखा हुआ लगता है
LoRA और quantized training को ज्यादा गंभीरता से न लिया जाना अजीब है। यह काफी सस्ता है, कम समय लेता है, और इसके काफी अच्छे होने के बहुत evidence भी हैं
इसे बाद में try करने वाले extra option की तरह किनारे नहीं रखना चाहिए
यह देखकर अच्छा लगा कि NER जैसी task ने best performance दी। मैं अभी fine-tuned BERT model से compare करने के लिए similar test शुरू करने ही वाला था
इस task का training cost कितना रहा होगा, यह जानना चाहूंगा
block size घटाया जा सकता था, लेकिन code न बदलना आसान था इसलिए वैसे ही रखा। 7B में 16xA10G पर प्रति epoch करीब 15 मिनट, 13B में करीब 25 मिनट लगे। इसलिए on-demand cost प्रति epoch 7B के लिए करीब $7.2, 13B के लिए करीब $12 है। ये values सिर्फ training में लगे समय पर आधारित हैं, cluster start/stop time शामिल नहीं है
लिखा है कि 7B और 13B के लिए 16xA10G, और 70B के लिए 32xA10G को 4 g5.48xlarge instances में बांटकर इस्तेमाल किया गया। Ray इस्तेमाल करने पर ऐसे models की full-parameter fine-tuning के लिए A100 जुटाने की जरूरत नहीं होती, और हर task के लिए यही process repeat किया गया। GSM8k dataset में context length 512 और प्रति epoch effective tokens 37 लाख के साथ example run दिखाया गया
उन्होंने कहा कि training अधिकतम 10 epochs तक की गई, और validation set पर minimum perplexity दिखाने वाला checkpoint चुना गया
एक मुश्किल यह है कि पर्याप्त बड़ा custom dataset बनाने के लिए छोटी-सी सेना जैसी manpower या बहुत मजबूत existing model चाहिए
आखिर में OpenAI का इस्तेमाल करना पड़ने की संभावना ज्यादा है, लेकिन OpenAI से दूसरे model के training material generate करना terms का violation है। क्या इस पर कभी lawsuit तक बात गई है? या लोग इसे बस unfair मानकर ignore कर देते हैं?
आजकल NER examples ज्यादा दिख रहे हैं, तो सोच रहा हूं कि ऐसे tasks के लिए spaCy क्यों नहीं इस्तेमाल करते
मैं Anyscale में काम करता हूं
लगता है इस blog को अच्छी attention मिली है, इसलिए इसे Ray Summit में शामिल करने का plan है: https://raysummit.anyscale.com/agenda
अगर Ray Summit में किस तरह का content और देखना चाहेंगे, इसके ideas हों तो बताइए
35 लाख tokens के आधार पर 7B में 1 epoch लगभग 14 मिनट, और 13B में 1 epoch लगभग 26 मिनट बताया गया है
कहा गया है कि 7B और 13B दोनों के लिए head node के तौर पर कम से कम 1xg5.16xlarge और worker nodes के तौर पर 15xg5.4xlarge चाहिए; AWS पर cost लगभग कितनी होगी, जानना चाहूंगा
us-east-1 में चलाएं तो लगभग $30 per hour समझें
https://instances.vantage.sh/?selected=g5.16xlarge,g5.4xlarg...
जानना चाहूंगा कि M1 Ultra 64GB पर Llama-2 को local fine-tune किया जा सकता है या नहीं। ज्यादातर resources cloud या Linux पर Nvidia CUDA इस्तेमाल करने वाले हैं, इसलिए कोई reference material हो तो अच्छा होगा
training के लिए RunPod credits थोड़ा खरीदने का plan है, और लगता है कुछ दर्जन dollars में हो जाएगा