PPP_Demo / app.py
sudhir2016's picture
Create app.py
a5b8f58 verified
import gradio as gr, numpy as np
import torch
from transformers import EsmTokenizer,EsmForMaskedLM
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
device = torch.device("cpu")
model_mlm = EsmForMaskedLM.from_pretrained(model_name).to(device)
# 2. Define the PLL Calculation Function
def predict_ppp(sequence) -> float:
"""
Calculates the ESM2 Pseudolog-Likelihood (PLL) for a single amino acid sequence.
PLL = sum_i( log P(x_i | x_{~i}) )
"""
# Tokenize the sequence
# This automatically adds <CLS> and <EOS> tokens
input_ids = tokenizer(sequence, return_tensors='pt')['input_ids']
# The true sequence length (excluding special tokens)
L = len(sequence)
# The mask indices correspond to the AA sequence positions
# We ignore the first (<CLS>) and last (<EOS>) tokens.
mask_indices = torch.arange(1, L + 1)
# Accumulator for the log-likelihood sum
pll_sum = 0.0
# Iterate over each position in the sequence to mask it
for i in mask_indices:
# Create a copy of the input_ids
masked_input = input_ids.clone()
# Mask the current residue (token ID for MASK is 1)
masked_input[0, i] = tokenizer.mask_token_id
# Get model logits (unnormalized log-probabilities)
with torch.no_grad():
outputs = model_mlm(masked_input)
logits = outputs.logits # shape: (batch_size, seq_len, vocab_size)
# Extract the log-probabilities for the prediction at the masked position
# We use log_softmax to get log-probabilities
log_probs = torch.log_softmax(logits[0, i], dim=-1)
# Get the token ID of the *actual* residue at the masked position
target_token_id = input_ids[0, i].item()
# Get the log-probability of the actual residue
log_prob_of_target = log_probs[target_token_id].item()
# Add to the sum
pll_sum += log_prob_of_target
L = len(sequence)
ppp = np.exp(-pll_sum / L)
return ppp
demo = gr.Interface(
fn=predict_ppp,
inputs=[
gr.Textbox(label="Enter Protein Amino Acid Sequence (1-letter code)",
placeholder="ACDEFGHIKLMNPQRSTVWY"),
],
outputs="text",
title="Nano Protein Language Model for Pseudo Perplexity (PPP) prediction of a protein sequence",
description="Enter an amino acid sequence (using the 1-letter code) to predict its Pseudo Perplexity (PPP)",
examples=[
["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"], # Example sequence
]
)
demo.launch()