• AlphaFold3 का लक्ष्य केवल एक protein से आगे बढ़कर protein, nucleic acid और small molecule वाले complexes को सिर्फ sequence से predict करना है, इसलिए AF2 की तुलना में इसकी input representation और tokenization कहीं ज़्यादा जटिल हो जाती है
  • Input को token-level single/pair representations, atom-level representations, MSA और templates में बांटा जाता है; standard amino acid और nucleotide को 1 token, जबकि non-standard residues और अन्य molecules को प्रति atom 1 token के रूप में handle किया जाता है
  • Representation learning trunk, template module, MSA module और Pairformer के जरिए pair-bias attention, triangle operations और recycling से single representation s और pair representation z को बार-बार improve करता है
  • Structure prediction में AF2 के Invariant Point Attention की जगह atom coordinates के लिए conditional diffusion model का उपयोग होता है, और rotation/translation augmentation व denoising के जरिए सभी atoms के coordinate updates बनाए जाते हैं
  • Training में distogram, diffusion और confidence loss को जोड़ा जाता है, और AF2 व AF-Multimer results का उपयोग करने वाली cross-distillation से low-confidence regions की unfolded representation तक फिर से train की जाती है

AlphaFold3 का input scope और overall pipeline

  • AlphaFold3 का लक्ष्य AF2 की तरह केवल individual protein sequence predict करने या AF-Multimer की तरह सिर्फ protein complexes तक सीमित रहने के बजाय, protein और वैकल्पिक रूप से दूसरे proteins, nucleic acids और small molecules से जुड़ी structures को सिर्फ sequence से predict करना है
  • “Token” का मतलब input type के हिसाब से बदलता है
    • Protein: 1 standard amino acid = 1 token
    • DNA/RNA: 1 standard nucleotide = 1 token
    • Non-standard amino acid/nucleotide: 1 atom = 1 token
    • Other molecule: 1 atom = 1 token
  • 35 standard amino acids वाला protein वास्तव में 600 से ज़्यादा atoms रख सकता है, लेकिन उसे 35 tokens से represent किया जाता है; वहीं 35 atoms वाला ligand 35 tokens से represent होता है
  • Model मोटे तौर पर तीन stages से बना है
    • Input Preparation: user input sequences और search से मिली related sequences/structures को numeric tensors में बदलना
    • Representation Learning: single और pair representations को attention के कई variants से update करना
    • Structure Prediction: conditional diffusion से structure predict करना
  • Protein complex मुख्य रूप से दो representations में store होता है
    • single representation: complex के सभी tokens को खुद represent करता है
    • pair representation: सभी token pairs के बीच distance और संभावित interactions जैसे relations को represent करता है
  • मुख्य channel dimensions c_z=128, c_m=64, c_atom=128, c_atompair=16, c_token=768, c_s=384 हैं

Input preparation: sequence को 6 tensors में बदलने की प्रक्रिया

  • User द्वारा दिया गया input model trunk में जाने वाले 6 tensors में convert होता है
    • s: token-level single representation
    • z: token-level pair representation
    • q: atom-level single representation
    • p: atom-level pair representation
    • m: MSA representation
    • t: template representation
  • MSA और template search

    • AF3 protein और RNA sequences के लिए similar sequences खोजता है और उन्हें MSA के रूप में बनाता है; related structures को template के तौर पर शामिल करता है
    • MSA कई species में पाए जाने वाले similar protein sequences को align करके किसी खास position के conservation patterns और अलग-अलग positions के बीच variation correlations model को देता है
    • Similar proteins की known structures, homology modeling की तरह query protein structure का अनुमान लगाने में इस्तेमाल होती हैं
    • Search में training शामिल नहीं होती, और HMM-based methods इस्तेमाल किए जाते हैं
    • jackhmmer, HHBlits, nhmmer से कई protein/RNA databases search किए जाते हैं, और hmmsearch से Protein Data Bank में similar sequences खोजे जाते हैं
    • Computational complexity के कारण MSA size को N_MSA < 2^14 तक सीमित किया जाता है
    • हर protein chain में high-quality structure चुनी जाती है, और अधिकतम 4 को template के रूप में sample किया जाता है
    • AF-Multimer की तुलना में नया जोड़ा गया search element यह है कि RNA sequences को भी search target में शामिल किया जाता है
  • Template representation method

    • Template की 3D structure में हर token pair के बीच Euclidean distance calculate की जाती है
    • कई atoms वाले token के लिए representative “center atom” इस्तेमाल होता है
      • Amino acid: atom
      • Standard nucleotide: C1' atom
    • Distance values continuous नहीं, बल्कि distogram के रूप में discretize की जाती हैं
      • 3.15Å से 50.75Å तक 38 bins
      • इससे बड़ी distances के लिए 1 extra bin
    • Distogram में chain information, crystal structure में उस token के resolved होने या न होने की जानकारी, और हर amino acid के भीतर local distance information जोड़ी जाती है
    • Template matrix को mask किया जाता है ताकि वह सिर्फ same chain के भीतर की distances देखे; template selection से inter-chain interaction information हासिल करने की कोशिश नहीं की जाती

