Spaces:
Sleeping
Sleeping
| import gradio as gr, numpy as np | |
| import torch | |
| from transformers import EsmTokenizer,EsmForMaskedLM | |
| model_name = "facebook/esm2_t6_8M_UR50D" | |
| tokenizer = EsmTokenizer.from_pretrained(model_name) | |
| device = torch.device("cpu") | |
| model_mlm = EsmForMaskedLM.from_pretrained(model_name).to(device) | |
| # 2. Define the PLL Calculation Function | |
| def predict_ppp(sequence) -> float: | |
| """ | |
| Calculates the ESM2 Pseudolog-Likelihood (PLL) for a single amino acid sequence. | |
| PLL = sum_i( log P(x_i | x_{~i}) ) | |
| """ | |
| # Tokenize the sequence | |
| # This automatically adds <CLS> and <EOS> tokens | |
| input_ids = tokenizer(sequence, return_tensors='pt')['input_ids'] | |
| # The true sequence length (excluding special tokens) | |
| L = len(sequence) | |
| # The mask indices correspond to the AA sequence positions | |
| # We ignore the first (<CLS>) and last (<EOS>) tokens. | |
| mask_indices = torch.arange(1, L + 1) | |
| # Accumulator for the log-likelihood sum | |
| pll_sum = 0.0 | |
| # Iterate over each position in the sequence to mask it | |
| for i in mask_indices: | |
| # Create a copy of the input_ids | |
| masked_input = input_ids.clone() | |
| # Mask the current residue (token ID for MASK is 1) | |
| masked_input[0, i] = tokenizer.mask_token_id | |
| # Get model logits (unnormalized log-probabilities) | |
| with torch.no_grad(): | |
| outputs = model_mlm(masked_input) | |
| logits = outputs.logits # shape: (batch_size, seq_len, vocab_size) | |
| # Extract the log-probabilities for the prediction at the masked position | |
| # We use log_softmax to get log-probabilities | |
| log_probs = torch.log_softmax(logits[0, i], dim=-1) | |
| # Get the token ID of the *actual* residue at the masked position | |
| target_token_id = input_ids[0, i].item() | |
| # Get the log-probability of the actual residue | |
| log_prob_of_target = log_probs[target_token_id].item() | |
| # Add to the sum | |
| pll_sum += log_prob_of_target | |
| L = len(sequence) | |
| ppp = np.exp(-pll_sum / L) | |
| return ppp | |
| demo = gr.Interface( | |
| fn=predict_ppp, | |
| inputs=[ | |
| gr.Textbox(label="Enter Protein Amino Acid Sequence (1-letter code)", | |
| placeholder="ACDEFGHIKLMNPQRSTVWY"), | |
| ], | |
| outputs="text", | |
| title="Nano Protein Language Model for Pseudo Perplexity (PPP) prediction of a protein sequence", | |
| description="Enter an amino acid sequence (using the 1-letter code) to predict its Pseudo Perplexity (PPP)", | |
| examples=[ | |
| ["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"], # Example sequence | |
| ] | |
| ) | |
| demo.launch() |