import gradio as gr, numpy as np import torch from transformers import EsmTokenizer,EsmForMaskedLM model_name = "facebook/esm2_t6_8M_UR50D" tokenizer = EsmTokenizer.from_pretrained(model_name) device = torch.device("cpu") model_mlm = EsmForMaskedLM.from_pretrained(model_name).to(device) # 2. Define the PLL Calculation Function def predict_ppp(sequence) -> float: """ Calculates the ESM2 Pseudolog-Likelihood (PLL) for a single amino acid sequence. PLL = sum_i( log P(x_i | x_{~i}) ) """ # Tokenize the sequence # This automatically adds and tokens input_ids = tokenizer(sequence, return_tensors='pt')['input_ids'] # The true sequence length (excluding special tokens) L = len(sequence) # The mask indices correspond to the AA sequence positions # We ignore the first () and last () tokens. mask_indices = torch.arange(1, L + 1) # Accumulator for the log-likelihood sum pll_sum = 0.0 # Iterate over each position in the sequence to mask it for i in mask_indices: # Create a copy of the input_ids masked_input = input_ids.clone() # Mask the current residue (token ID for MASK is 1) masked_input[0, i] = tokenizer.mask_token_id # Get model logits (unnormalized log-probabilities) with torch.no_grad(): outputs = model_mlm(masked_input) logits = outputs.logits # shape: (batch_size, seq_len, vocab_size) # Extract the log-probabilities for the prediction at the masked position # We use log_softmax to get log-probabilities log_probs = torch.log_softmax(logits[0, i], dim=-1) # Get the token ID of the *actual* residue at the masked position target_token_id = input_ids[0, i].item() # Get the log-probability of the actual residue log_prob_of_target = log_probs[target_token_id].item() # Add to the sum pll_sum += log_prob_of_target L = len(sequence) ppp = np.exp(-pll_sum / L) return ppp demo = gr.Interface( fn=predict_ppp, inputs=[ gr.Textbox(label="Enter Protein Amino Acid Sequence (1-letter code)", placeholder="ACDEFGHIKLMNPQRSTVWY"), ], outputs="text", title="Nano Protein Language Model for Pseudo Perplexity (PPP) prediction of a protein sequence", description="Enter an amino acid sequence (using the 1-letter code) to predict its Pseudo Perplexity (PPP)", examples=[ ["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"], # Example sequence ] ) demo.launch()