sudhir2016 commited on
Commit
a5b8f58
·
verified ·
1 Parent(s): 0d05e09

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr, numpy as np
2
+
3
+ import torch
4
+ from transformers import EsmTokenizer,EsmForMaskedLM
5
+
6
+ model_name = "facebook/esm2_t6_8M_UR50D"
7
+ tokenizer = EsmTokenizer.from_pretrained(model_name)
8
+
9
+ device = torch.device("cpu")
10
+
11
+ model_mlm = EsmForMaskedLM.from_pretrained(model_name).to(device)
12
+
13
+ # 2. Define the PLL Calculation Function
14
+ def predict_ppp(sequence) -> float:
15
+ """
16
+ Calculates the ESM2 Pseudolog-Likelihood (PLL) for a single amino acid sequence.
17
+
18
+ PLL = sum_i( log P(x_i | x_{~i}) )
19
+ """
20
+ # Tokenize the sequence
21
+ # This automatically adds <CLS> and <EOS> tokens
22
+ input_ids = tokenizer(sequence, return_tensors='pt')['input_ids']
23
+
24
+ # The true sequence length (excluding special tokens)
25
+ L = len(sequence)
26
+
27
+ # The mask indices correspond to the AA sequence positions
28
+ # We ignore the first (<CLS>) and last (<EOS>) tokens.
29
+ mask_indices = torch.arange(1, L + 1)
30
+
31
+ # Accumulator for the log-likelihood sum
32
+ pll_sum = 0.0
33
+
34
+ # Iterate over each position in the sequence to mask it
35
+ for i in mask_indices:
36
+ # Create a copy of the input_ids
37
+ masked_input = input_ids.clone()
38
+
39
+ # Mask the current residue (token ID for MASK is 1)
40
+ masked_input[0, i] = tokenizer.mask_token_id
41
+
42
+ # Get model logits (unnormalized log-probabilities)
43
+ with torch.no_grad():
44
+ outputs = model_mlm(masked_input)
45
+ logits = outputs.logits # shape: (batch_size, seq_len, vocab_size)
46
+
47
+ # Extract the log-probabilities for the prediction at the masked position
48
+ # We use log_softmax to get log-probabilities
49
+ log_probs = torch.log_softmax(logits[0, i], dim=-1)
50
+
51
+ # Get the token ID of the *actual* residue at the masked position
52
+ target_token_id = input_ids[0, i].item()
53
+
54
+ # Get the log-probability of the actual residue
55
+ log_prob_of_target = log_probs[target_token_id].item()
56
+
57
+ # Add to the sum
58
+ pll_sum += log_prob_of_target
59
+ L = len(sequence)
60
+ ppp = np.exp(-pll_sum / L)
61
+
62
+ return ppp
63
+
64
+ demo = gr.Interface(
65
+ fn=predict_ppp,
66
+ inputs=[
67
+ gr.Textbox(label="Enter Protein Amino Acid Sequence (1-letter code)",
68
+ placeholder="ACDEFGHIKLMNPQRSTVWY"),
69
+ ],
70
+ outputs="text",
71
+ title="Nano Protein Language Model for Pseudo Perplexity (PPP) prediction of a protein sequence",
72
+ description="Enter an amino acid sequence (using the 1-letter code) to predict its Pseudo Perplexity (PPP)",
73
+ examples=[
74
+ ["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"], # Example sequence
75
+ ]
76
+ )
77
+ demo.launch()