परमाणु-स्तर representation और Atom Transformer

  • reference conformer और atom-level representation

    • परमाणु-स्तर single representation q बनाने के लिए हर amino acid, nucleotide, ligand के लिए reference conformer की गणना की जाती है
    • conformer किसी molecule की 3D atomic arrangement है, जो single bonds के आसपास rotation को sample करके बनाई जाती है
    • standard amino acids के लिए lookup से मिलने वाले low-energy conformer का उपयोग किया जाता है, और छोटे molecules के लिए RDKit’s ETKDGv3 से 3D conformer बनाया जाता है
    • conformer की relative positions, atomic charges, atomic numbers, identifiers आदि को जोड़कर atom-level single representation c बनाया जाता है
    • c से atom-level pair representation p को initialize किया जाता है, और mask v का उपयोग किया जाता है ताकि इसमें सिर्फ reference conformer से गणना की गई atoms के बीच की distances रहें
    • q की शुरुआत c की copy के रूप में होती है, फिर Atom Transformer में इसे update किया जाता है
  • Atom Transformer की भूमिका

    • Atom Transformer एक module है जो परमाणु-स्तर attention करता है, और p व मूल representation c का उपयोग करके q को update करता है
    • c update नहीं होता, बल्कि starting representation की ओर जाने वाले residual connection की तरह इस्तेमाल होता है
    • बुनियादी structure transformer जैसा है, जिसमें LayerNorm, attention, MLP transition शामिल हैं, लेकिन हर step को c और p के additional inputs से adjust किया जाता है
  • Adaptive LayerNorm

    • Adaptive LayerNorm fixed gamma, beta learn करने के बजाय auxiliary input से gamma, beta generate करता है
    • Atom Transformer में rescale होने वाला target q है, और rescale parameters auxiliary input c से predict किए जाते हैं
  • Attention with Pair Bias

    • Atom-level attention with pair bias, self-attention का विस्तार है
    • query, key, value सभी single representation q से आते हैं, लेकिन query-key dot product के बाद pair representation p की linear projection को bias के रूप में जोड़ा जाता है
    • pair representation से q में information flow होती है, लेकिन इस step में q की information से p को update नहीं किया जाता
    • additional projection को sigmoid से गुजारकर बना gate, attention result से multiply होता है और control करता है कि residual stream में कौन-सी information छोड़ी जाए
    • atoms की संख्या tokens की संख्या से कहीं अधिक हो सकती है, इसलिए full attention के बजाय Sequence-local atom attention का उपयोग किया जाता है
    • 32 atoms की local group, 128 अन्य atoms पर attend कर सकती है
  • Conditioned Gating और Transition

    • Conditioned Gating मूल atom-level single matrix c से बनाए गए gate को data पर apply करता है
    • Conditioned Transition transformer के MLP के बराबर है, और इसे conditioned इसलिए कहा जाता है क्योंकि Adaptive LayerNorm और Conditional Gating c पर निर्भर करते हैं
    • AF3 transition block में ReLU के बजाय SwiGLU का उपयोग करता है
    • AF2 का ReLU-based transition 4x up-projection, ReLU, down-projection structure है
    • AF3 का SwiGLU दो up-projections में से एक पर swish nonlinearity apply करता है, फिर multiply करके down-project करता है

