Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| # In[1]: | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| from torch.nn import init, MarginRankingLoss | |
| from torch.optim import Adam | |
| from distutils.version import LooseVersion | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.autograd import Variable | |
| import math | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer | |
| import nltk | |
| import re | |
| import torch.optim as optim | |
| from transformers import AutoModelForMaskedLM | |
| import torch.nn.functional as F | |
| import random | |
| # In[2]: | |
| # eng_dict = [] | |
| # with open('eng_dict.txt', 'r') as file: | |
| # # Read each line from the file and append it to the list | |
| # for line in file: | |
| # # Remove leading and trailing whitespace (e.g., newline characters) | |
| # cleaned_line = line.strip() | |
| # eng_dict.append(cleaned_line) | |
| # In[14]: | |
| def greet(X, ny): | |
| global eng_dict | |
| ny = int(ny) | |
| if ny == 0: | |
| rand_no = random.random() | |
| tok_map = {2: 0.4363429005892416, | |
| 1: 0.6672580202327398, | |
| 4: 0.7476060740459144, | |
| 3: 0.9618703668504087, | |
| 6: 0.9701028532809564, | |
| 7: 0.9729244545819342, | |
| 8: 0.9739508754144756, | |
| 5: 0.9994508859743607, | |
| 9: 0.9997507867114407, | |
| 10: 0.9999112969650892, | |
| 11: 0.9999788802297832, | |
| 0: 0.9999831041838266, | |
| 12: 0.9999873281378701, | |
| 22: 0.9999957760459568, | |
| 14: 1.0000000000000002} | |
| for key in tok_map.keys(): | |
| if rand_no < tok_map[key]: | |
| num_sub_tokens_label = key | |
| break | |
| else: | |
| num_sub_tokens_label = ny | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") | |
| model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base") | |
| model.load_state_dict(torch.load('model_26_2')) | |
| model.eval() | |
| X_init = X | |
| X_init = X_init.replace("[MASK]", " [MASK] ") | |
| X_init = X_init.replace("[MASK]", " ".join([tokenizer.mask_token] * num_sub_tokens_label)) | |
| tokens = tokenizer.encode_plus(X_init, add_special_tokens=False,return_tensors='pt') | |
| input_id_chunki = tokens['input_ids'][0].split(510) | |
| input_id_chunks = [] | |
| mask_chunks = [] | |
| mask_chunki = tokens['attention_mask'][0].split(510) | |
| for tensor in input_id_chunki: | |
| input_id_chunks.append(tensor) | |
| for tensor in mask_chunki: | |
| mask_chunks.append(tensor) | |
| xi = torch.full((1,), fill_value=101) | |
| yi = torch.full((1,), fill_value=1) | |
| zi = torch.full((1,), fill_value=102) | |
| for r in range(len(input_id_chunks)): | |
| input_id_chunks[r] = torch.cat([xi, input_id_chunks[r]],dim = -1) | |
| input_id_chunks[r] = torch.cat([input_id_chunks[r],zi],dim=-1) | |
| mask_chunks[r] = torch.cat([yi, mask_chunks[r]],dim=-1) | |
| mask_chunks[r] = torch.cat([mask_chunks[r],yi],dim=-1) | |
| di = torch.full((1,), fill_value=0) | |
| for i in range(len(input_id_chunks)): | |
| pad_len = 512 - input_id_chunks[i].shape[0] | |
| if pad_len > 0: | |
| for p in range(pad_len): | |
| input_id_chunks[i] = torch.cat([input_id_chunks[i],di],dim=-1) | |
| mask_chunks[i] = torch.cat([mask_chunks[i],di],dim=-1) | |
| vb = torch.ones_like(input_id_chunks[0]) | |
| fg = torch.zeros_like(input_id_chunks[0]) | |
| maski = [] | |
| for l in range(len(input_id_chunks)): | |
| masked_pos = [] | |
| for i in range(len(input_id_chunks[l])): | |
| if input_id_chunks[l][i] == tokenizer.mask_token_id: #103 | |
| if i != 0 and input_id_chunks[l][i-1] == tokenizer.mask_token_id: | |
| continue | |
| masked_pos.append(i) | |
| maski.append(masked_pos) | |
| input_ids = torch.stack(input_id_chunks) | |
| att_mask = torch.stack(mask_chunks) | |
| outputs = model(input_ids, attention_mask = att_mask) | |
| last_hidden_state = outputs[0].squeeze() | |
| l_o_l_sa = [] | |
| sum_state = [] | |
| for t in range(num_sub_tokens_label): | |
| c = [] | |
| l_o_l_sa.append(c) | |
| if len(maski) == 1: | |
| masked_pos = maski[0] | |
| for k in masked_pos: | |
| for t in range(num_sub_tokens_label): | |
| l_o_l_sa[t].append(last_hidden_state[k+t]) | |
| else: | |
| for p in range(len(maski)): | |
| masked_pos = maski[p] | |
| for k in masked_pos: | |
| for t in range(num_sub_tokens_label): | |
| if (k+t) >= len(last_hidden_state[p]): | |
| l_o_l_sa[t].append(last_hidden_state[p+1][k+t-len(last_hidden_state[p])]) | |
| continue | |
| l_o_l_sa[t].append(last_hidden_state[p][k+t]) | |
| for t in range(num_sub_tokens_label): | |
| sum_state.append(l_o_l_sa[t][0]) | |
| for i in range(len(l_o_l_sa[0])): | |
| if i == 0: | |
| continue | |
| for t in range(num_sub_tokens_label): | |
| sum_state[t] = sum_state[t] + l_o_l_sa[t][i] | |
| yip = len(l_o_l_sa[0]) | |
| # qw = [] | |
| er = "" | |
| val = 0.0 | |
| for t in range(num_sub_tokens_label): | |
| sum_state[t] /= yip | |
| idx = torch.topk(sum_state[t], k=5, dim=0)[1] | |
| probs = F.softmax(sum_state[t], dim=0) | |
| wor = [tokenizer.decode(i.item()).strip() for i in idx] | |
| cnt = 0 | |
| for kl in wor: | |
| if all(char.isalpha() for char in kl): | |
| # qw.append(kl.lower()) | |
| er+=kl | |
| break | |
| cnt+=1 | |
| val = val - torch.log(probs[idx[cnt]]) | |
| val = val/num_sub_tokens_label | |
| vali = round(val.item(), 2) | |
| # print(er) | |
| # astr = "" | |
| # for j in range(len(qw)): | |
| # mock = "" | |
| # mock+= qw[j] | |
| # if (j+2) < len(qw) and ((mock+qw[j+1]+qw[j+2]) in eng_dict): | |
| # mock +=qw[j+1] | |
| # mock +=qw[j+2] | |
| # j = j+2 | |
| # elif (j+1) < len(qw) and ((mock+qw[j+1]) in eng_dict): | |
| # mock +=qw[j+1] | |
| # j = j+1 | |
| # if len(astr) == 0: | |
| # astr+=mock | |
| # else: | |
| # astr+=mock.capitalize() | |
| er = er+" (with PLL value of: "+str(vali)+")" | |
| return er | |
| title = "Rename a variable in a Java class" | |
| description = """This model is a fine-tuned GraphCodeBERT model fine-tuned to output higher-quality variable names for Java classes. Long classes are handled by the | |
| model. Replace any variable name with a "[MASK]" to get an identifier renaming. | |
| """ | |
| ex = [["""import java.io.*; | |
| public class x { | |
| public static void main(String[] args) { | |
| String f = "file.txt"; | |
| BufferedReader [MASK] = null; | |
| String l; | |
| try { | |
| [MASK] = new BufferedReader(new FileReader(f)); | |
| while ((l = [MASK].readLine()) != null) { | |
| System.out.println(l); | |
| } | |
| } catch (IOException e) { | |
| e.printStackTrace(); | |
| } finally { | |
| try { | |
| if ([MASK] != null) [MASK].close(); | |
| } catch (IOException ex) { | |
| ex.printStackTrace(); | |
| } | |
| } | |
| } | |
| }""", "0"], ["""import java.net.*; | |
| import java.io.*; | |
| public class s { | |
| public static void main(String[] args) throws IOException { | |
| ServerSocket [MASK] = new ServerSocket(8000); | |
| try { | |
| Socket s = [MASK].accept(); | |
| PrintWriter pw = new PrintWriter(s.getOutputStream(), true); | |
| BufferedReader br = new BufferedReader(new InputStreamReader(s.getInputStream())); | |
| String i; | |
| while ((i = br.readLine()) != null) { | |
| pw.println(i); | |
| } | |
| } finally { | |
| if ([MASK] != null) [MASK].close(); | |
| } | |
| } | |
| }""", "2"], ["""import java.io.*; | |
| import java.util.*; | |
| public class y { | |
| public static void main(String[] args) { | |
| String [MASK] = "data.csv"; | |
| String l = ""; | |
| String cvsSplitBy = ","; | |
| try (BufferedReader br = new BufferedReader(new FileReader([MASK]))) { | |
| while ((l = br.readLine()) != null) { | |
| String[] z = l.split(cvsSplitBy); | |
| System.out.println("Values [field-1= " + z[0] + " , field-2=" + z[1] + "]"); | |
| } | |
| } catch (IOException e) { | |
| e.printStackTrace(); | |
| } | |
| } | |
| }""", "2"]] | |
| # We instantiate the Textbox class | |
| textbox = gr.Textbox(label="Type Java code snippet:", placeholder="replace variable with [MASK]", lines=10) | |
| textbox1 = gr.Textbox(label="Number of tokens in name:", placeholder="0 for randomly sampled number of tokens",lines=1) | |
| gr.Interface(title = title, description = description, examples = ex, fn=greet, inputs=[ | |
| textbox,textbox1 | |
| ], outputs="text").launch() | |
| # In[ ]: | |