परमाणु representation को token representation में aggregate करना

  • representation learning stage बाद में token-level पर काम करता है, इसलिए atom-level representation को token-level representation में aggregate किया जाता है
  • atom-level representation को बड़े dimension में projection करने के बाद, एक ही token से संबंधित atoms का average लिया जाता है
  • यह average aggregation standard amino acids और nucleotides जैसे मामलों में apply होता है, जहाँ कई atoms एक token से जुड़े होते हैं; प्रति atom 1 token वाले inputs वैसे ही बने रहते हैं
  • token-level single input में MSA से मिले statistics भी जोड़े जाते हैं
    • amino acid type
    • उस position का MSA amino acid distribution
    • उस token का deletion mean
  • ligand atoms जैसे tokens जिनमें MSA नहीं होता, उनके लिए ये values 0 हो जाती हैं
  • इस तरह बना s_inputs projection से गुजरकर s_init बनता है, और representation learning stage में update होता है
  • pair representation z_init एक 3D tensor है जो token pairs के संबंध store करता है, और हर z_i,j c_z=128 dimension का vector है
  • z_i,j initialization में s_i, s_j की projection, relative positional encoding, और user द्वारा specified tokens के बीच की bond information जोड़ी जाती है

Representation learning: Template, MSA, Pairformer

  • representation learning वह trunk है जो model computation का अधिकांश हिस्सा लेता है, और इसका उद्देश्य token-level single representation s और pair representation z को improve करना है
  • single sequence representation सिर्फ एक protein sequence नहीं, बल्कि structure के भीतर सभी atoms या tokens को जोड़कर बनी sequence को दर्शाता है
  • Template Module

    • हर template linear projection से गुजरता है और pair representation z की linear projection के साथ जोड़ा जाता है
    • combined matrix Pairformer Stack से गुजरता है
    • कई template results का average लिया जाता है और फिर वे एक linear layer से गुजरते हैं
    • अंतिम linear layer में ReLU का उपयोग होता है, और यह AF3 में उन कम जगहों में से एक है जहाँ ReLU को nonlinearity के रूप में इस्तेमाल किया जाता है
  • MSA Module

    • MSA Module AF2 के Evoformer से बहुत मिलता-जुलता है, और MSA representation m व pair representation z दोनों को साथ-साथ improve करता है
    • पूरी MSA row का उपयोग करने के बजाय subsampling किया जाता है, फिर single representation की projection को MSA में जोड़ा जाता है
    • Outer Product Mean MSA information को pair representation में डालने वाला operation है
      • हर token index i,j के लिए सभी evolutionary sequences पर m_s,i और m_s,j का outer product calculate किया जाता है
      • इसे पूरी sequence पर average करके flatten किया जाता है, फिर projection के बाद z_i,j में जोड़ा जाता है
      • model में evolutionary sequences के बीच information share होने का यह एकमात्र point है
    • Row-wise gated self-attention using only pair bias pair representation का उपयोग करके MSA को update करता है
      • query और key से attention score बनाने के बजाय, pair representation z को matrix में project करके tokens के बीच attention score के रूप में इस्तेमाल किया जाता है
      • हर MSA row पर स्वतंत्र रूप से apply होने के कारण इस step में evolutionary sequences के बीच information share नहीं होती
    • MSA module के अंत में triangle update और triangle attention से pair representation को फिर update किया जाता है

Pairformer और triangle operations

  • Template और MSA से z को update करने के बाद, template और MSA का अब उपयोग नहीं होता; केवल s और z Pairformer में input किए जाते हैं
  • Pairformer 48 blocks की repetition के ज़रिए final s_trunk और z_trunk बनाता है
  • triangle operation की intuition

    • triangle update और triangle attention ऐसी structure हैं जो triangle inequality की intuition को model में reflect करने की कोशिश करती हैं
    • pair tensor का z_i,j भौतिक दूरी खुद तो नहीं है, लेकिन token i और j के relation को रखता है, इसलिए i-j, j-k, i-k के तीनों relations को आपस में consistent रहने के लिए update किया जाता है
    • triangle inequality को model के भीतर सीधे enforce नहीं किया जाता; बल्कि सभी triplet (i,j,k) को देखते हुए z_i,j को update करने के तरीके से induce किया जाता है
    • z को directed adjacency matrix की तरह देखा जा सकता है, इसलिए outgoing edge और incoming edge directions को अलग-अलग process किया जाता है
  • Triangle Updates

    • outgoing update में हर z_i,j को उसी row के दूसरे element z_i,k और तीसरे edge z_j,k का उपयोग करके update किया जाता है
    • implementation में z की तीन projections a, b, g बनाई जाती हैं, row i और row j की element-wise multiplication को k पर sum किया जाता है, फिर gate g apply किया जाता है
    • incoming update row और column को swap किया हुआ रूप है, जिसमें z_i,j को उसी column के दूसरे elements z_k,j और z_k,i के ज़रिए update किया जाता है
  • Triangle Attention

    • triangle attention, 2D matrix की row और column पर independent attention apply करने वाले axial attention में triangle principle जोड़ा हुआ रूप है
    • “starting node” case में z_i,j और z_i,k की query-key comparison में z_j,k को bias के रूप में जोड़ा जाता है
    • “ending node” case column basis पर काम करता है, और z_i,jz_k,i के attention score को z_k,j से bias करता है
  • Single Attention with Pair Bias

    • triangle step और transition block के बाद, single representation s को updated pair representation z का उपयोग करने वाले single attention with pair bias से update किया जाता है
    • क्योंकि यह token-level पर काम करता है, इसलिए atom-level में इस्तेमाल होने वाले block-wise sparse attention के बजाय full attention का उपयोग करता है

Structure prediction: atomic coordinates को diffusion से denoising करना

  • diffusion model का basic तरीका

    • AF3 final structure prediction को atom-level diffusion से करता है
    • diffusion model real data में step-by-step random noise जोड़ता है, और model को यह predict करने के लिए train करता है कि कौन-सा noise add हुआ था
    • inference में यह complete random noise से शुरू करता है, और हर step पर model द्वारा predict किया गया noise हटाते हुए denoised datapoint बनाता है
    • conditional diffusion current noisy generation, current timestep representation, और condition vector को input के रूप में लेकर condition के हिसाब से result बनाता है
    • AF3 में denoising का target सभी atoms के x,y,z coordinates वाली matrix x है
  • AF2 के IPA के बजाय rotation-translation augmentation

    • AF3, AF2 के Invariant Point Attention का उपयोग नहीं करता; इसके बजाय हर timestep पर prediction में मौजूद पूरे complex को randomly rotate और translate करता है
    • यह augmentation model को यह सीखने देता है कि कोई भी rotation और translation उसी structure के रूप में valid है, और यह AF2 के IPA की तुलना में ज्यादा simple approach है
    • rotation current generation के सभी atomic coordinates के mean को center मानकर apply किया जाता है, और translation हर dimension में N(0,1) Gaussian से sample किया जाता है
    • coordinates में छोटा noise भी add किया जाता है ताकि ज्यादा diverse generations induce हों
    • inference में कई generations को confidence head से score किया जा सकता है, और सबसे high score वाली generation return की जा सकती है
  • Diffusion Module के चार stages

    • हर denoising step कई conditioning representations का उपयोग करता है
      • trunk outputs s_trunk, z_trunk
      • input embedder से बने initial representations s_inputs, c_inputs
    • diffusion process token और atom spaces के बीच आते-जाते हुए चार stages से बना है
        1. token-level conditioning tensor तैयार करना
        1. atom-level conditioning tensor तैयार करना, Atom Transformer apply करना, token-level पर aggregate करना
        1. token-level attention apply करना
        1. atom-level attention से per-atom noise update predict करना
    • token-level conditioning में z_trunk और relative positional encoding को combine करके transition block से pass कराया जाता है
    • single representation में s_inputs और s_trunk को combine किया जाता है, और diffusion timestep के अनुसार Fourier embedding जोड़ी जाती है
    • atom-level stage में initial c, p को current token-level representation से update किया जाता है, और current coordinates x को data variance से scale करके dimensionless coordinate r बनाया जाता है
    • final atom-level stage में linear layer q को R^3 में map करके सभी atoms के coordinate update r_update बनाती है
    • update को data variance और noise schedule को ध्यान में रखकर x_update में rescale किया जाता है, फिर current coordinates x_l पर apply किया जाता है

Loss function और confidence head

  • कुल loss तीन terms का weighted sum है

L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence

  • L_distogram

    • L_distogram token-level पर predicted distogram की accuracy का मूल्यांकन करता है
    • atomic coordinates से token coordinates बनाते समय हर token के center atom coordinates का उपयोग किया जाता है
    • distogram distance को categorical value की तरह treat किया जाता है, और predicted distogram व वास्तविक distogram की तुलना cross entropy से की जाती है
  • L_diffusion

    • L_diffusion atom position को target करने वाले कई terms का weighted sum है
    • L_MSE center atom नहीं, बल्कि सभी atoms के लिए positions के बीच mean squared error calculate करता है, और DNA, RNA, ligand atoms को upweight किया जाता है
    • L_bond protein-ligand bond में शामिल atom pair की bond length accuracy बढ़ाने के लिए एक अतिरिक्त MSE term है
    • शुरुआती training stage में α_bond=0 होता है, इसलिए इसे बाद में introduce किया जाता है
    • L_smooth_LDDT local distance accuracy को smooth और differentiable बनाने वाला loss है
      • thresholds के रूप में 4Å, 2Å, 1Å, 0.5Å — ये चार values इस्तेमाल होती हैं
      • nucleotide atom pairs 30Å से ज्यादा दूर हों तो ignore किए जाते हैं
      • protein या ligand atom pairs 15Å से ज्यादा दूर हों तो ignore किए जाते हैं
  • L_confidence

    • L_confidence structure accuracy को सीधे बढ़ाने के बजाय, model को अपनी prediction की accuracy estimate करना सिखाता है
    • यह चार confidence metrics से जुड़े losses से बना है
      • pLDDT: पास के atoms के लिए local distance accuracy
      • PAE: token pair का predicted alignment error
      • PDE: token pair के बीच predicted distance error
      • experimentally resolved prediction: हर atom experimental structure में resolved हुआ है या नहीं, इसकी prediction
    • भले ही predicted structure inaccurate होने के कारण PAE high हो, अगर model PAE को भी high predict करता है तो संबंधित PAE loss कम हो सकता है
    • confidence prediction diffusion के intermediate stage में generate होती है
    • confidence loss का gradient केवल confidence prediction head को update करता है, और model के बाकी हिस्सों पर असर नहीं डालता

अतिरिक्त training techniques और optimization

  • Recycling

    • AF3, AF2 की तरह weight recycling का उपयोग करता है
    • model को और deep बनाने के बजाय वही weights कई बार reuse करके representation को धीरे-धीरे improve करता है
    • diffusion भी inference में timestep information का उपयोग करता है और हर timestep पर वही weights reuse करता है, इसलिए उसमें recycling अंतर्निहित है
  • Cross-distillation

    • AF3 अपने द्वारा बनाए गए synthetic training data के अलावा AF2 और AF-Multimer द्वारा बनाए गए synthetic data का भी उपयोग करता है
    • diffusion-based generation पर shift करने के बाद, AF2 में low-confidence और disordered regions को visually अलग दिखाने वाली “spaghetti” shape गायब हो जाने की समस्या थी
    • AF2 और AF-Multimer generation को AF3 training data में शामिल करके, AF3 यह सीखता है कि जिन regions को लेकर AF2 confident नहीं था, वहां unfolded region कैसे output किया जाए
    • distillation dataset में AF2 और AF-Multimer जिन nucleic acids और small molecules को handle नहीं कर सकते, उन्हें remove किया जाता है
    • previous model predicted structure बनाने के बाद जब original से alignment करता है, तो हटाए गए molecules फिर से add किए जाते हैं
    • अगर फिर से add किया गया molecule atom clash बनाता है, तो पूरी structure को exclude किया जाता है, ताकि model clash allow करना न सीखे
  • Cropping और training stage

    • model में input sequence length पर कोई explicit limit नहीं है, लेकिन कई operations N_tokens^3 के हिसाब से बढ़ते हैं, जिससे memory और compute requirements बढ़ जाती हैं
    • efficiency के लिए protein को random crop किया जाता है
    • चूंकि कई chains के बीच interaction को model करना होता है, इसलिए crop में chains को साथ शामिल करना चाहिए
    • तीन cropping methods इस्तेमाल होती हैं
      • contiguous cropping: हर chain से continuous amino acid sequence चुनना
      • spatial cropping: reference atom तक की distance के आधार पर amino acids चुनना
      • spatial interface cropping: binding interface के atom तक की distance के आधार पर चुनना
    • random crop 384 से trained model को लंबे sequence पर भी apply किया जा सकता है, लेकिन longer sequence handling capability बढ़ाने के लिए बड़े sequence length पर repeated fine-tuning किया जाता है
  • Clashing और batch size

    • AF3 loss में overlapping atoms के लिए clash penalty शामिल नहीं है
    • diffusion-based structure module theoretically दो atoms को same location पर predict कर सकता है, लेकिन training के बाद यह समस्या छोटी रहती है
    • generated structures की ranking में clashing penalty का उपयोग किया जाता है
    • diffusion process जटिल दिखता है, लेकिन trunk की तुलना में इसकी computation cost कम है
    • training efficiency के लिए trunk के बाद batch size expand किया जाता है
    • हर input structure embedding और trunk से एक बार गुजरती है, और उसके बाद 48 independent data-augmented structures parallel में train की जाती हैं

ML दृष्टिकोण से AF3 डिज़ाइन

  • Retrieval-Augmented Generation जैसी संरचना

    • AF3 में MSA और template खोज का स्वभाव language model के RAG जैसा है
    • AlphaFold क्षेत्र में structural template इस्तेमाल करने का तरीका RAG शब्द से काफी पहले से homology modeling के रूप में इस्तेमाल होता रहा है
    • AF3 ने AF2 की तुलना में MSA processing का हिस्सा कम किया है, लेकिन MSA और template अब भी शामिल हैं
    • ESMFold जैसे कुछ protein prediction model retrieval को हटाकर fully parametric inference इस्तेमाल करते हैं
  • Pair-Bias Attention

    • AF2 का प्रमुख component रहा Pair-Bias Attention AF3 में और व्यापक रूप से इस्तेमाल होता है
    • query, key, value एक ही source से आते हैं, लेकिन attention map में किसी दूसरे source से आया bias term जोड़ा जाता है
    • यह full cross-attention की तुलना में information sharing का हल्का तरीका है
    • क्योंकि pair representation स्वाभाविक रूप से attention map से मिलता-जुलता है, यह structure protein modeling के लिए अच्छी तरह फिट हो सकता है
  • Self-supervised training में कमी

    • ESM परिवार के model self-supervised pre-training के जरिए MSA embedding को replace करने के तरीके में मजबूत साबित हुए थे
    • AF2 में MSA के masked token की prediction करने वाला एक अतिरिक्त task था, लेकिन AF3 में इसे हटा दिया गया
    • AF3 ने MSA processing compute घटाया है, और MSA के लिए self-supervised language modeling pre-training का इस्तेमाल नहीं करता
    • संभावित कारण यह हो सकते हैं कि massive pre-training compute उपयोग के लिहाज़ से inefficient रही हो, छोटा MSA module pre-trained embedding से बेहतर रहा हो, या amino acid·DNA/RNA·ligand के मिले-जुले hybrid atom-token structure और pre-trained embedding का combination मेल न खाता हो
  • Classification और Regression का मिश्रण

    • AF3, AF2 की तरह MSE और binned classification loss को साथ में इस्तेमाल करता है
    • distogram bin में सिर्फ एक bin की गलती होने पर भी, दूर तक गलत होने वाली स्थिति की तरह ही कोई credit नहीं मिलता—यह classification loss की खासियत है
    • इस design choice का आधार स्पष्ट नहीं है, लेकिन संभव है कि gradient कई MSE loss की तुलना में ज्यादा stable रहा हो
  • recurrent architecture से मिलते-जुलते तत्व

    • AF3 में सामान्य transformer की तुलना में recurrent network की याद दिलाने वाले कई तत्व हैं
    • gating residual stream में information flow को नियंत्रित करता है, और LSTM या GRU के gate जैसा है
    • recycling और diffusion एक ही weight को बार-बार apply करके prediction को धीरे-धीरे improve करते हैं
    • adaptive compute time की तरह, repeated update ऐसी संरचना से जुड़ा है जो कठिन input पर अधिक processing लागू कर सकती है
    • AF2 ablation में recycling का महत्व दिखा था, लेकिन gating के महत्व पर ज्यादा चर्चा नहीं हुई थी

अभी कोई टिप्पणी नहीं है.

अभी कोई टिप्पणी नहीं है.