Spaces:
Sleeping
Sleeping
fine
Browse files- CustomBERTModel.py +33 -0
- Untitled.ipynb +0 -0
- __pycache__/metrics.cpython-312.pyc +0 -0
- __pycache__/recalibration.cpython-312.pyc +0 -0
- __pycache__/visualization.cpython-312.pyc +0 -0
- app.py +48 -0
- data_preprocessor.py +170 -0
- hint_fine_tuning.py +382 -0
- main.py +322 -0
- metrics.py +149 -0
- new_fine_tuning/README.md +197 -0
- new_fine_tuning/__pycache__/metrics.cpython-312.pyc +0 -0
- new_fine_tuning/__pycache__/recalibration.cpython-312.pyc +0 -0
- new_fine_tuning/__pycache__/visualization.cpython-312.pyc +0 -0
- new_hint_fine_tuned.py +131 -0
- new_test_saved_finetuned_model.py +613 -0
- plot.png +0 -0
- prepare_pretraining_input_vocab_file.py +0 -0
- ratio_proportion_change3_2223/sch_largest_100-coded/pretraining/vocab.txt +34 -0
- recalibration.py +82 -0
- src/__pycache__/attention.cpython-312.pyc +0 -0
- src/__pycache__/bert.cpython-312.pyc +0 -0
- src/__pycache__/classifier_model.cpython-312.pyc +0 -0
- src/__pycache__/dataset.cpython-312.pyc +0 -0
- src/__pycache__/embedding.cpython-312.pyc +0 -0
- src/__pycache__/seq_model.cpython-312.pyc +0 -0
- src/__pycache__/transformer.cpython-312.pyc +0 -0
- src/__pycache__/transformer_component.cpython-312.pyc +0 -0
- src/__pycache__/vocab.cpython-312.pyc +0 -0
- src/attention.py +21 -1
- src/bert.py +35 -0
- src/classifier_model.py +52 -1
- src/dataset.py +385 -0
- src/pretrainer.py +713 -0
- src/reference_code/bert_reference_code.py +1622 -0
- src/reference_code/evaluate_embeddings.py +136 -0
- src/reference_code/metrics.py +149 -0
- src/reference_code/pretrainer-old.py +696 -0
- src/reference_code/test.py +493 -0
- src/reference_code/utils.py +369 -0
- src/reference_code/visualization.py +78 -0
- src/seq_model.py +15 -0
- src/transformer.py +11 -0
- src/vocab.py +17 -0
- test.py +8 -0
- test.txt +0 -0
- test_hint_fine_tuned.py +45 -0
- test_saved_model.py +234 -0
- visualization.py +78 -0
CustomBERTModel.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from src.bert import BERT
|
| 4 |
+
|
| 5 |
+
class CustomBERTModel(nn.Module):
|
| 6 |
+
def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
| 7 |
+
super(CustomBERTModel, self).__init__()
|
| 8 |
+
hidden_size = 768
|
| 9 |
+
self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=4, attn_heads=8, dropout=0.1)
|
| 10 |
+
|
| 11 |
+
# Load the pre-trained model's state_dict
|
| 12 |
+
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
| 13 |
+
if isinstance(checkpoint, dict):
|
| 14 |
+
self.bert.load_state_dict(checkpoint)
|
| 15 |
+
else:
|
| 16 |
+
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
|
| 17 |
+
|
| 18 |
+
# Fully connected layer with input size 768 (matching BERT hidden size)
|
| 19 |
+
self.fc = nn.Linear(hidden_size, output_dim)
|
| 20 |
+
|
| 21 |
+
def forward(self, sequence, segment_info):
|
| 22 |
+
sequence = sequence.to(next(self.parameters()).device)
|
| 23 |
+
segment_info = segment_info.to(sequence.device)
|
| 24 |
+
|
| 25 |
+
x = self.bert(sequence, segment_info)
|
| 26 |
+
print(f"BERT output shape: {x.shape}")
|
| 27 |
+
|
| 28 |
+
cls_embeddings = x[:, 0] # Extract CLS token embeddings
|
| 29 |
+
print(f"CLS Embeddings shape: {cls_embeddings.shape}")
|
| 30 |
+
|
| 31 |
+
logits = self.fc(cls_embeddings) # Pass tensor of size (batch_size, 768) to the fully connected layer
|
| 32 |
+
|
| 33 |
+
return logits
|
Untitled.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (9.14 kB). View file
|
|
|
__pycache__/recalibration.cpython-312.pyc
ADDED
|
Binary file (5.49 kB). View file
|
|
|
__pycache__/visualization.cpython-312.pyc
ADDED
|
Binary file (5.27 kB). View file
|
|
|
app.py
CHANGED
|
@@ -101,15 +101,48 @@ import shutil
|
|
| 101 |
import matplotlib.pyplot as plt
|
| 102 |
from sklearn.metrics import roc_curve, auc
|
| 103 |
# Define the function to process the input file and model selection
|
|
|
|
|
|
|
|
|
|
| 104 |
def process_file(file,label, model_name):
|
|
|
|
| 105 |
with open(file.name, 'r') as f:
|
| 106 |
content = f.read()
|
| 107 |
saved_test_dataset = "train.txt"
|
| 108 |
saved_test_label = "train_label.txt"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
# Save the uploaded file content to a specified location
|
| 111 |
shutil.copyfile(file.name, saved_test_dataset)
|
| 112 |
shutil.copyfile(label.name, saved_test_label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# For demonstration purposes, we'll just return the content with the selected model name
|
| 114 |
if(model_name=="FS"):
|
| 115 |
checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
|
|
@@ -126,6 +159,7 @@ def process_file(file,label, model_name):
|
|
| 126 |
subprocess.run(["python", "src/test_saved_model.py",
|
| 127 |
"--finetuned_bert_checkpoint",checkpoint
|
| 128 |
])
|
|
|
|
| 129 |
result = {}
|
| 130 |
with open("result.txt", 'r') as file:
|
| 131 |
for line in file:
|
|
@@ -160,7 +194,11 @@ def process_file(file,label, model_name):
|
|
| 160 |
return text_output,plot_path
|
| 161 |
|
| 162 |
# List of models for the dropdown menu
|
|
|
|
|
|
|
|
|
|
| 163 |
models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"]
|
|
|
|
| 164 |
|
| 165 |
# Create the Gradio interface
|
| 166 |
with gr.Blocks(css="""
|
|
@@ -350,15 +388,25 @@ tbody.svelte-18wv37q>tr.svelte-18wv37q:nth-child(odd) {
|
|
| 350 |
with gr.Row():
|
| 351 |
file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
|
| 352 |
label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
model_dropdown = gr.Dropdown(choices=models, label="Select Model", elem_classes="dropdown-menu")
|
|
|
|
| 355 |
|
| 356 |
with gr.Row():
|
| 357 |
output_text = gr.Textbox(label="Output Text")
|
| 358 |
output_image = gr.Image(label="Output Plot")
|
| 359 |
|
| 360 |
btn = gr.Button("Submit")
|
|
|
|
|
|
|
|
|
|
| 361 |
btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image])
|
|
|
|
| 362 |
|
| 363 |
# Launch the app
|
| 364 |
demo.launch()
|
|
|
|
| 101 |
import matplotlib.pyplot as plt
|
| 102 |
from sklearn.metrics import roc_curve, auc
|
| 103 |
# Define the function to process the input file and model selection
|
| 104 |
+
<<<<<<< HEAD
|
| 105 |
+
def process_file(file,label,info, model_name):
|
| 106 |
+
=======
|
| 107 |
def process_file(file,label, model_name):
|
| 108 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 109 |
with open(file.name, 'r') as f:
|
| 110 |
content = f.read()
|
| 111 |
saved_test_dataset = "train.txt"
|
| 112 |
saved_test_label = "train_label.txt"
|
| 113 |
+
<<<<<<< HEAD
|
| 114 |
+
saved_train_info="train_info.txt"
|
| 115 |
+
=======
|
| 116 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 117 |
|
| 118 |
# Save the uploaded file content to a specified location
|
| 119 |
shutil.copyfile(file.name, saved_test_dataset)
|
| 120 |
shutil.copyfile(label.name, saved_test_label)
|
| 121 |
+
<<<<<<< HEAD
|
| 122 |
+
shutil.copyfile(info.name, saved_train_info)
|
| 123 |
+
# For demonstration purposes, we'll just return the content with the selected model name
|
| 124 |
+
# if(model_name=="highGRschool10"):
|
| 125 |
+
# checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
|
| 126 |
+
# elif(model_name=="lowGRschoolAll"):
|
| 127 |
+
# checkpoint="ratio_proportion_change3/output/IS/bert_fine_tuned.model.ep14"
|
| 128 |
+
# elif(model_name=="fullTest"):
|
| 129 |
+
# checkpoint="ratio_proportion_change3/output/correctness/bert_fine_tuned.model.ep48"
|
| 130 |
+
# else:
|
| 131 |
+
# checkpoint=None
|
| 132 |
+
|
| 133 |
+
# print(checkpoint)
|
| 134 |
+
subprocess.run([
|
| 135 |
+
"python", "new_test_saved_finetuned_model.py",
|
| 136 |
+
"-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
|
| 137 |
+
"-finetune_task", model_name,
|
| 138 |
+
"-test_dataset_path","../../../../train.txt",
|
| 139 |
+
# "-test_label_path","../../../../train_label.txt",
|
| 140 |
+
"-finetuned_bert_classifier_checkpoint",
|
| 141 |
+
"ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42",
|
| 142 |
+
"-e",str(1),
|
| 143 |
+
"-b",str(5)
|
| 144 |
+
], shell=True)
|
| 145 |
+
=======
|
| 146 |
# For demonstration purposes, we'll just return the content with the selected model name
|
| 147 |
if(model_name=="FS"):
|
| 148 |
checkpoint="ratio_proportion_change3/output/FS/bert_fine_tuned.model.ep32"
|
|
|
|
| 159 |
subprocess.run(["python", "src/test_saved_model.py",
|
| 160 |
"--finetuned_bert_checkpoint",checkpoint
|
| 161 |
])
|
| 162 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 163 |
result = {}
|
| 164 |
with open("result.txt", 'r') as file:
|
| 165 |
for line in file:
|
|
|
|
| 194 |
return text_output,plot_path
|
| 195 |
|
| 196 |
# List of models for the dropdown menu
|
| 197 |
+
<<<<<<< HEAD
|
| 198 |
+
models = ["highGRschool10", "lowGRschoolAll", "fullTest"]
|
| 199 |
+
=======
|
| 200 |
models = ["FS", "IS", "CORRECTNESS","EFFECTIVENESS"]
|
| 201 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 202 |
|
| 203 |
# Create the Gradio interface
|
| 204 |
with gr.Blocks(css="""
|
|
|
|
| 388 |
with gr.Row():
|
| 389 |
file_input = gr.File(label="Upload a test file", file_types=['.txt'], elem_classes="file-box")
|
| 390 |
label_input = gr.File(label="Upload test labels", file_types=['.txt'], elem_classes="file-box")
|
| 391 |
+
<<<<<<< HEAD
|
| 392 |
+
info_input = gr.File(label="Upload test info", file_types=['.txt'], elem_classes="file-box")
|
| 393 |
+
|
| 394 |
+
model_dropdown = gr.Dropdown(choices=models, label="Select Finetune Task", elem_classes="dropdown-menu")
|
| 395 |
+
=======
|
| 396 |
|
| 397 |
model_dropdown = gr.Dropdown(choices=models, label="Select Model", elem_classes="dropdown-menu")
|
| 398 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 399 |
|
| 400 |
with gr.Row():
|
| 401 |
output_text = gr.Textbox(label="Output Text")
|
| 402 |
output_image = gr.Image(label="Output Plot")
|
| 403 |
|
| 404 |
btn = gr.Button("Submit")
|
| 405 |
+
<<<<<<< HEAD
|
| 406 |
+
btn.click(fn=process_file, inputs=[file_input,label_input,info_input, model_dropdown], outputs=[output_text,output_image])
|
| 407 |
+
=======
|
| 408 |
btn.click(fn=process_file, inputs=[file_input,label_input, model_dropdown], outputs=[output_text,output_image])
|
| 409 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 410 |
|
| 411 |
# Launch the app
|
| 412 |
demo.launch()
|
data_preprocessor.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
class DataPreprocessor:
|
| 7 |
+
def __init__(self, input_file_path):
|
| 8 |
+
self.input_file_path = input_file_path
|
| 9 |
+
self.unique_students = None
|
| 10 |
+
self.unique_problems = None
|
| 11 |
+
self.unique_prob_hierarchy = None
|
| 12 |
+
self.unique_steps = None
|
| 13 |
+
self.unique_kcs = None
|
| 14 |
+
|
| 15 |
+
def analyze_dataset(self):
|
| 16 |
+
file_iterator = self.load_file_iterator()
|
| 17 |
+
|
| 18 |
+
start_time = time.time()
|
| 19 |
+
self.unique_students = {"st"}
|
| 20 |
+
self.unique_problems = {"pr"}
|
| 21 |
+
self.unique_prob_hierarchy = {"ph"}
|
| 22 |
+
self.unique_kcs = {"kc"}
|
| 23 |
+
for chunk_data in file_iterator:
|
| 24 |
+
for student_id, std_groups in chunk_data.groupby('Anon Student Id'):
|
| 25 |
+
self.unique_students.update({student_id})
|
| 26 |
+
prob_hierarchy = std_groups.groupby('Level (Workspace Id)')
|
| 27 |
+
for hierarchy, hierarchy_groups in prob_hierarchy:
|
| 28 |
+
self.unique_prob_hierarchy.update({hierarchy})
|
| 29 |
+
prob_name = hierarchy_groups.groupby('Problem Name')
|
| 30 |
+
for problem_name, prob_name_groups in prob_name:
|
| 31 |
+
self.unique_problems.update({problem_name})
|
| 32 |
+
sub_skills = prob_name_groups['KC Model(MATHia)']
|
| 33 |
+
for a in sub_skills:
|
| 34 |
+
if str(a) != "nan":
|
| 35 |
+
temp = a.split("~~")
|
| 36 |
+
for kc in temp:
|
| 37 |
+
self.unique_kcs.update({kc})
|
| 38 |
+
self.unique_students.remove("st")
|
| 39 |
+
self.unique_problems.remove("pr")
|
| 40 |
+
self.unique_prob_hierarchy.remove("ph")
|
| 41 |
+
self.unique_kcs.remove("kc")
|
| 42 |
+
end_time = time.time()
|
| 43 |
+
print("Time Taken to analyze dataset = ", end_time - start_time)
|
| 44 |
+
print("Length of unique students->", len(self.unique_students))
|
| 45 |
+
print("Length of unique problems->", len(self.unique_problems))
|
| 46 |
+
print("Length of unique problem hierarchy->", len(self.unique_prob_hierarchy))
|
| 47 |
+
print("Length of Unique Knowledge components ->", len(self.unique_kcs))
|
| 48 |
+
|
| 49 |
+
def analyze_dataset_by_section(self, workspace_name):
|
| 50 |
+
file_iterator = self.load_file_iterator()
|
| 51 |
+
|
| 52 |
+
start_time = time.time()
|
| 53 |
+
self.unique_students = {"st"}
|
| 54 |
+
self.unique_problems = {"pr"}
|
| 55 |
+
self.unique_prob_hierarchy = {"ph"}
|
| 56 |
+
self.unique_steps = {"s"}
|
| 57 |
+
self.unique_kcs = {"kc"}
|
| 58 |
+
# with open("workspace_info.txt", 'a') as f:
|
| 59 |
+
# sys.stdout = f
|
| 60 |
+
for chunk_data in file_iterator:
|
| 61 |
+
for student_id, std_groups in chunk_data.groupby('Anon Student Id'):
|
| 62 |
+
prob_hierarchy = std_groups.groupby('Level (Workspace Id)')
|
| 63 |
+
for hierarchy, hierarchy_groups in prob_hierarchy:
|
| 64 |
+
if workspace_name == hierarchy:
|
| 65 |
+
# print("Workspace : ", hierarchy)
|
| 66 |
+
self.unique_students.update({student_id})
|
| 67 |
+
self.unique_prob_hierarchy.update({hierarchy})
|
| 68 |
+
prob_name = hierarchy_groups.groupby('Problem Name')
|
| 69 |
+
for problem_name, prob_name_groups in prob_name:
|
| 70 |
+
self.unique_problems.update({problem_name})
|
| 71 |
+
step_names = prob_name_groups['Step Name']
|
| 72 |
+
sub_skills = prob_name_groups['KC Model(MATHia)']
|
| 73 |
+
for step in step_names:
|
| 74 |
+
if str(step) != "nan":
|
| 75 |
+
self.unique_steps.update({step})
|
| 76 |
+
for a in sub_skills:
|
| 77 |
+
if str(a) != "nan":
|
| 78 |
+
temp = a.split("~~")
|
| 79 |
+
for kc in temp:
|
| 80 |
+
self.unique_kcs.update({kc})
|
| 81 |
+
self.unique_problems.remove("pr")
|
| 82 |
+
self.unique_prob_hierarchy.remove("ph")
|
| 83 |
+
self.unique_steps.remove("s")
|
| 84 |
+
self.unique_kcs.remove("kc")
|
| 85 |
+
end_time = time.time()
|
| 86 |
+
print("Time Taken to analyze dataset = ", end_time - start_time)
|
| 87 |
+
print("Workspace-> ",workspace_name)
|
| 88 |
+
print("Length of unique students->", len(self.unique_students))
|
| 89 |
+
print("Length of unique problems->", len(self.unique_problems))
|
| 90 |
+
print("Length of unique problem hierarchy->", len(self.unique_prob_hierarchy))
|
| 91 |
+
print("Length of unique step names ->", len(self.unique_steps))
|
| 92 |
+
print("Length of unique knowledge components ->", len(self.unique_kcs))
|
| 93 |
+
# f.close()
|
| 94 |
+
# sys.stdout = sys.__stdout__
|
| 95 |
+
|
| 96 |
+
def analyze_dataset_by_school(self, workspace_name, school_id=None):
|
| 97 |
+
file_iterator = self.load_file_iterator(sep=",")
|
| 98 |
+
|
| 99 |
+
start_time = time.time()
|
| 100 |
+
self.unique_schools = set()
|
| 101 |
+
self.unique_class = set()
|
| 102 |
+
self.unique_students = set()
|
| 103 |
+
self.unique_problems = set()
|
| 104 |
+
self.unique_steps = set()
|
| 105 |
+
self.unique_kcs = set()
|
| 106 |
+
self.unique_actions = set()
|
| 107 |
+
self.unique_outcomes = set()
|
| 108 |
+
self.unique_new_steps_w_action_attempt = set()
|
| 109 |
+
self.unique_new_steps_w_kcs = set()
|
| 110 |
+
self.unique_new_steps_w_action_attempt_kcs = set()
|
| 111 |
+
|
| 112 |
+
for chunk_data in file_iterator:
|
| 113 |
+
for school, school_group in chunk_data.groupby('CF (Anon School Id)'):
|
| 114 |
+
# if school and school == school_id:
|
| 115 |
+
self.unique_schools.add(school)
|
| 116 |
+
for class_id, class_group in school_group.groupby('CF (Anon Class Id)'):
|
| 117 |
+
self.unique_class.add(class_id)
|
| 118 |
+
for student_id, std_group in class_group.groupby('Anon Student Id'):
|
| 119 |
+
self.unique_students.add(student_id)
|
| 120 |
+
for prob, prob_group in std_group.groupby('Problem Name'):
|
| 121 |
+
self.unique_problems.add(prob)
|
| 122 |
+
|
| 123 |
+
step_names = set(prob_group['Step Name'])
|
| 124 |
+
sub_skills = set(prob_group['KC Model(MATHia)'])
|
| 125 |
+
actions = set(prob_group['Action'])
|
| 126 |
+
outcomes = set(prob_group['Outcome'])
|
| 127 |
+
|
| 128 |
+
self.unique_steps.update(step_names)
|
| 129 |
+
self.unique_kcs.update(sub_skills)
|
| 130 |
+
self.unique_actions.update(actions)
|
| 131 |
+
self.unique_outcomes.update(outcomes)
|
| 132 |
+
|
| 133 |
+
for step in step_names:
|
| 134 |
+
if pd.isna(step):
|
| 135 |
+
step_group = prob_group[pd.isna(prob_group['Step Name'])]
|
| 136 |
+
else:
|
| 137 |
+
step_group = prob_group[prob_group['Step Name']==step]
|
| 138 |
+
|
| 139 |
+
for kc in set(step_group['KC Model(MATHia)']):
|
| 140 |
+
new_step = f"{step}:{kc}"
|
| 141 |
+
self.unique_new_steps_w_kcs.add(new_step)
|
| 142 |
+
|
| 143 |
+
for action, action_group in step_group.groupby('Action'):
|
| 144 |
+
for attempt, attempt_group in action_group.groupby('Attempt At Step'):
|
| 145 |
+
new_step = f"{step}:{action}:{attempt}"
|
| 146 |
+
self.unique_new_steps_w_action_attempt.add(new_step)
|
| 147 |
+
|
| 148 |
+
for kc in set(attempt_group["KC Model(MATHia)"]):
|
| 149 |
+
new_step = f"{step}:{action}:{attempt}:{kc}"
|
| 150 |
+
self.unique_new_steps_w_action_attempt_kcs.add(new_step)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
end_time = time.time()
|
| 154 |
+
print("Time Taken to analyze dataset = ", end_time - start_time)
|
| 155 |
+
print("Workspace-> ",workspace_name)
|
| 156 |
+
print("Length of unique students->", len(self.unique_students))
|
| 157 |
+
print("Length of unique problems->", len(self.unique_problems))
|
| 158 |
+
print("Length of unique classes->", len(self.unique_class))
|
| 159 |
+
print("Length of unique step names ->", len(self.unique_steps))
|
| 160 |
+
print("Length of unique knowledge components ->", len(self.unique_kcs))
|
| 161 |
+
print("Length of unique actions ->", len(self.unique_actions))
|
| 162 |
+
print("Length of unique outcomes ->", len(self.unique_outcomes))
|
| 163 |
+
print("Length of unique new step names with actions and attempts ->", len(self.unique_new_steps_w_action_attempt))
|
| 164 |
+
print("Length of unique new step names with actions, attempts and kcs ->", len(self.unique_new_steps_w_action_attempt_kcs))
|
| 165 |
+
print("Length of unique new step names with kcs ->", len(self.unique_new_steps_w_kcs))
|
| 166 |
+
|
| 167 |
+
def load_file_iterator(self, sep="\t"):
|
| 168 |
+
chunk_iterator = pd.read_csv(self.input_file_path, sep=sep, header=0, iterator=True, chunksize=1000000)
|
| 169 |
+
return chunk_iterator
|
| 170 |
+
|
hint_fine_tuning.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.data import DataLoader, random_split, TensorDataset
|
| 7 |
+
from src.dataset import TokenizerDataset
|
| 8 |
+
from src.bert import BERT
|
| 9 |
+
from src.pretrainer import BERTFineTuneTrainer1
|
| 10 |
+
from src.vocab import Vocab
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# class CustomBERTModel(nn.Module):
|
| 15 |
+
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
| 16 |
+
# super(CustomBERTModel, self).__init__()
|
| 17 |
+
# hidden_size = 768
|
| 18 |
+
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
|
| 19 |
+
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
| 20 |
+
# if isinstance(checkpoint, dict):
|
| 21 |
+
# self.bert.load_state_dict(checkpoint)
|
| 22 |
+
# elif isinstance(checkpoint, BERT):
|
| 23 |
+
# self.bert = checkpoint
|
| 24 |
+
# else:
|
| 25 |
+
# raise TypeError(f"Expected state_dict or BERT instance, got {type(checkpoint)} instead.")
|
| 26 |
+
# self.fc = nn.Linear(hidden_size, output_dim)
|
| 27 |
+
|
| 28 |
+
# def forward(self, sequence, segment_info):
|
| 29 |
+
# sequence = sequence.to(next(self.parameters()).device)
|
| 30 |
+
# segment_info = segment_info.to(sequence.device)
|
| 31 |
+
|
| 32 |
+
# if sequence.size(0) == 0 or sequence.size(1) == 0:
|
| 33 |
+
# raise ValueError("Input sequence tensor has 0 elements. Check data preprocessing.")
|
| 34 |
+
|
| 35 |
+
# x = self.bert(sequence, segment_info)
|
| 36 |
+
# print(f"BERT output shape: {x.shape}")
|
| 37 |
+
|
| 38 |
+
# if x.size(0) == 0 or x.size(1) == 0:
|
| 39 |
+
# raise ValueError("BERT output tensor has 0 elements. Check input dimensions.")
|
| 40 |
+
|
| 41 |
+
# cls_embeddings = x[:, 0]
|
| 42 |
+
# logits = self.fc(cls_embeddings)
|
| 43 |
+
# return logits
|
| 44 |
+
|
| 45 |
+
# class CustomBERTModel(nn.Module):
|
| 46 |
+
# def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
| 47 |
+
# super(CustomBERTModel, self).__init__()
|
| 48 |
+
# hidden_size = 764 # Ensure this is 768
|
| 49 |
+
# self.bert = BERT(vocab_size=vocab_size, hidden=hidden_size, n_layers=12, attn_heads=12, dropout=0.1)
|
| 50 |
+
|
| 51 |
+
# # Load the pre-trained model's state_dict
|
| 52 |
+
# checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
| 53 |
+
# if isinstance(checkpoint, dict):
|
| 54 |
+
# self.bert.load_state_dict(checkpoint)
|
| 55 |
+
# else:
|
| 56 |
+
# raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
|
| 57 |
+
|
| 58 |
+
# # Fully connected layer with input size 768
|
| 59 |
+
# self.fc = nn.Linear(hidden_size, output_dim)
|
| 60 |
+
|
| 61 |
+
# def forward(self, sequence, segment_info):
|
| 62 |
+
# sequence = sequence.to(next(self.parameters()).device)
|
| 63 |
+
# segment_info = segment_info.to(sequence.device)
|
| 64 |
+
|
| 65 |
+
# x = self.bert(sequence, segment_info)
|
| 66 |
+
# print(f"BERT output shape: {x.shape}") # Should output (batch_size, seq_len, 768)
|
| 67 |
+
|
| 68 |
+
# cls_embeddings = x[:, 0] # Extract CLS token embeddings
|
| 69 |
+
# print(f"CLS Embeddings shape: {cls_embeddings.shape}") # Should output (batch_size, 768)
|
| 70 |
+
|
| 71 |
+
# logits = self.fc(cls_embeddings) # Should now pass a tensor of size (batch_size, 768) to `fc`
|
| 72 |
+
|
| 73 |
+
# return logits
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# for test
|
| 77 |
+
class CustomBERTModel(nn.Module):
|
| 78 |
+
def __init__(self, vocab_size, output_dim, pre_trained_model_path):
|
| 79 |
+
super(CustomBERTModel, self).__init__()
|
| 80 |
+
self.hidden = 764 # Ensure this is defined correctly
|
| 81 |
+
self.bert = BERT(vocab_size=vocab_size, hidden=self.hidden, n_layers=12, attn_heads=12, dropout=0.1)
|
| 82 |
+
|
| 83 |
+
# Load the pre-trained model's state_dict
|
| 84 |
+
checkpoint = torch.load(pre_trained_model_path, map_location=torch.device('cpu'))
|
| 85 |
+
if isinstance(checkpoint, dict):
|
| 86 |
+
self.bert.load_state_dict(checkpoint)
|
| 87 |
+
else:
|
| 88 |
+
raise TypeError(f"Expected state_dict, got {type(checkpoint)} instead.")
|
| 89 |
+
|
| 90 |
+
self.fc = nn.Linear(self.hidden, output_dim)
|
| 91 |
+
|
| 92 |
+
def forward(self, sequence, segment_info):
|
| 93 |
+
x = self.bert(sequence, segment_info)
|
| 94 |
+
cls_embeddings = x[:, 0] # Extract CLS token embeddings
|
| 95 |
+
logits = self.fc(cls_embeddings) # Pass to fully connected layer
|
| 96 |
+
return logits
|
| 97 |
+
|
| 98 |
+
def preprocess_labels(label_csv_path):
|
| 99 |
+
try:
|
| 100 |
+
labels_df = pd.read_csv(label_csv_path)
|
| 101 |
+
labels = labels_df['last_hint_class'].values.astype(int)
|
| 102 |
+
return torch.tensor(labels, dtype=torch.long)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error reading dataset file: {e}")
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def preprocess_data(data_path, vocab, max_length=128):
|
| 109 |
+
try:
|
| 110 |
+
with open(data_path, 'r') as f:
|
| 111 |
+
sequences = f.readlines()
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"Error reading data file: {e}")
|
| 114 |
+
return None, None
|
| 115 |
+
|
| 116 |
+
if len(sequences) == 0:
|
| 117 |
+
raise ValueError(f"No sequences found in data file {data_path}. Check the file content.")
|
| 118 |
+
|
| 119 |
+
tokenized_sequences = []
|
| 120 |
+
|
| 121 |
+
for sequence in sequences:
|
| 122 |
+
sequence = sequence.strip()
|
| 123 |
+
if sequence:
|
| 124 |
+
encoded = vocab.to_seq(sequence, seq_len=max_length)
|
| 125 |
+
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
|
| 126 |
+
segment_label = [0] * max_length
|
| 127 |
+
|
| 128 |
+
tokenized_sequences.append({
|
| 129 |
+
'input_ids': torch.tensor(encoded),
|
| 130 |
+
'segment_label': torch.tensor(segment_label)
|
| 131 |
+
})
|
| 132 |
+
|
| 133 |
+
if not tokenized_sequences:
|
| 134 |
+
raise ValueError("Tokenization resulted in an empty list. Check the sequences and tokenization logic.")
|
| 135 |
+
|
| 136 |
+
tokenized_sequences = [t for t in tokenized_sequences if len(t['input_ids']) == max_length]
|
| 137 |
+
|
| 138 |
+
if not tokenized_sequences:
|
| 139 |
+
raise ValueError("All tokenized sequences are of unexpected length. This suggests an issue with the tokenization logic.")
|
| 140 |
+
|
| 141 |
+
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
| 142 |
+
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
| 143 |
+
|
| 144 |
+
print(f"Input IDs shape: {input_ids.shape}")
|
| 145 |
+
print(f"Segment labels shape: {segment_labels.shape}")
|
| 146 |
+
|
| 147 |
+
return input_ids, segment_labels
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def collate_fn(batch):
|
| 151 |
+
inputs = []
|
| 152 |
+
labels = []
|
| 153 |
+
segment_labels = []
|
| 154 |
+
|
| 155 |
+
for item in batch:
|
| 156 |
+
if item is None:
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
if isinstance(item, dict):
|
| 160 |
+
inputs.append(item['input_ids'].unsqueeze(0))
|
| 161 |
+
labels.append(item['label'].unsqueeze(0))
|
| 162 |
+
segment_labels.append(item['segment_label'].unsqueeze(0))
|
| 163 |
+
|
| 164 |
+
if len(inputs) == 0 or len(segment_labels) == 0:
|
| 165 |
+
print("Empty batch encountered. Returning None to skip this batch.")
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
inputs = torch.cat(inputs, dim=0)
|
| 170 |
+
labels = torch.cat(labels, dim=0)
|
| 171 |
+
segment_labels = torch.cat(segment_labels, dim=0)
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print(f"Error concatenating tensors: {e}")
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
return {
|
| 177 |
+
'input': inputs,
|
| 178 |
+
'label': labels,
|
| 179 |
+
'segment_label': segment_labels
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
def custom_collate_fn(batch):
|
| 183 |
+
processed_batch = collate_fn(batch)
|
| 184 |
+
|
| 185 |
+
if processed_batch is None or len(processed_batch['input']) == 0:
|
| 186 |
+
# Return a valid batch with at least one element instead of an empty one
|
| 187 |
+
return {
|
| 188 |
+
'input': torch.zeros((1, 128), dtype=torch.long),
|
| 189 |
+
'label': torch.zeros((1,), dtype=torch.long),
|
| 190 |
+
'segment_label': torch.zeros((1, 128), dtype=torch.long)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
return processed_batch
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def train_without_progress_status(trainer, epoch, shuffle):
|
| 197 |
+
for epoch_idx in range(epoch):
|
| 198 |
+
print(f"EP_train:{epoch_idx}:")
|
| 199 |
+
for batch in trainer.train_data:
|
| 200 |
+
if batch is None:
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
# Check if batch is a string (indicating an issue)
|
| 204 |
+
if isinstance(batch, str):
|
| 205 |
+
print(f"Error: Received a string instead of a dictionary in batch: {batch}")
|
| 206 |
+
raise ValueError(f"Unexpected string in batch: {batch}")
|
| 207 |
+
|
| 208 |
+
# Validate the batch structure before passing to iteration
|
| 209 |
+
if isinstance(batch, dict):
|
| 210 |
+
# Verify that all expected keys are present and that the values are tensors
|
| 211 |
+
if all(key in batch for key in ['input_ids', 'segment_label', 'labels']):
|
| 212 |
+
if all(isinstance(batch[key], torch.Tensor) for key in batch):
|
| 213 |
+
try:
|
| 214 |
+
print(f"Batch Structure: {batch}") # Debugging batch before iteration
|
| 215 |
+
trainer.iteration(epoch_idx, batch)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(f"Error during batch processing: {e}")
|
| 218 |
+
sys.stdout.flush()
|
| 219 |
+
raise e # Propagate the exception for better debugging
|
| 220 |
+
else:
|
| 221 |
+
print(f"Error: Expected all values in batch to be tensors, but got: {batch}")
|
| 222 |
+
raise ValueError("Batch contains non-tensor values.")
|
| 223 |
+
else:
|
| 224 |
+
print(f"Error: Batch missing expected keys. Batch keys: {batch.keys()}")
|
| 225 |
+
raise ValueError("Batch does not contain expected keys.")
|
| 226 |
+
else:
|
| 227 |
+
print(f"Error: Expected batch to be a dictionary but got {type(batch)} instead.")
|
| 228 |
+
raise ValueError(f"Invalid batch structure: {batch}")
|
| 229 |
+
|
| 230 |
+
# def main(opt):
|
| 231 |
+
# # device = torch.device("cpu")
|
| 232 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 233 |
+
|
| 234 |
+
# vocab = Vocab(opt.vocab_file)
|
| 235 |
+
# vocab.load_vocab()
|
| 236 |
+
|
| 237 |
+
# input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
|
| 238 |
+
# labels = preprocess_labels(opt.dataset)
|
| 239 |
+
|
| 240 |
+
# if input_ids is None or segment_labels is None or labels is None:
|
| 241 |
+
# print("Error in preprocessing data. Exiting.")
|
| 242 |
+
# return
|
| 243 |
+
|
| 244 |
+
# dataset = TensorDataset(input_ids, segment_labels, torch.tensor(labels, dtype=torch.long))
|
| 245 |
+
# val_size = len(dataset) - int(0.8 * len(dataset))
|
| 246 |
+
# val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
|
| 247 |
+
|
| 248 |
+
# train_dataloader = DataLoader(
|
| 249 |
+
# train_dataset,
|
| 250 |
+
# batch_size=32,
|
| 251 |
+
# shuffle=True,
|
| 252 |
+
# collate_fn=custom_collate_fn
|
| 253 |
+
# )
|
| 254 |
+
# val_dataloader = DataLoader(
|
| 255 |
+
# val_dataset,
|
| 256 |
+
# batch_size=32,
|
| 257 |
+
# shuffle=False,
|
| 258 |
+
# collate_fn=custom_collate_fn
|
| 259 |
+
# )
|
| 260 |
+
|
| 261 |
+
# custom_model = CustomBERTModel(
|
| 262 |
+
# vocab_size=len(vocab.vocab),
|
| 263 |
+
# output_dim=2,
|
| 264 |
+
# pre_trained_model_path=opt.pre_trained_model_path
|
| 265 |
+
# ).to(device)
|
| 266 |
+
|
| 267 |
+
# trainer = BERTFineTuneTrainer1(
|
| 268 |
+
# bert=custom_model.bert,
|
| 269 |
+
# vocab_size=len(vocab.vocab),
|
| 270 |
+
# train_dataloader=train_dataloader,
|
| 271 |
+
# test_dataloader=val_dataloader,
|
| 272 |
+
# lr=5e-5,
|
| 273 |
+
# num_labels=2,
|
| 274 |
+
# with_cuda=torch.cuda.is_available(),
|
| 275 |
+
# log_freq=10,
|
| 276 |
+
# workspace_name=opt.output_dir,
|
| 277 |
+
# log_folder_path=opt.log_folder_path
|
| 278 |
+
# )
|
| 279 |
+
|
| 280 |
+
# trainer.train(epoch=20)
|
| 281 |
+
|
| 282 |
+
# # os.makedirs(opt.output_dir, exist_ok=True)
|
| 283 |
+
# # output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model.pth')
|
| 284 |
+
# # torch.save(custom_model.state_dict(), output_model_file)
|
| 285 |
+
# # print(f'Model saved to {output_model_file}')
|
| 286 |
+
|
| 287 |
+
# os.makedirs(opt.output_dir, exist_ok=True)
|
| 288 |
+
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
|
| 289 |
+
# torch.save(custom_model, output_model_file)
|
| 290 |
+
# print(f'Model saved to {output_model_file}')
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def main(opt):
|
| 294 |
+
# Set device to GPU if available, otherwise use CPU
|
| 295 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 296 |
+
|
| 297 |
+
print(torch.cuda.is_available()) # Should return True if GPU is available
|
| 298 |
+
print(torch.cuda.device_count())
|
| 299 |
+
|
| 300 |
+
# Load vocabulary
|
| 301 |
+
vocab = Vocab(opt.vocab_file)
|
| 302 |
+
vocab.load_vocab()
|
| 303 |
+
|
| 304 |
+
# Preprocess data and labels
|
| 305 |
+
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=128)
|
| 306 |
+
labels = preprocess_labels(opt.dataset)
|
| 307 |
+
|
| 308 |
+
if input_ids is None or segment_labels is None or labels is None:
|
| 309 |
+
print("Error in preprocessing data. Exiting.")
|
| 310 |
+
return
|
| 311 |
+
|
| 312 |
+
# Transfer tensors to the correct device (GPU/CPU)
|
| 313 |
+
input_ids = input_ids.to(device)
|
| 314 |
+
segment_labels = segment_labels.to(device)
|
| 315 |
+
labels = torch.tensor(labels, dtype=torch.long).to(device)
|
| 316 |
+
|
| 317 |
+
# Create TensorDataset and split into train and validation sets
|
| 318 |
+
dataset = TensorDataset(input_ids, segment_labels, labels)
|
| 319 |
+
val_size = len(dataset) - int(0.8 * len(dataset))
|
| 320 |
+
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
|
| 321 |
+
|
| 322 |
+
# Create DataLoaders for training and validation
|
| 323 |
+
train_dataloader = DataLoader(
|
| 324 |
+
train_dataset,
|
| 325 |
+
batch_size=32,
|
| 326 |
+
shuffle=True,
|
| 327 |
+
collate_fn=custom_collate_fn
|
| 328 |
+
)
|
| 329 |
+
val_dataloader = DataLoader(
|
| 330 |
+
val_dataset,
|
| 331 |
+
batch_size=32,
|
| 332 |
+
shuffle=False,
|
| 333 |
+
collate_fn=custom_collate_fn
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Initialize custom BERT model and move it to the device
|
| 337 |
+
custom_model = CustomBERTModel(
|
| 338 |
+
vocab_size=len(vocab.vocab),
|
| 339 |
+
output_dim=2,
|
| 340 |
+
pre_trained_model_path=opt.pre_trained_model_path
|
| 341 |
+
).to(device)
|
| 342 |
+
|
| 343 |
+
# Initialize the fine-tuning trainer
|
| 344 |
+
trainer = BERTFineTuneTrainer1(
|
| 345 |
+
bert=custom_model.bert,
|
| 346 |
+
vocab_size=len(vocab.vocab),
|
| 347 |
+
train_dataloader=train_dataloader,
|
| 348 |
+
test_dataloader=val_dataloader,
|
| 349 |
+
lr=5e-5,
|
| 350 |
+
num_labels=2,
|
| 351 |
+
with_cuda=torch.cuda.is_available(),
|
| 352 |
+
log_freq=10,
|
| 353 |
+
workspace_name=opt.output_dir,
|
| 354 |
+
log_folder_path=opt.log_folder_path
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# Train the model
|
| 358 |
+
trainer.train(epoch=20)
|
| 359 |
+
|
| 360 |
+
# Save the model to the specified output directory
|
| 361 |
+
# os.makedirs(opt.output_dir, exist_ok=True)
|
| 362 |
+
# output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
|
| 363 |
+
# torch.save(custom_model.state_dict(), output_model_file)
|
| 364 |
+
# print(f'Model saved to {output_model_file}')
|
| 365 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
| 366 |
+
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_2.pth')
|
| 367 |
+
torch.save(custom_model, output_model_file)
|
| 368 |
+
print(f'Model saved to {output_model_file}')
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
if __name__ == '__main__':
|
| 372 |
+
parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
|
| 373 |
+
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
|
| 374 |
+
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
|
| 375 |
+
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
|
| 376 |
+
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
|
| 377 |
+
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
|
| 378 |
+
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct_logs', help='Path to the folder for saving logs.')
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
opt = parser.parse_args()
|
| 382 |
+
main(opt)
|
main.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from src.bert import BERT
|
| 8 |
+
from src.pretrainer import BERTTrainer, BERTFineTuneTrainer, BERTAttention
|
| 9 |
+
from src.dataset import PretrainerDataset, TokenizerDataset
|
| 10 |
+
from src.vocab import Vocab
|
| 11 |
+
|
| 12 |
+
import time
|
| 13 |
+
import os
|
| 14 |
+
import tqdm
|
| 15 |
+
import pickle
|
| 16 |
+
|
| 17 |
+
def train():
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
|
| 20 |
+
parser.add_argument('-workspace_name', type=str, default=None)
|
| 21 |
+
parser.add_argument('-code', type=str, default=None, help="folder for pretraining outputs and logs")
|
| 22 |
+
parser.add_argument('-finetune_task', type=str, default=None, help="folder inside finetuning")
|
| 23 |
+
parser.add_argument("-attention", type=bool, default=False, help="analyse attention scores")
|
| 24 |
+
parser.add_argument("-diff_test_folder", type=bool, default=False, help="use for different test folder")
|
| 25 |
+
parser.add_argument("-embeddings", type=bool, default=False, help="get and analyse embeddings")
|
| 26 |
+
parser.add_argument('-embeddings_file_name', type=str, default=None, help="file name of embeddings")
|
| 27 |
+
parser.add_argument("-pretrain", type=bool, default=False, help="pretraining: true, or false")
|
| 28 |
+
# parser.add_argument('-opts', nargs='+', type=str, default=None, help='List of optional steps')
|
| 29 |
+
parser.add_argument("-max_mask", type=int, default=0.15, help="% of input tokens selected for masking")
|
| 30 |
+
# parser.add_argument("-p", "--pretrain_dataset", type=str, default="pretraining/pretrain.txt", help="pretraining dataset for bert")
|
| 31 |
+
# parser.add_argument("-pv", "--pretrain_val_dataset", type=str, default="pretraining/test.txt", help="pretraining validation dataset for bert")
|
| 32 |
+
# default="finetuning/test.txt",
|
| 33 |
+
parser.add_argument("-vocab_path", type=str, default="pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
| 34 |
+
|
| 35 |
+
parser.add_argument("-train_dataset_path", type=str, default="train.txt", help="fine tune train dataset for progress classifier")
|
| 36 |
+
parser.add_argument("-val_dataset_path", type=str, default="val.txt", help="test set for evaluate fine tune train set")
|
| 37 |
+
parser.add_argument("-test_dataset_path", type=str, default="test.txt", help="test set for evaluate fine tune train set")
|
| 38 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
| 39 |
+
parser.add_argument("-train_label_path", type=str, default="train_label.txt", help="fine tune train dataset for progress classifier")
|
| 40 |
+
parser.add_argument("-val_label_path", type=str, default="val_label.txt", help="test set for evaluate fine tune train set")
|
| 41 |
+
parser.add_argument("-test_label_path", type=str, default="test_label.txt", help="test set for evaluate fine tune train set")
|
| 42 |
+
##### change Checkpoint for finetuning
|
| 43 |
+
parser.add_argument("-pretrained_bert_checkpoint", type=str, default=None, help="checkpoint of saved pretrained bert model") #."output_feb09/bert_trained.model.ep40"
|
| 44 |
+
parser.add_argument('-check_epoch', type=int, default=None)
|
| 45 |
+
|
| 46 |
+
parser.add_argument("-hs", "--hidden", type=int, default=64, help="hidden size of transformer model") #64
|
| 47 |
+
parser.add_argument("-l", "--layers", type=int, default=4, help="number of layers") #4
|
| 48 |
+
parser.add_argument("-a", "--attn_heads", type=int, default=4, help="number of attention heads") #8
|
| 49 |
+
parser.add_argument("-s", "--seq_len", type=int, default=50, help="maximum sequence length")
|
| 50 |
+
|
| 51 |
+
parser.add_argument("-b", "--batch_size", type=int, default=500, help="number of batch_size") #64
|
| 52 |
+
parser.add_argument("-e", "--epochs", type=int, default=50)#1501, help="number of epochs") #501
|
| 53 |
+
# Use 50 for pretrain, and 10 for fine tune
|
| 54 |
+
parser.add_argument("-w", "--num_workers", type=int, default=4, help="dataloader worker size")
|
| 55 |
+
|
| 56 |
+
# Later run with cuda
|
| 57 |
+
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
|
| 58 |
+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
|
| 59 |
+
# parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
|
| 60 |
+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
|
| 61 |
+
# parser.add_argument("--on_memory", type=bool, default=False, help="Loading on memory: true or false")
|
| 62 |
+
|
| 63 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="dropout of network")
|
| 64 |
+
parser.add_argument("--lr", type=float, default=1e-05, help="learning rate of adam") #1e-3
|
| 65 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
|
| 66 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
|
| 67 |
+
parser.add_argument("--adam_beta2", type=float, default=0.98, help="adam first beta value") #0.999
|
| 68 |
+
|
| 69 |
+
parser.add_argument("-o", "--output_path", type=str, default="bert_trained.seq_encoder.model", help="ex)output/bert.model")
|
| 70 |
+
# parser.add_argument("-o", "--output_path", type=str, default="output/bert_fine_tuned.model", help="ex)output/bert.model")
|
| 71 |
+
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
for k,v in vars(args).items():
|
| 74 |
+
if 'path' in k:
|
| 75 |
+
if v:
|
| 76 |
+
if k == "output_path":
|
| 77 |
+
if args.code:
|
| 78 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.code}/"+v)
|
| 79 |
+
elif args.finetune_task:
|
| 80 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.finetune_task}/"+v)
|
| 81 |
+
else:
|
| 82 |
+
setattr(args, f"{k}", args.workspace_name+"/output/"+v)
|
| 83 |
+
elif k != "vocab_path":
|
| 84 |
+
if args.pretrain:
|
| 85 |
+
setattr(args, f"{k}", args.workspace_name+"/pretraining/"+v)
|
| 86 |
+
else:
|
| 87 |
+
if args.code:
|
| 88 |
+
setattr(args, f"{k}", args.workspace_name+f"/{args.code}/"+v)
|
| 89 |
+
elif args.finetune_task:
|
| 90 |
+
if args.diff_test_folder and "test" in k:
|
| 91 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/"+v)
|
| 92 |
+
else:
|
| 93 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/{args.finetune_task}/"+v)
|
| 94 |
+
else:
|
| 95 |
+
setattr(args, f"{k}", args.workspace_name+"/finetuning/"+v)
|
| 96 |
+
else:
|
| 97 |
+
setattr(args, f"{k}", args.workspace_name+"/"+v)
|
| 98 |
+
|
| 99 |
+
print(f"args.{k} : {getattr(args, f'{k}')}")
|
| 100 |
+
|
| 101 |
+
print("Loading Vocab", args.vocab_path)
|
| 102 |
+
vocab_obj = Vocab(args.vocab_path)
|
| 103 |
+
vocab_obj.load_vocab()
|
| 104 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
| 105 |
+
|
| 106 |
+
if args.attention:
|
| 107 |
+
print(f"Attention aggregate...... code: {args.code}, dataset: {args.finetune_task}")
|
| 108 |
+
if args.code:
|
| 109 |
+
new_folder = f"{args.workspace_name}/plots/{args.code}/"
|
| 110 |
+
if not os.path.exists(new_folder):
|
| 111 |
+
os.makedirs(new_folder)
|
| 112 |
+
|
| 113 |
+
train_dataset = TokenizerDataset(args.train_dataset_path, None, vocab_obj, seq_len=args.seq_len)
|
| 114 |
+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 115 |
+
print("Load Pre-trained BERT model")
|
| 116 |
+
cuda_condition = torch.cuda.is_available() and args.with_cuda
|
| 117 |
+
device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 118 |
+
bert = torch.load(args.pretrained_bert_checkpoint, map_location=device)
|
| 119 |
+
trainer = BERTAttention(bert, vocab_obj, train_dataloader = train_data_loader, workspace_name = args.workspace_name, code=args.code, finetune_task = args.finetune_task)
|
| 120 |
+
trainer.getAttention()
|
| 121 |
+
|
| 122 |
+
elif args.embeddings:
|
| 123 |
+
print("Get embeddings... and cluster... ")
|
| 124 |
+
train_dataset = TokenizerDataset(args.test_dataset_path, None, vocab_obj, seq_len=args.seq_len)
|
| 125 |
+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 126 |
+
print("Load Pre-trained BERT model")
|
| 127 |
+
cuda_condition = torch.cuda.is_available() and args.with_cuda
|
| 128 |
+
device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 129 |
+
bert = torch.load(args.pretrained_bert_checkpoint).to(device)
|
| 130 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
| 131 |
+
if torch.cuda.device_count() > 1:
|
| 132 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 133 |
+
bert = nn.DataParallel(bert, device_ids=available_gpus)
|
| 134 |
+
|
| 135 |
+
data_iter = tqdm.tqdm(enumerate(train_data_loader),
|
| 136 |
+
desc="Model: %s" % (args.pretrained_bert_checkpoint.split("/")[-1]),
|
| 137 |
+
total=len(train_data_loader), bar_format="{l_bar}{r_bar}")
|
| 138 |
+
all_embeddings = []
|
| 139 |
+
for i, data in data_iter:
|
| 140 |
+
data = {key: value.to(device) for key, value in data.items()}
|
| 141 |
+
embedding = bert(data["input"], data["segment_label"])
|
| 142 |
+
# print(embedding.shape, embedding[:, 0].shape)
|
| 143 |
+
embeddings = [h for h in embedding[:,0].cpu().detach().numpy()]
|
| 144 |
+
all_embeddings.extend(embeddings)
|
| 145 |
+
|
| 146 |
+
new_emb_folder = f"{args.workspace_name}/embeddings"
|
| 147 |
+
if not os.path.exists(new_emb_folder):
|
| 148 |
+
os.makedirs(new_emb_folder)
|
| 149 |
+
pickle.dump(all_embeddings, open(f"{new_emb_folder}/{args.embeddings_file_name}.pkl", "wb"))
|
| 150 |
+
else:
|
| 151 |
+
if args.pretrain:
|
| 152 |
+
print("Pre-training......")
|
| 153 |
+
print("Loading Pretraining Train Dataset", args.train_dataset_path)
|
| 154 |
+
print(f"Workspace: {args.workspace_name}")
|
| 155 |
+
pretrain_dataset = PretrainerDataset(args.train_dataset_path, vocab_obj, seq_len=args.seq_len, max_mask = args.max_mask)
|
| 156 |
+
|
| 157 |
+
print("Loading Pretraining Validation Dataset", args.val_dataset_path)
|
| 158 |
+
pretrain_valid_dataset = PretrainerDataset(args.val_dataset_path, vocab_obj, seq_len=args.seq_len, max_mask = args.max_mask) \
|
| 159 |
+
if args.val_dataset_path is not None else None
|
| 160 |
+
|
| 161 |
+
print("Loading Pretraining Test Dataset", args.test_dataset_path)
|
| 162 |
+
pretrain_test_dataset = PretrainerDataset(args.test_dataset_path, vocab_obj, seq_len=args.seq_len, max_mask = args.max_mask) \
|
| 163 |
+
if args.test_dataset_path is not None else None
|
| 164 |
+
|
| 165 |
+
print("Creating Dataloader")
|
| 166 |
+
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 167 |
+
pretrain_val_data_loader = DataLoader(pretrain_valid_dataset, batch_size=args.batch_size, num_workers=args.num_workers)\
|
| 168 |
+
if pretrain_valid_dataset is not None else None
|
| 169 |
+
pretrain_test_data_loader = DataLoader(pretrain_test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)\
|
| 170 |
+
if pretrain_test_dataset is not None else None
|
| 171 |
+
|
| 172 |
+
print("Building BERT model")
|
| 173 |
+
bert = BERT(len(vocab_obj.vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads, dropout=args.dropout)
|
| 174 |
+
|
| 175 |
+
if args.pretrained_bert_checkpoint:
|
| 176 |
+
print(f"BERT model : {args.pretrained_bert_checkpoint}")
|
| 177 |
+
bert = torch.load(args.pretrained_bert_checkpoint)
|
| 178 |
+
|
| 179 |
+
new_log_folder = f"{args.workspace_name}/logs"
|
| 180 |
+
new_output_folder = f"{args.workspace_name}/output"
|
| 181 |
+
if args.code: # is sent almost all the time
|
| 182 |
+
new_log_folder = f"{args.workspace_name}/logs/{args.code}"
|
| 183 |
+
new_output_folder = f"{args.workspace_name}/output/{args.code}"
|
| 184 |
+
|
| 185 |
+
if not os.path.exists(new_log_folder):
|
| 186 |
+
os.makedirs(new_log_folder)
|
| 187 |
+
if not os.path.exists(new_output_folder):
|
| 188 |
+
os.makedirs(new_output_folder)
|
| 189 |
+
|
| 190 |
+
print(f"Creating BERT Trainer .... masking: True, max_mask: {args.max_mask}")
|
| 191 |
+
trainer = BERTTrainer(bert, len(vocab_obj.vocab), train_dataloader=pretrain_data_loader,
|
| 192 |
+
val_dataloader=pretrain_val_data_loader, test_dataloader=pretrain_test_data_loader,
|
| 193 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
| 194 |
+
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq,
|
| 195 |
+
log_folder_path=new_log_folder)
|
| 196 |
+
|
| 197 |
+
start_time = time.time()
|
| 198 |
+
print(f'Pretraining Starts, Time: {time.strftime("%D %T", time.localtime(start_time))}')
|
| 199 |
+
# if need to pretrain from a check-point, need :check_epoch
|
| 200 |
+
repoch = range(args.check_epoch, args.epochs) if args.check_epoch else range(args.epochs)
|
| 201 |
+
counter = 0
|
| 202 |
+
patience = 20
|
| 203 |
+
for epoch in repoch:
|
| 204 |
+
print(f'Training Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
| 205 |
+
trainer.train(epoch)
|
| 206 |
+
print(f'Training Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
| 207 |
+
|
| 208 |
+
if pretrain_val_data_loader is not None:
|
| 209 |
+
print(f'Validation Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
| 210 |
+
trainer.val(epoch)
|
| 211 |
+
print(f'Validation Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
| 212 |
+
|
| 213 |
+
if trainer.save_model: # or epoch%10 == 0 and epoch > 4
|
| 214 |
+
trainer.save(epoch, args.output_path)
|
| 215 |
+
counter = 0
|
| 216 |
+
if pretrain_test_data_loader is not None:
|
| 217 |
+
print(f'Test Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
| 218 |
+
trainer.test(epoch)
|
| 219 |
+
print(f'Test Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
| 220 |
+
else:
|
| 221 |
+
counter +=1
|
| 222 |
+
if counter >= patience:
|
| 223 |
+
print(f"Early stopping at epoch {epoch}")
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
end_time = time.time()
|
| 227 |
+
print("Time Taken to pretrain model = ", end_time - start_time)
|
| 228 |
+
print(f'Pretraining Ends, Time: {time.strftime("%D %T", time.localtime(end_time))}')
|
| 229 |
+
else:
|
| 230 |
+
print("Fine Tuning......")
|
| 231 |
+
print("Loading Train Dataset", args.train_dataset_path)
|
| 232 |
+
train_dataset = TokenizerDataset(args.train_dataset_path, args.train_label_path, vocab_obj, seq_len=args.seq_len)
|
| 233 |
+
|
| 234 |
+
# print("Loading Validation Dataset", args.val_dataset_path)
|
| 235 |
+
# val_dataset = TokenizerDataset(args.val_dataset_path, args.val_label_path, vocab_obj, seq_len=args.seq_len) \
|
| 236 |
+
# if args.val_dataset_path is not None else None
|
| 237 |
+
|
| 238 |
+
print("Loading Test Dataset", args.test_dataset_path)
|
| 239 |
+
test_dataset = TokenizerDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len) \
|
| 240 |
+
if args.test_dataset_path is not None else None
|
| 241 |
+
|
| 242 |
+
print("Creating Dataloader...")
|
| 243 |
+
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 244 |
+
# val_data_loader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
|
| 245 |
+
# if val_dataset is not None else None
|
| 246 |
+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
|
| 247 |
+
if test_dataset is not None else None
|
| 248 |
+
|
| 249 |
+
print("Load Pre-trained BERT model")
|
| 250 |
+
# bert = BERT(len(vocab_obj.vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)
|
| 251 |
+
cuda_condition = torch.cuda.is_available() and args.with_cuda
|
| 252 |
+
device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 253 |
+
bert = torch.load(args.pretrained_bert_checkpoint, map_location=device)
|
| 254 |
+
|
| 255 |
+
# if args.finetune_task == "SL":
|
| 256 |
+
# if args.workspace_name == "ratio_proportion_change4":
|
| 257 |
+
# num_labels = 9
|
| 258 |
+
# elif args.workspace_name == "ratio_proportion_change3":
|
| 259 |
+
# num_labels = 9
|
| 260 |
+
# elif args.workspace_name == "scale_drawings_3":
|
| 261 |
+
# num_labels = 9
|
| 262 |
+
# elif args.workspace_name == "sales_tax_discounts_two_rates":
|
| 263 |
+
# num_labels = 3
|
| 264 |
+
# else:
|
| 265 |
+
# num_labels = 2
|
| 266 |
+
# # num_labels = 1
|
| 267 |
+
# print(f"Number of Labels : {args.num_labels}")
|
| 268 |
+
new_log_folder = f"{args.workspace_name}/logs"
|
| 269 |
+
new_output_folder = f"{args.workspace_name}/output"
|
| 270 |
+
if args.finetune_task: # is sent almost all the time
|
| 271 |
+
new_log_folder = f"{args.workspace_name}/logs/{args.finetune_task}"
|
| 272 |
+
new_output_folder = f"{args.workspace_name}/output/{args.finetune_task}"
|
| 273 |
+
|
| 274 |
+
if not os.path.exists(new_log_folder):
|
| 275 |
+
os.makedirs(new_log_folder)
|
| 276 |
+
if not os.path.exists(new_output_folder):
|
| 277 |
+
os.makedirs(new_output_folder)
|
| 278 |
+
|
| 279 |
+
print("Creating BERT Fine Tune Trainer")
|
| 280 |
+
trainer = BERTFineTuneTrainer(bert, len(vocab_obj.vocab),
|
| 281 |
+
train_dataloader=train_data_loader, test_dataloader=test_data_loader,
|
| 282 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
| 283 |
+
with_cuda=args.with_cuda, cuda_devices = args.cuda_devices, log_freq=args.log_freq,
|
| 284 |
+
workspace_name = args.workspace_name, num_labels=args.num_labels, log_folder_path=new_log_folder)
|
| 285 |
+
|
| 286 |
+
print("Fine-tune training Start....")
|
| 287 |
+
start_time = time.time()
|
| 288 |
+
repoch = range(args.check_epoch, args.epochs) if args.check_epoch else range(args.epochs)
|
| 289 |
+
counter = 0
|
| 290 |
+
patience = 10
|
| 291 |
+
for epoch in repoch:
|
| 292 |
+
print(f'Training Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
| 293 |
+
trainer.train(epoch)
|
| 294 |
+
print(f'Training Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
| 295 |
+
|
| 296 |
+
if test_data_loader is not None:
|
| 297 |
+
print(f'Test Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
| 298 |
+
trainer.test(epoch)
|
| 299 |
+
# pickle.dump(trainer.probability_list, open(f"{args.workspace_name}/output/aaai/change4_mid_prob_{epoch}.pkl","wb"))
|
| 300 |
+
print(f'Test Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
| 301 |
+
|
| 302 |
+
# if val_data_loader is not None:
|
| 303 |
+
# print(f'Validation Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
| 304 |
+
# trainer.val(epoch)
|
| 305 |
+
# print(f'Validation Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
| 306 |
+
|
| 307 |
+
if trainer.save_model: # or epoch%10 == 0
|
| 308 |
+
trainer.save(epoch, args.output_path)
|
| 309 |
+
counter = 0
|
| 310 |
+
else:
|
| 311 |
+
counter +=1
|
| 312 |
+
if counter >= patience:
|
| 313 |
+
print(f"Early stopping at epoch {epoch}")
|
| 314 |
+
break
|
| 315 |
+
|
| 316 |
+
end_time = time.time()
|
| 317 |
+
print("Time Taken to fine-tune model = ", end_time - start_time)
|
| 318 |
+
print(f'Pretraining Ends, Time: {time.strftime("%D %T", time.localtime(end_time))}')
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
train()
|
metrics.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.special import softmax
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CELoss(object):
|
| 6 |
+
|
| 7 |
+
def compute_bin_boundaries(self, probabilities = np.array([])):
|
| 8 |
+
|
| 9 |
+
#uniform bin spacing
|
| 10 |
+
if probabilities.size == 0:
|
| 11 |
+
bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
|
| 12 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 13 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 14 |
+
else:
|
| 15 |
+
#size of bins
|
| 16 |
+
bin_n = int(self.n_data/self.n_bins)
|
| 17 |
+
|
| 18 |
+
bin_boundaries = np.array([])
|
| 19 |
+
|
| 20 |
+
probabilities_sort = np.sort(probabilities)
|
| 21 |
+
|
| 22 |
+
for i in range(0,self.n_bins):
|
| 23 |
+
bin_boundaries = np.append(bin_boundaries,probabilities_sort[i*bin_n])
|
| 24 |
+
bin_boundaries = np.append(bin_boundaries,1.0)
|
| 25 |
+
|
| 26 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 27 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_probabilities(self, output, labels, logits):
|
| 31 |
+
#If not probabilities apply softmax!
|
| 32 |
+
if logits:
|
| 33 |
+
self.probabilities = softmax(output, axis=1)
|
| 34 |
+
else:
|
| 35 |
+
self.probabilities = output
|
| 36 |
+
|
| 37 |
+
self.labels = labels
|
| 38 |
+
self.confidences = np.max(self.probabilities, axis=1)
|
| 39 |
+
self.predictions = np.argmax(self.probabilities, axis=1)
|
| 40 |
+
self.accuracies = np.equal(self.predictions,labels)
|
| 41 |
+
|
| 42 |
+
def binary_matrices(self):
|
| 43 |
+
idx = np.arange(self.n_data)
|
| 44 |
+
#make matrices of zeros
|
| 45 |
+
pred_matrix = np.zeros([self.n_data,self.n_class])
|
| 46 |
+
label_matrix = np.zeros([self.n_data,self.n_class])
|
| 47 |
+
#self.acc_matrix = np.zeros([self.n_data,self.n_class])
|
| 48 |
+
pred_matrix[idx,self.predictions] = 1
|
| 49 |
+
label_matrix[idx,self.labels] = 1
|
| 50 |
+
|
| 51 |
+
self.acc_matrix = np.equal(pred_matrix, label_matrix)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def compute_bins(self, index = None):
|
| 55 |
+
self.bin_prop = np.zeros(self.n_bins)
|
| 56 |
+
self.bin_acc = np.zeros(self.n_bins)
|
| 57 |
+
self.bin_conf = np.zeros(self.n_bins)
|
| 58 |
+
self.bin_score = np.zeros(self.n_bins)
|
| 59 |
+
|
| 60 |
+
if index == None:
|
| 61 |
+
confidences = self.confidences
|
| 62 |
+
accuracies = self.accuracies
|
| 63 |
+
else:
|
| 64 |
+
confidences = self.probabilities[:,index]
|
| 65 |
+
accuracies = self.acc_matrix[:,index]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
for i, (bin_lower, bin_upper) in enumerate(zip(self.bin_lowers, self.bin_uppers)):
|
| 69 |
+
# Calculated |confidence - accuracy| in each bin
|
| 70 |
+
in_bin = np.greater(confidences,bin_lower.item()) * np.less_equal(confidences,bin_upper.item())
|
| 71 |
+
self.bin_prop[i] = np.mean(in_bin)
|
| 72 |
+
|
| 73 |
+
if self.bin_prop[i].item() > 0:
|
| 74 |
+
self.bin_acc[i] = np.mean(accuracies[in_bin])
|
| 75 |
+
self.bin_conf[i] = np.mean(confidences[in_bin])
|
| 76 |
+
self.bin_score[i] = np.abs(self.bin_conf[i] - self.bin_acc[i])
|
| 77 |
+
|
| 78 |
+
class MaxProbCELoss(CELoss):
|
| 79 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 80 |
+
self.n_bins = n_bins
|
| 81 |
+
super().compute_bin_boundaries()
|
| 82 |
+
super().get_probabilities(output, labels, logits)
|
| 83 |
+
super().compute_bins()
|
| 84 |
+
|
| 85 |
+
#http://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf
|
| 86 |
+
class ECELoss(MaxProbCELoss):
|
| 87 |
+
|
| 88 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 89 |
+
super().loss(output, labels, n_bins, logits)
|
| 90 |
+
return np.dot(self.bin_prop,self.bin_score)
|
| 91 |
+
|
| 92 |
+
class MCELoss(MaxProbCELoss):
|
| 93 |
+
|
| 94 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 95 |
+
super().loss(output, labels, n_bins, logits)
|
| 96 |
+
return np.max(self.bin_score)
|
| 97 |
+
|
| 98 |
+
#https://arxiv.org/abs/1905.11001
|
| 99 |
+
#Overconfidence Loss (Good in high risk applications where confident but wrong predictions can be especially harmful)
|
| 100 |
+
class OELoss(MaxProbCELoss):
|
| 101 |
+
|
| 102 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 103 |
+
super().loss(output, labels, n_bins, logits)
|
| 104 |
+
return np.dot(self.bin_prop,self.bin_conf * np.maximum(self.bin_conf-self.bin_acc,np.zeros(self.n_bins)))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
#https://arxiv.org/abs/1904.01685
|
| 108 |
+
class SCELoss(CELoss):
|
| 109 |
+
|
| 110 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 111 |
+
sce = 0.0
|
| 112 |
+
self.n_bins = n_bins
|
| 113 |
+
self.n_data = len(output)
|
| 114 |
+
self.n_class = len(output[0])
|
| 115 |
+
|
| 116 |
+
super().compute_bin_boundaries()
|
| 117 |
+
super().get_probabilities(output, labels, logits)
|
| 118 |
+
super().binary_matrices()
|
| 119 |
+
|
| 120 |
+
for i in range(self.n_class):
|
| 121 |
+
super().compute_bins(i)
|
| 122 |
+
sce += np.dot(self.bin_prop,self.bin_score)
|
| 123 |
+
|
| 124 |
+
return sce/self.n_class
|
| 125 |
+
|
| 126 |
+
class TACELoss(CELoss):
|
| 127 |
+
|
| 128 |
+
def loss(self, output, labels, threshold = 0.01, n_bins = 15, logits = True):
|
| 129 |
+
tace = 0.0
|
| 130 |
+
self.n_bins = n_bins
|
| 131 |
+
self.n_data = len(output)
|
| 132 |
+
self.n_class = len(output[0])
|
| 133 |
+
|
| 134 |
+
super().get_probabilities(output, labels, logits)
|
| 135 |
+
self.probabilities[self.probabilities < threshold] = 0
|
| 136 |
+
super().binary_matrices()
|
| 137 |
+
|
| 138 |
+
for i in range(self.n_class):
|
| 139 |
+
super().compute_bin_boundaries(self.probabilities[:,i])
|
| 140 |
+
super().compute_bins(i)
|
| 141 |
+
tace += np.dot(self.bin_prop,self.bin_score)
|
| 142 |
+
|
| 143 |
+
return tace/self.n_class
|
| 144 |
+
|
| 145 |
+
#create TACELoss with threshold fixed at 0
|
| 146 |
+
class ACELoss(TACELoss):
|
| 147 |
+
|
| 148 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 149 |
+
return super().loss(output, labels, 0.0 , n_bins, logits)
|
new_fine_tuning/README.md
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Pre-training Data
|
| 2 |
+
|
| 3 |
+
### ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
| 4 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -analyze_dataset_by_section True -workspace_name ratio_proportion_change3 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain1000.txt -train_info_path pretraining/pretrain1000_info.txt -test_file_path pretraining/test1000.txt -test_info_path pretraining/test1000_info.txt
|
| 5 |
+
|
| 6 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain2000.txt -train_info_path pretraining/pretrain2000_info.txt -test_file_path pretraining/test2000.txt -test_info_path pretraining/test2000_info.txt
|
| 7 |
+
|
| 8 |
+
#### Test simple
|
| 9 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code full -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path full.txt -train_info_path full_info.txt
|
| 10 |
+
|
| 11 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code gt -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path er.txt -train_info_path er_info.txt -test_file_path me.txt -test_info_path me_info.txt
|
| 12 |
+
|
| 13 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code correct -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path correct.txt -train_info_path correct_info.txt -test_file_path incorrect.txt -test_info_path incorrect_info.txt -final_step FinalAnswer
|
| 14 |
+
|
| 15 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -code progress -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path graduated.txt -train_info_path graduated_info.txt -test_file_path promoted.txt -test_info_path promoted_info.txt
|
| 16 |
+
|
| 17 |
+
### ratio_proportion_change4 : Using Percents and Percent Change
|
| 18 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -analyze_dataset_by_section True -workspace_name ratio_proportion_change4 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor NumeratorLabel1 DenominatorLabel1 -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain1000.txt -train_info_path pretraining/pretrain1000_info.txt -test_file_path pretraining/test1000.txt -test_info_path pretraining/test1000_info.txt
|
| 19 |
+
|
| 20 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor NumeratorLabel1 DenominatorLabel1 -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -pretrain True -train_file_path pretraining/pretrain2000.txt -train_info_path pretraining/pretrain2000_info.txt -test_file_path pretraining/test2000.txt -test_info_path pretraining/test2000_info.txt
|
| 21 |
+
|
| 22 |
+
#### Test simple
|
| 23 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code full -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path full.txt -train_info_path full_info.txt
|
| 24 |
+
|
| 25 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code gt -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path er.txt -train_info_path er_info.txt -test_file_path me.txt -test_info_path me_info.txt
|
| 26 |
+
|
| 27 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code correct -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path correct.txt -train_info_path correct_info.txt -test_file_path incorrect.txt -test_info_path incorrect_info.txt -final_step FinalAnswer
|
| 28 |
+
|
| 29 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -code progress -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -train_file_path graduated.txt -train_info_path graduated_info.txt -test_file_path promoted.txt -test_info_path promoted_info.txt
|
| 30 |
+
|
| 31 |
+
## Pretraining
|
| 32 |
+
|
| 33 |
+
### ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
| 34 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3_1920 -code pretrain1000 --pretrain_dataset pretraining/pretrain1000.txt --pretrain_val_dataset pretraining/test1000.txt
|
| 35 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000 --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt
|
| 36 |
+
|
| 37 |
+
#### Test simple models
|
| 38 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 1
|
| 39 |
+
|
| 40 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 2
|
| 41 |
+
|
| 42 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 2
|
| 43 |
+
|
| 44 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 4
|
| 45 |
+
|
| 46 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 4
|
| 47 |
+
|
| 48 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 8
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
### ratio_proportion_change4 : Using Percents and Percent Change
|
| 53 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain1000 --pretrain_dataset pretraining/pretrain1000.txt --pretrain_val_dataset pretraining/test1000.txt
|
| 54 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000 --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt
|
| 55 |
+
|
| 56 |
+
#### Test simple models
|
| 57 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_1l1h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 1
|
| 58 |
+
|
| 59 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_1l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 1 --attn_heads 2
|
| 60 |
+
|
| 61 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_2l2h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 2
|
| 62 |
+
|
| 63 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_2l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 2 --attn_heads 4
|
| 64 |
+
|
| 65 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_4l4h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 4
|
| 66 |
+
|
| 67 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -code pretrain2000_4l8h-5lr --pretrain_dataset pretraining/pretrain2000.txt --pretrain_val_dataset pretraining/test2000.txt --layers 4 --attn_heads 8
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
## Preparing Fine Tuning Data
|
| 71 |
+
|
| 72 |
+
### ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
| 73 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change3 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -final_step FinalAnswer
|
| 74 |
+
|
| 75 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task check2 --train_dataset finetuning/check2/train.txt --test_dataset finetuning/check2/test.txt --train_label finetuning/check2/train_label.txt --test_label finetuning/check2/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
| 76 |
+
|
| 77 |
+
#### Attention Head Check
|
| 78 |
+
<!-- > PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_1 EquationAnswer NumeratorFactor EquationAnswer NumeratorFactor EquationAnswer NumeratorFactor DenominatorFactor NumeratorFactor DenominatorFactor NumeratorFactor DenominatorFactor FirstRow1:2 FirstRow1:1 FirstRow2:1 FirstRow2:2 FirstRow2:1 SecondRow ThirdRow FinalAnswerDirection ThirdRow FinalAnswer -->
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task er ;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task correct ;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep598 --attention True -finetune_task promoted
|
| 82 |
+
|
| 83 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep823 --attention True -finetune_task promoted
|
| 84 |
+
|
| 85 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l2h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l2h-5lr/bert_trained.seq_encoder.model.ep1045 --attention True -finetune_task promoted
|
| 86 |
+
|
| 87 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_2l4h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_2l4h-5lr/bert_trained.seq_encoder.model.ep1336 --attention True -finetune_task promoted
|
| 88 |
+
|
| 89 |
+
<!-- > clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep923 --attention True -->
|
| 90 |
+
|
| 91 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l4h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l4h-5lr/bert_trained.seq_encoder.model.ep871 --attention True -finetune_task promoted
|
| 92 |
+
|
| 93 |
+
clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset full/full_attn.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task full
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset full/full.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task full;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset gt/er.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task er;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset gt/me.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task me;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset correct/correct.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task correct;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset correct/incorrect.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task incorrect;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset progress/graduated.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task graduated;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_4l8h-5lr --train_dataset progress/promoted.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_4l8h-5lr/bert_trained.seq_encoder.model.ep1349 --attention True -finetune_task promoted
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
<!-- PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_2 FirstRow2:1 FirstRow2:2 FirstRow1:1 SecondRow ThirdRow FinalAnswer FinalAnswerDirection --> me
|
| 100 |
+
|
| 101 |
+
<!-- PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_1 DenominatorFactor NumeratorFactor OptionalTask_2 EquationAnswer FirstRow1:1 FirstRow1:2 FirstRow2:2 FirstRow2:1 FirstRow1:2 SecondRow ThirdRow FinalAnswer --> er
|
| 102 |
+
|
| 103 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l1h-5lr --train_dataset pretraining/attention_train.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l1h-5lr/bert_trained.seq_encoder.model.ep273 --attention True
|
| 104 |
+
|
| 105 |
+
<!-- PercentChange NumeratorQuantity2 NumeratorQuantity1 DenominatorQuantity1 OptionalTask_1 DenominatorFactor NumeratorFactor OptionalTask_2 EquationAnswer FirstRow1:1 FirstRow1:2 FirstRow2:2 FirstRow2:1 FirstRow1:2 SecondRow ThirdRow FinalAnswer -->
|
| 106 |
+
|
| 107 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -code pretrain2000_1l2h-5lr --train_dataset pretraining/attention_train.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000_1l2h-5lr/bert_trained.seq_encoder.model.ep1021 --attention True
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
### ratio_proportion_change4 : Using Percents and Percent Change
|
| 112 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name ratio_proportion_change4 -opt_step1 OptionalTask_1 EquationAnswer NumeratorFactor DenominatorFactor NumeratorLabel1 DenominatorLabel1 -opt_step2 OptionalTask_2 FirstRow1:1 FirstRow1:2 FirstRow2:1 FirstRow2:2 SecondRow ThirdRow -final_step FinalAnswer
|
| 113 |
+
|
| 114 |
+
### scale_drawings_3 : Calculating Measurements Using a Scale
|
| 115 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name scale_drawings_3 -opt_step1 opt1-check opt1-ratio-L-n opt1-ratio-L-d opt1-ratio-R-n opt1-ratio-R-d opt1-me2-top-3 opt1-me2-top-4 opt1-me2-top-2 opt1-me2-top-1 opt1-me2-middle-1 opt1-me2-bottom-1 -opt_step2 opt2-check opt2-ratio-L-n opt2-ratio-L-d opt2-ratio-R-n opt2-ratio-R-d opt2-me2-top-3 opt2-me2-top-4 opt2-me2-top-1 opt2-me2-top-2 opt2-me2-middle-1 opt2-me2-bottom-1 -final_step unk-value1 unk-value2
|
| 116 |
+
|
| 117 |
+
### sales_tax_discounts_two_rates : Solving Problems with Both Sales Tax and Discounts
|
| 118 |
+
> clear;python3 prepare_pretraining_input_vocab_file.py -workspace_name sales_tax_discounts_two_rates -opt_step1 optionalTaskGn salestaxFactor2 discountFactor2 multiplyOrderStatementGn -final_step totalCost1
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# Fine Tuning Pre-trained model
|
| 122 |
+
|
| 123 |
+
## ratio_proportion_change3 : Calculating Percent Change and Final Amounts
|
| 124 |
+
> Selected Pretrained model: **ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279**
|
| 125 |
+
> New **bert/ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731**
|
| 126 |
+
|
| 127 |
+
### 10per
|
| 128 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731 --epochs 51
|
| 129 |
+
|
| 130 |
+
### IS
|
| 131 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task IS --train_dataset finetuning/IS/train.txt --test_dataset finetuning/FS/train.txt --train_label finetuning/IS/train_label.txt --test_label finetuning/FS/train_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731 --epochs 51
|
| 132 |
+
|
| 133 |
+
### FS
|
| 134 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task FS --train_dataset finetuning/FS/train.txt --test_dataset finetuning/IS/train.txt --train_label finetuning/FS/train_label.txt --test_label finetuning/IS/train_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/pretrain2000/bert_trained.seq_encoder.model.ep731 --epochs 51
|
| 135 |
+
|
| 136 |
+
### correctness
|
| 137 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
| 138 |
+
|
| 139 |
+
### SL
|
| 140 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task SL --train_dataset finetuning/SL/train.txt --test_dataset finetuning/SL/test.txt --train_label finetuning/SL/train_label.txt --test_label finetuning/SL/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
| 141 |
+
|
| 142 |
+
### effectiveness
|
| 143 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change3 -finetune_task effectiveness --train_dataset finetuning/effectiveness/train.txt --test_dataset finetuning/effectiveness/test.txt --train_label finetuning/effectiveness/train_label.txt --test_label finetuning/effectiveness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change3/output/bert_trained.seq_encoder.model.ep279 --epochs 51
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
## ratio_proportion_change4 : Using Percents and Percent Change
|
| 147 |
+
> Selected Pretrained model: **ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287**
|
| 148 |
+
### 10per
|
| 149 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
| 150 |
+
|
| 151 |
+
### IS
|
| 152 |
+
|
| 153 |
+
### FS
|
| 154 |
+
|
| 155 |
+
### correctness
|
| 156 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
| 157 |
+
|
| 158 |
+
### SL
|
| 159 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task SL --train_dataset finetuning/SL/train.txt --test_dataset finetuning/SL/test.txt --train_label finetuning/SL/train_label.txt --test_label finetuning/SL/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
| 160 |
+
|
| 161 |
+
### effectiveness
|
| 162 |
+
> clear;python3 src/main.py -workspace_name ratio_proportion_change4 -finetune_task effectiveness --train_dataset finetuning/effectiveness/train.txt --test_dataset finetuning/effectiveness/test.txt --train_label finetuning/effectiveness/train_label.txt --test_label finetuning/effectiveness/test_label.txt --pretrained_bert_checkpoint ratio_proportion_change4/output/bert_trained.seq_encoder.model.ep287 --epochs 51
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
## scale_drawings_3 : Calculating Measurements Using a Scale
|
| 166 |
+
> Selected Pretrained model: **scale_drawings_3/output/bert_trained.seq_encoder.model.ep252**
|
| 167 |
+
### 10per
|
| 168 |
+
> clear;python3 src/main.py -workspace_name scale_drawings_3 -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint scale_drawings_3/output/bert_trained.seq_encoder.model.ep252 --epochs 51
|
| 169 |
+
|
| 170 |
+
### IS
|
| 171 |
+
|
| 172 |
+
### FS
|
| 173 |
+
|
| 174 |
+
### correctness
|
| 175 |
+
> clear;python3 src/main.py -workspace_name scale_drawings_3 -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint scale_drawings_3/output/bert_trained.seq_encoder.model.ep252 --epochs 51
|
| 176 |
+
|
| 177 |
+
### SL
|
| 178 |
+
> clear;python3 src/main.py -workspace_name scale_drawings_3 -finetune_task SL --train_dataset finetuning/SL/train.txt --test_dataset finetuning/SL/test.txt --train_label finetuning/SL/train_label.txt --test_label finetuning/SL/test_label.txt --pretrained_bert_checkpoint scale_drawings_3/output/bert_trained.seq_encoder.model.ep252 --epochs 51
|
| 179 |
+
|
| 180 |
+
### effectiveness
|
| 181 |
+
|
| 182 |
+
## sales_tax_discounts_two_rates : Solving Problems with Both Sales Tax and Discounts
|
| 183 |
+
> Selected Pretrained model: **sales_tax_discounts_two_rates/output/bert_trained.seq_encoder.model.ep255**
|
| 184 |
+
|
| 185 |
+
### 10per
|
| 186 |
+
> clear;python3 src/main.py -workspace_name sales_tax_discounts_two_rates -finetune_task 10per --train_dataset finetuning/10per/train.txt --test_dataset finetuning/10per/test.txt --train_label finetuning/10per/train_label.txt --test_label finetuning/10per/test_label.txt --pretrained_bert_checkpoint sales_tax_discounts_two_rates/output/bert_trained.seq_encoder.model.ep255 --epochs 51
|
| 187 |
+
|
| 188 |
+
### IS
|
| 189 |
+
|
| 190 |
+
### FS
|
| 191 |
+
|
| 192 |
+
### correctness
|
| 193 |
+
> clear;python3 src/main.py -workspace_name sales_tax_discounts_two_rates -finetune_task correctness --train_dataset finetuning/correctness/train.txt --test_dataset finetuning/correctness/test.txt --train_label finetuning/correctness/train_label.txt --test_label finetuning/correctness/test_label.txt --pretrained_bert_checkpoint sales_tax_discounts_two_rates/output/bert_trained.seq_encoder.model.ep255 --epochs 51
|
| 194 |
+
|
| 195 |
+
### SL
|
| 196 |
+
|
| 197 |
+
### effectiveness
|
new_fine_tuning/__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (9.16 kB). View file
|
|
|
new_fine_tuning/__pycache__/recalibration.cpython-312.pyc
ADDED
|
Binary file (5.51 kB). View file
|
|
|
new_fine_tuning/__pycache__/visualization.cpython-312.pyc
ADDED
|
Binary file (5.28 kB). View file
|
|
|
new_hint_fine_tuned.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import DataLoader, random_split, TensorDataset
|
| 6 |
+
from src.dataset import TokenizerDataset
|
| 7 |
+
from src.bert import BERT
|
| 8 |
+
from src.pretrainer import BERTFineTuneTrainer1
|
| 9 |
+
from src.vocab import Vocab
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
def preprocess_labels(label_csv_path):
|
| 13 |
+
try:
|
| 14 |
+
labels_df = pd.read_csv(label_csv_path)
|
| 15 |
+
labels = labels_df['last_hint_class'].values.astype(int)
|
| 16 |
+
return torch.tensor(labels, dtype=torch.long)
|
| 17 |
+
except Exception as e:
|
| 18 |
+
print(f"Error reading dataset file: {e}")
|
| 19 |
+
return None
|
| 20 |
+
|
| 21 |
+
def preprocess_data(data_path, vocab, max_length=128):
|
| 22 |
+
try:
|
| 23 |
+
with open(data_path, 'r') as f:
|
| 24 |
+
sequences = f.readlines()
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Error reading data file: {e}")
|
| 27 |
+
return None, None
|
| 28 |
+
|
| 29 |
+
tokenized_sequences = []
|
| 30 |
+
for sequence in sequences:
|
| 31 |
+
sequence = sequence.strip()
|
| 32 |
+
if sequence:
|
| 33 |
+
encoded = vocab.to_seq(sequence, seq_len=max_length)
|
| 34 |
+
encoded = encoded[:max_length] + [vocab.vocab.get('[PAD]', 0)] * (max_length - len(encoded))
|
| 35 |
+
segment_label = [0] * max_length
|
| 36 |
+
|
| 37 |
+
tokenized_sequences.append({
|
| 38 |
+
'input_ids': torch.tensor(encoded),
|
| 39 |
+
'segment_label': torch.tensor(segment_label)
|
| 40 |
+
})
|
| 41 |
+
|
| 42 |
+
input_ids = torch.cat([t['input_ids'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
| 43 |
+
segment_labels = torch.cat([t['segment_label'].unsqueeze(0) for t in tokenized_sequences], dim=0)
|
| 44 |
+
|
| 45 |
+
print(f"Input IDs shape: {input_ids.shape}")
|
| 46 |
+
print(f"Segment labels shape: {segment_labels.shape}")
|
| 47 |
+
|
| 48 |
+
return input_ids, segment_labels
|
| 49 |
+
|
| 50 |
+
def custom_collate_fn(batch):
|
| 51 |
+
inputs = [item['input_ids'].unsqueeze(0) for item in batch]
|
| 52 |
+
labels = [item['label'].unsqueeze(0) for item in batch]
|
| 53 |
+
segment_labels = [item['segment_label'].unsqueeze(0) for item in batch]
|
| 54 |
+
|
| 55 |
+
inputs = torch.cat(inputs, dim=0)
|
| 56 |
+
labels = torch.cat(labels, dim=0)
|
| 57 |
+
segment_labels = torch.cat(segment_labels, dim=0)
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
'input': inputs,
|
| 61 |
+
'label': labels,
|
| 62 |
+
'segment_label': segment_labels
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def main(opt):
|
| 66 |
+
# Set device to GPU if available, otherwise use CPU
|
| 67 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 68 |
+
|
| 69 |
+
# Load vocabulary
|
| 70 |
+
vocab = Vocab(opt.vocab_file)
|
| 71 |
+
vocab.load_vocab()
|
| 72 |
+
|
| 73 |
+
# Preprocess data and labels
|
| 74 |
+
input_ids, segment_labels = preprocess_data(opt.data_path, vocab, max_length=50) # Using sequence length 50
|
| 75 |
+
labels = preprocess_labels(opt.dataset)
|
| 76 |
+
|
| 77 |
+
if input_ids is None or segment_labels is None or labels is None:
|
| 78 |
+
print("Error in preprocessing data. Exiting.")
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
# Create TensorDataset and split into train and validation sets
|
| 82 |
+
dataset = TensorDataset(input_ids, segment_labels, labels)
|
| 83 |
+
val_size = len(dataset) - int(0.8 * len(dataset))
|
| 84 |
+
val_dataset, train_dataset = random_split(dataset, [val_size, len(dataset) - val_size])
|
| 85 |
+
|
| 86 |
+
# Create DataLoaders for training and validation
|
| 87 |
+
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
|
| 88 |
+
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=custom_collate_fn)
|
| 89 |
+
|
| 90 |
+
# Initialize custom BERT model and move it to the device
|
| 91 |
+
custom_model = CustomBERTModel(
|
| 92 |
+
vocab_size=len(vocab.vocab),
|
| 93 |
+
output_dim=2,
|
| 94 |
+
pre_trained_model_path=opt.pre_trained_model_path
|
| 95 |
+
).to(device)
|
| 96 |
+
|
| 97 |
+
# Initialize the fine-tuning trainer
|
| 98 |
+
trainer = BERTFineTuneTrainer1(
|
| 99 |
+
bert=custom_model,
|
| 100 |
+
vocab_size=len(vocab.vocab),
|
| 101 |
+
train_dataloader=train_dataloader,
|
| 102 |
+
test_dataloader=val_dataloader,
|
| 103 |
+
lr=1e-5, # Using learning rate 10^-5 as specified
|
| 104 |
+
num_labels=2,
|
| 105 |
+
with_cuda=torch.cuda.is_available(),
|
| 106 |
+
log_freq=10,
|
| 107 |
+
workspace_name=opt.output_dir,
|
| 108 |
+
log_folder_path=opt.log_folder_path
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Train the model
|
| 112 |
+
trainer.train(epoch=20)
|
| 113 |
+
|
| 114 |
+
# Save the model
|
| 115 |
+
os.makedirs(opt.output_dir, exist_ok=True)
|
| 116 |
+
output_model_file = os.path.join(opt.output_dir, 'fine_tuned_model_3.pth')
|
| 117 |
+
torch.save(custom_model, output_model_file)
|
| 118 |
+
print(f'Model saved to {output_model_file}')
|
| 119 |
+
|
| 120 |
+
if __name__ == '__main__':
|
| 121 |
+
parser = argparse.ArgumentParser(description='Fine-tune BERT model.')
|
| 122 |
+
parser.add_argument('--dataset', type=str, default='/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_train.csv', help='Path to the dataset file.')
|
| 123 |
+
parser.add_argument('--data_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/gt/er.txt', help='Path to the input sequence file.')
|
| 124 |
+
parser.add_argument('--output_dir', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification', help='Directory to save the fine-tuned model.')
|
| 125 |
+
parser.add_argument('--pre_trained_model_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/output/pretrain:1800ms:64hs:4l:8a:50s:64b:1000e:-5lr/bert_trained.seq_encoder.model.ep68', help='Path to the pre-trained BERT model.')
|
| 126 |
+
parser.add_argument('--vocab_file', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt', help='Path to the vocabulary file.')
|
| 127 |
+
parser.add_argument('--log_folder_path', type=str, default='/home/jupyter/bert/ratio_proportion_change3_1920/logs/oct', help='Path to the folder for saving logs.')
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
opt = parser.parse_args()
|
| 131 |
+
main(opt)
|
new_test_saved_finetuned_model.py
ADDED
|
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.optim import Adam
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import pickle
|
| 8 |
+
print("here1",os.getcwd())
|
| 9 |
+
from src.dataset import TokenizerDataset, TokenizerDatasetForCalibration
|
| 10 |
+
from src.vocab import Vocab
|
| 11 |
+
print("here3",os.getcwd())
|
| 12 |
+
from src.bert import BERT
|
| 13 |
+
from src.seq_model import BERTSM
|
| 14 |
+
from src.classifier_model import BERTForClassification, BERTForClassificationWithFeats
|
| 15 |
+
# from src.new_finetuning.optim_schedule import ScheduledOptim
|
| 16 |
+
import metrics, recalibration, visualization
|
| 17 |
+
from recalibration import ModelWithTemperature
|
| 18 |
+
import tqdm
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, roc_curve, roc_auc_score
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
import seaborn as sns
|
| 26 |
+
import pandas as pd
|
| 27 |
+
from collections import defaultdict
|
| 28 |
+
print("here3",os.getcwd())
|
| 29 |
+
class BERTFineTuneTrainer:
|
| 30 |
+
|
| 31 |
+
def __init__(self, bertFinetunedClassifierwithFeats: BERT, #BERTForClassificationWithFeats
|
| 32 |
+
vocab_size: int, test_dataloader: DataLoader = None,
|
| 33 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 34 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
| 35 |
+
num_labels=2, log_folder_path: str = None):
|
| 36 |
+
"""
|
| 37 |
+
:param bert: BERT model which you want to train
|
| 38 |
+
:param vocab_size: total word vocab size
|
| 39 |
+
:param test_dataloader: test dataset data loader [can be None]
|
| 40 |
+
:param lr: learning rate of optimizer
|
| 41 |
+
:param betas: Adam optimizer betas
|
| 42 |
+
:param weight_decay: Adam optimizer weight decay param
|
| 43 |
+
:param with_cuda: traning with cuda
|
| 44 |
+
:param log_freq: logging frequency of the batch iteration
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 48 |
+
# cuda_condition = torch.cuda.is_available() and with_cuda
|
| 49 |
+
# self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 50 |
+
self.device = torch.device("cpu") #torch.device("cuda:0" if cuda_condition else "cpu")
|
| 51 |
+
# print(cuda_condition, " Device used = ", self.device)
|
| 52 |
+
print(" Device used = ", self.device)
|
| 53 |
+
|
| 54 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
| 55 |
+
|
| 56 |
+
# This BERT model will be saved every epoch
|
| 57 |
+
self.model = bertFinetunedClassifierwithFeats.to("cpu")
|
| 58 |
+
print(self.model.parameters())
|
| 59 |
+
for param in self.model.parameters():
|
| 60 |
+
param.requires_grad = False
|
| 61 |
+
# Initialize the BERT Language Model, with BERT model
|
| 62 |
+
# self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
| 63 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
| 64 |
+
# self.model = bertFinetunedClassifierwithFeats
|
| 65 |
+
# print(self.model.bert.parameters())
|
| 66 |
+
# for param in self.model.bert.parameters():
|
| 67 |
+
# param.requires_grad = False
|
| 68 |
+
# BERTForClassificationWithFeats(self.bert, num_labels, 18).to(self.device)
|
| 69 |
+
|
| 70 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
| 71 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 72 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
| 73 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 74 |
+
# self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
| 75 |
+
|
| 76 |
+
# Setting the train, validation and test data loader
|
| 77 |
+
# self.train_data = train_dataloader
|
| 78 |
+
# self.val_data = val_dataloader
|
| 79 |
+
self.test_data = test_dataloader
|
| 80 |
+
|
| 81 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
| 82 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
| 83 |
+
# self.optim_schedule = ScheduledOptim(self.optim, self.model.bert.hidden, n_warmup_steps=warmup_steps)
|
| 84 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
| 85 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 86 |
+
|
| 87 |
+
# if num_labels == 1:
|
| 88 |
+
# self.criterion = nn.MSELoss()
|
| 89 |
+
# elif num_labels == 2:
|
| 90 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 91 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
| 92 |
+
# elif num_labels > 2:
|
| 93 |
+
# self.criterion = nn.CrossEntropyLoss()
|
| 94 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
self.log_freq = log_freq
|
| 98 |
+
self.log_folder_path = log_folder_path
|
| 99 |
+
# self.workspace_name = workspace_name
|
| 100 |
+
# self.finetune_task = finetune_task
|
| 101 |
+
# self.save_model = False
|
| 102 |
+
# self.avg_loss = 10000
|
| 103 |
+
self.start_time = time.time()
|
| 104 |
+
# self.probability_list = []
|
| 105 |
+
for fi in ['test']: #'val',
|
| 106 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
| 107 |
+
f.close()
|
| 108 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 109 |
+
|
| 110 |
+
# def train(self, epoch):
|
| 111 |
+
# self.iteration(epoch, self.train_data)
|
| 112 |
+
|
| 113 |
+
# def val(self, epoch):
|
| 114 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
| 115 |
+
|
| 116 |
+
def test(self, epoch):
|
| 117 |
+
# if epoch == 0:
|
| 118 |
+
# self.avg_loss = 10000
|
| 119 |
+
self.iteration(epoch, self.test_data, phase="test")
|
| 120 |
+
|
| 121 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
| 122 |
+
"""
|
| 123 |
+
loop over the data_loader for training or testing
|
| 124 |
+
if on train status, backward operation is activated
|
| 125 |
+
and also auto save the model every peoch
|
| 126 |
+
|
| 127 |
+
:param epoch: current epoch index
|
| 128 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
| 129 |
+
:param train: boolean value of is train or test
|
| 130 |
+
:return: None
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
# Setting the tqdm progress bar
|
| 134 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 135 |
+
desc="EP_%s:%d" % (phase, epoch),
|
| 136 |
+
total=len(data_loader),
|
| 137 |
+
bar_format="{l_bar}{r_bar}")
|
| 138 |
+
|
| 139 |
+
avg_loss = 0.0
|
| 140 |
+
total_correct = 0
|
| 141 |
+
total_element = 0
|
| 142 |
+
plabels = []
|
| 143 |
+
tlabels = []
|
| 144 |
+
probabs = []
|
| 145 |
+
|
| 146 |
+
if phase == "train":
|
| 147 |
+
self.model.train()
|
| 148 |
+
else:
|
| 149 |
+
self.model.eval()
|
| 150 |
+
# self.probability_list = []
|
| 151 |
+
|
| 152 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
| 153 |
+
sys.stdout = f
|
| 154 |
+
for i, data in data_iter:
|
| 155 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 156 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 157 |
+
if phase == "train":
|
| 158 |
+
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
| 159 |
+
else:
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
logits = self.model.forward(data["input"].cpu(), data["segment_label"].cpu(), data["feat"].cpu())
|
| 162 |
+
|
| 163 |
+
logits = logits.cpu()
|
| 164 |
+
loss = self.criterion(logits, data["label"])
|
| 165 |
+
# if torch.cuda.device_count() > 1:
|
| 166 |
+
# loss = loss.mean()
|
| 167 |
+
|
| 168 |
+
# 3. backward and optimization only in train
|
| 169 |
+
# if phase == "train":
|
| 170 |
+
# self.optim_schedule.zero_grad()
|
| 171 |
+
# loss.backward()
|
| 172 |
+
# self.optim_schedule.step_and_update_lr()
|
| 173 |
+
|
| 174 |
+
# prediction accuracy
|
| 175 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
| 176 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
| 177 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
| 178 |
+
# self.probability_list.append(probs)
|
| 179 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
| 180 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
| 181 |
+
tlabels.extend(data['label'].cpu().numpy())
|
| 182 |
+
|
| 183 |
+
# Compare predicted labels to true labels and calculate accuracy
|
| 184 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
| 185 |
+
|
| 186 |
+
avg_loss += loss.item()
|
| 187 |
+
total_correct += correct
|
| 188 |
+
# total_element += true_labels.nelement()
|
| 189 |
+
total_element += data["label"].nelement()
|
| 190 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
| 191 |
+
|
| 192 |
+
post_fix = {
|
| 193 |
+
"epoch": epoch,
|
| 194 |
+
"iter": i,
|
| 195 |
+
"avg_loss": avg_loss / (i + 1),
|
| 196 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
| 197 |
+
"loss": loss.item()
|
| 198 |
+
}
|
| 199 |
+
if i % self.log_freq == 0:
|
| 200 |
+
data_iter.write(str(post_fix))
|
| 201 |
+
|
| 202 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
| 203 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
| 204 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
| 205 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
| 206 |
+
end_time = time.time()
|
| 207 |
+
auc_score = roc_auc_score(tlabels, plabels)
|
| 208 |
+
final_msg = {
|
| 209 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 210 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 211 |
+
"total_acc": total_correct * 100.0 / total_element,
|
| 212 |
+
"precisions": precisions,
|
| 213 |
+
"recalls": recalls,
|
| 214 |
+
"f1_scores": f1_scores,
|
| 215 |
+
# "confusion_matrix": f"{cmatrix}",
|
| 216 |
+
# "true_labels": f"{tlabels}",
|
| 217 |
+
# "predicted_labels": f"{plabels}",
|
| 218 |
+
"time_taken_from_start": end_time - self.start_time,
|
| 219 |
+
"auc_score":auc_score
|
| 220 |
+
}
|
| 221 |
+
with open("result.txt", 'w') as file:
|
| 222 |
+
for key, value in final_msg.items():
|
| 223 |
+
file.write(f"{key}: {value}\n")
|
| 224 |
+
print(final_msg)
|
| 225 |
+
fpr, tpr, thresholds = roc_curve(tlabels, plabels)
|
| 226 |
+
with open("roc_data.pkl", "wb") as f:
|
| 227 |
+
pickle.dump((fpr, tpr, thresholds), f)
|
| 228 |
+
print(final_msg)
|
| 229 |
+
f.close()
|
| 230 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
| 231 |
+
sys.stdout = f1
|
| 232 |
+
final_msg = {
|
| 233 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 234 |
+
"confusion_matrix": f"{cmatrix}",
|
| 235 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
| 236 |
+
"predicted_labels": f"{plabels}",
|
| 237 |
+
"probabilities": f"{probabs}",
|
| 238 |
+
"time_taken_from_start": end_time - self.start_time
|
| 239 |
+
}
|
| 240 |
+
print(final_msg)
|
| 241 |
+
f1.close()
|
| 242 |
+
sys.stdout = sys.__stdout__
|
| 243 |
+
sys.stdout = sys.__stdout__
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class BERTFineTuneCalibratedTrainer:
|
| 248 |
+
|
| 249 |
+
def __init__(self, bertFinetunedClassifierwithFeats: BERT, #BERTForClassificationWithFeats
|
| 250 |
+
vocab_size: int, test_dataloader: DataLoader = None,
|
| 251 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 252 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
| 253 |
+
num_labels=2, log_folder_path: str = None):
|
| 254 |
+
"""
|
| 255 |
+
:param bert: BERT model which you want to train
|
| 256 |
+
:param vocab_size: total word vocab size
|
| 257 |
+
:param test_dataloader: test dataset data loader [can be None]
|
| 258 |
+
:param lr: learning rate of optimizer
|
| 259 |
+
:param betas: Adam optimizer betas
|
| 260 |
+
:param weight_decay: Adam optimizer weight decay param
|
| 261 |
+
:param with_cuda: traning with cuda
|
| 262 |
+
:param log_freq: logging frequency of the batch iteration
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 266 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 267 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 268 |
+
print(cuda_condition, " Device used = ", self.device)
|
| 269 |
+
|
| 270 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
| 271 |
+
|
| 272 |
+
# This BERT model will be saved every epoch
|
| 273 |
+
self.model = bertFinetunedClassifierwithFeats
|
| 274 |
+
print(self.model.parameters())
|
| 275 |
+
for param in self.model.parameters():
|
| 276 |
+
param.requires_grad = False
|
| 277 |
+
# Initialize the BERT Language Model, with BERT model
|
| 278 |
+
# self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
| 279 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
| 280 |
+
# self.model = bertFinetunedClassifierwithFeats
|
| 281 |
+
# print(self.model.bert.parameters())
|
| 282 |
+
# for param in self.model.bert.parameters():
|
| 283 |
+
# param.requires_grad = False
|
| 284 |
+
# BERTForClassificationWithFeats(self.bert, num_labels, 18).to(self.device)
|
| 285 |
+
|
| 286 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
| 287 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 288 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
| 289 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 290 |
+
# self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
| 291 |
+
|
| 292 |
+
# Setting the train, validation and test data loader
|
| 293 |
+
# self.train_data = train_dataloader
|
| 294 |
+
# self.val_data = val_dataloader
|
| 295 |
+
self.test_data = test_dataloader
|
| 296 |
+
|
| 297 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
| 298 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
| 299 |
+
# self.optim_schedule = ScheduledOptim(self.optim, self.model.bert.hidden, n_warmup_steps=warmup_steps)
|
| 300 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
| 301 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 302 |
+
|
| 303 |
+
# if num_labels == 1:
|
| 304 |
+
# self.criterion = nn.MSELoss()
|
| 305 |
+
# elif num_labels == 2:
|
| 306 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 307 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
| 308 |
+
# elif num_labels > 2:
|
| 309 |
+
# self.criterion = nn.CrossEntropyLoss()
|
| 310 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
self.log_freq = log_freq
|
| 314 |
+
self.log_folder_path = log_folder_path
|
| 315 |
+
# self.workspace_name = workspace_name
|
| 316 |
+
# self.finetune_task = finetune_task
|
| 317 |
+
# self.save_model = False
|
| 318 |
+
# self.avg_loss = 10000
|
| 319 |
+
self.start_time = time.time()
|
| 320 |
+
# self.probability_list = []
|
| 321 |
+
for fi in ['test']: #'val',
|
| 322 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
| 323 |
+
f.close()
|
| 324 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 325 |
+
|
| 326 |
+
# def train(self, epoch):
|
| 327 |
+
# self.iteration(epoch, self.train_data)
|
| 328 |
+
|
| 329 |
+
# def val(self, epoch):
|
| 330 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
| 331 |
+
|
| 332 |
+
def test(self, epoch):
|
| 333 |
+
# if epoch == 0:
|
| 334 |
+
# self.avg_loss = 10000
|
| 335 |
+
self.iteration(epoch, self.test_data, phase="test")
|
| 336 |
+
|
| 337 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
| 338 |
+
"""
|
| 339 |
+
loop over the data_loader for training or testing
|
| 340 |
+
if on train status, backward operation is activated
|
| 341 |
+
and also auto save the model every peoch
|
| 342 |
+
|
| 343 |
+
:param epoch: current epoch index
|
| 344 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
| 345 |
+
:param train: boolean value of is train or test
|
| 346 |
+
:return: None
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
# Setting the tqdm progress bar
|
| 350 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 351 |
+
desc="EP_%s:%d" % (phase, epoch),
|
| 352 |
+
total=len(data_loader),
|
| 353 |
+
bar_format="{l_bar}{r_bar}")
|
| 354 |
+
|
| 355 |
+
avg_loss = 0.0
|
| 356 |
+
total_correct = 0
|
| 357 |
+
total_element = 0
|
| 358 |
+
plabels = []
|
| 359 |
+
tlabels = []
|
| 360 |
+
probabs = []
|
| 361 |
+
|
| 362 |
+
if phase == "train":
|
| 363 |
+
self.model.train()
|
| 364 |
+
else:
|
| 365 |
+
self.model.eval()
|
| 366 |
+
# self.probability_list = []
|
| 367 |
+
|
| 368 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
| 369 |
+
sys.stdout = f
|
| 370 |
+
for i, data in data_iter:
|
| 371 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 372 |
+
# print(data_pair[0])
|
| 373 |
+
data = {key: value.to(self.device) for key, value in data[0].items()}
|
| 374 |
+
# print(f"data : {data}")
|
| 375 |
+
# data = {key: value.to(self.device) for key, value in data.items()}
|
| 376 |
+
|
| 377 |
+
# if phase == "train":
|
| 378 |
+
# logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
| 379 |
+
# else:
|
| 380 |
+
with torch.no_grad():
|
| 381 |
+
# logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
| 382 |
+
logits = self.model.forward(data)
|
| 383 |
+
|
| 384 |
+
loss = self.criterion(logits, data["label"])
|
| 385 |
+
if torch.cuda.device_count() > 1:
|
| 386 |
+
loss = loss.mean()
|
| 387 |
+
|
| 388 |
+
# 3. backward and optimization only in train
|
| 389 |
+
# if phase == "train":
|
| 390 |
+
# self.optim_schedule.zero_grad()
|
| 391 |
+
# loss.backward()
|
| 392 |
+
# self.optim_schedule.step_and_update_lr()
|
| 393 |
+
|
| 394 |
+
# prediction accuracy
|
| 395 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
| 396 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
| 397 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
| 398 |
+
# self.probability_list.append(probs)
|
| 399 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
| 400 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
| 401 |
+
tlabels.extend(data['label'].cpu().numpy())
|
| 402 |
+
positive_class_probs = [prob[1] for prob in probabs]
|
| 403 |
+
|
| 404 |
+
# Compare predicted labels to true labels and calculate accuracy
|
| 405 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
| 406 |
+
|
| 407 |
+
avg_loss += loss.item()
|
| 408 |
+
total_correct += correct
|
| 409 |
+
# total_element += true_labels.nelement()
|
| 410 |
+
total_element += data["label"].nelement()
|
| 411 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
| 412 |
+
|
| 413 |
+
post_fix = {
|
| 414 |
+
"epoch": epoch,
|
| 415 |
+
"iter": i,
|
| 416 |
+
"avg_loss": avg_loss / (i + 1),
|
| 417 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
| 418 |
+
"loss": loss.item()
|
| 419 |
+
}
|
| 420 |
+
if i % self.log_freq == 0:
|
| 421 |
+
data_iter.write(str(post_fix))
|
| 422 |
+
|
| 423 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
| 424 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
| 425 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
| 426 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
| 427 |
+
auc_score = roc_auc_score(tlabels, positive_class_probs)
|
| 428 |
+
end_time = time.time()
|
| 429 |
+
final_msg = {
|
| 430 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 431 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 432 |
+
"total_acc": total_correct * 100.0 / total_element,
|
| 433 |
+
"precisions": precisions,
|
| 434 |
+
"recalls": recalls,
|
| 435 |
+
"f1_scores": f1_scores,
|
| 436 |
+
"auc_score":auc_score,
|
| 437 |
+
# "confusion_matrix": f"{cmatrix}",
|
| 438 |
+
# "true_labels": f"{tlabels}",
|
| 439 |
+
# "predicted_labels": f"{plabels}",
|
| 440 |
+
"time_taken_from_start": end_time - self.start_time
|
| 441 |
+
}
|
| 442 |
+
with open("result.txt", 'w') as file:
|
| 443 |
+
for key, value in final_msg.items():
|
| 444 |
+
file.write(f"{key}: {value}\n")
|
| 445 |
+
|
| 446 |
+
print(final_msg)
|
| 447 |
+
fpr, tpr, thresholds = roc_curve(tlabels, positive_class_probs)
|
| 448 |
+
f.close()
|
| 449 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
| 450 |
+
sys.stdout = f1
|
| 451 |
+
final_msg = {
|
| 452 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 453 |
+
"confusion_matrix": f"{cmatrix}",
|
| 454 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
| 455 |
+
"predicted_labels": f"{plabels}",
|
| 456 |
+
"probabilities": f"{probabs}",
|
| 457 |
+
"time_taken_from_start": end_time - self.start_time
|
| 458 |
+
}
|
| 459 |
+
print(final_msg)
|
| 460 |
+
f1.close()
|
| 461 |
+
sys.stdout = sys.__stdout__
|
| 462 |
+
sys.stdout = sys.__stdout__
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def train():
|
| 467 |
+
parser = argparse.ArgumentParser()
|
| 468 |
+
|
| 469 |
+
parser.add_argument('-workspace_name', type=str, default=None)
|
| 470 |
+
parser.add_argument('-code', type=str, default=None, help="folder for pretraining outputs and logs")
|
| 471 |
+
parser.add_argument('-finetune_task', type=str, default=None, help="folder inside finetuning")
|
| 472 |
+
parser.add_argument("-attention", type=bool, default=False, help="analyse attention scores")
|
| 473 |
+
parser.add_argument("-diff_test_folder", type=bool, default=False, help="use for different test folder")
|
| 474 |
+
parser.add_argument("-embeddings", type=bool, default=False, help="get and analyse embeddings")
|
| 475 |
+
parser.add_argument('-embeddings_file_name', type=str, default=None, help="file name of embeddings")
|
| 476 |
+
parser.add_argument("-pretrain", type=bool, default=False, help="pretraining: true, or false")
|
| 477 |
+
# parser.add_argument('-opts', nargs='+', type=str, default=None, help='List of optional steps')
|
| 478 |
+
parser.add_argument("-max_mask", type=int, default=0.15, help="% of input tokens selected for masking")
|
| 479 |
+
# parser.add_argument("-p", "--pretrain_dataset", type=str, default="pretraining/pretrain.txt", help="pretraining dataset for bert")
|
| 480 |
+
# parser.add_argument("-pv", "--pretrain_val_dataset", type=str, default="pretraining/test.txt", help="pretraining validation dataset for bert")
|
| 481 |
+
# default="finetuning/test.txt",
|
| 482 |
+
parser.add_argument("-vocab_path", type=str, default="pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
| 483 |
+
|
| 484 |
+
parser.add_argument("-train_dataset_path", type=str, default="train.txt", help="fine tune train dataset for progress classifier")
|
| 485 |
+
parser.add_argument("-val_dataset_path", type=str, default="val.txt", help="test set for evaluate fine tune train set")
|
| 486 |
+
parser.add_argument("-test_dataset_path", type=str, default="test.txt", help="test set for evaluate fine tune train set")
|
| 487 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
| 488 |
+
parser.add_argument("-train_label_path", type=str, default="train_label.txt", help="fine tune train dataset for progress classifier")
|
| 489 |
+
parser.add_argument("-val_label_path", type=str, default="val_label.txt", help="test set for evaluate fine tune train set")
|
| 490 |
+
parser.add_argument("-test_label_path", type=str, default="test_label.txt", help="test set for evaluate fine tune train set")
|
| 491 |
+
##### change Checkpoint for finetuning
|
| 492 |
+
parser.add_argument("-pretrained_bert_checkpoint", type=str, default=None, help="checkpoint of saved pretrained bert model")
|
| 493 |
+
parser.add_argument("-finetuned_bert_classifier_checkpoint", type=str, default=None, help="checkpoint of saved finetuned bert model") #."output_feb09/bert_trained.model.ep40"
|
| 494 |
+
#."output_feb09/bert_trained.model.ep40"
|
| 495 |
+
parser.add_argument('-check_epoch', type=int, default=None)
|
| 496 |
+
|
| 497 |
+
parser.add_argument("-hs", "--hidden", type=int, default=64, help="hidden size of transformer model") #64
|
| 498 |
+
parser.add_argument("-l", "--layers", type=int, default=4, help="number of layers") #4
|
| 499 |
+
parser.add_argument("-a", "--attn_heads", type=int, default=4, help="number of attention heads") #8
|
| 500 |
+
parser.add_argument("-s", "--seq_len", type=int, default=5, help="maximum sequence length")
|
| 501 |
+
|
| 502 |
+
parser.add_argument("-b", "--batch_size", type=int, default=500, help="number of batch_size") #64
|
| 503 |
+
parser.add_argument("-e", "--epochs", type=int, default=1)#1501, help="number of epochs") #501
|
| 504 |
+
# Use 50 for pretrain, and 10 for fine tune
|
| 505 |
+
parser.add_argument("-w", "--num_workers", type=int, default=0, help="dataloader worker size")
|
| 506 |
+
|
| 507 |
+
# Later run with cuda
|
| 508 |
+
parser.add_argument("--with_cuda", type=bool, default=False, help="training with CUDA: true, or false")
|
| 509 |
+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
|
| 510 |
+
# parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
|
| 511 |
+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
|
| 512 |
+
# parser.add_argument("--on_memory", type=bool, default=False, help="Loading on memory: true or false")
|
| 513 |
+
|
| 514 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="dropout of network")
|
| 515 |
+
parser.add_argument("--lr", type=float, default=1e-05, help="learning rate of adam") #1e-3
|
| 516 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
|
| 517 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
|
| 518 |
+
parser.add_argument("--adam_beta2", type=float, default=0.98, help="adam first beta value") #0.999
|
| 519 |
+
|
| 520 |
+
parser.add_argument("-o", "--output_path", type=str, default="bert_trained.seq_encoder.model", help="ex)output/bert.model")
|
| 521 |
+
# parser.add_argument("-o", "--output_path", type=str, default="output/bert_fine_tuned.model", help="ex)output/bert.model")
|
| 522 |
+
|
| 523 |
+
args = parser.parse_args()
|
| 524 |
+
for k,v in vars(args).items():
|
| 525 |
+
if 'path' in k:
|
| 526 |
+
if v:
|
| 527 |
+
if k == "output_path":
|
| 528 |
+
if args.code:
|
| 529 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.code}/"+v)
|
| 530 |
+
elif args.finetune_task:
|
| 531 |
+
setattr(args, f"{k}", args.workspace_name+f"/output/{args.finetune_task}/"+v)
|
| 532 |
+
else:
|
| 533 |
+
setattr(args, f"{k}", args.workspace_name+"/output/"+v)
|
| 534 |
+
elif k != "vocab_path":
|
| 535 |
+
if args.pretrain:
|
| 536 |
+
setattr(args, f"{k}", args.workspace_name+"/pretraining/"+v)
|
| 537 |
+
else:
|
| 538 |
+
if args.code:
|
| 539 |
+
setattr(args, f"{k}", args.workspace_name+f"/{args.code}/"+v)
|
| 540 |
+
elif args.finetune_task:
|
| 541 |
+
if args.diff_test_folder and "test" in k:
|
| 542 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/"+v)
|
| 543 |
+
else:
|
| 544 |
+
setattr(args, f"{k}", args.workspace_name+f"/finetuning/{args.finetune_task}/"+v)
|
| 545 |
+
else:
|
| 546 |
+
setattr(args, f"{k}", args.workspace_name+"/finetuning/"+v)
|
| 547 |
+
else:
|
| 548 |
+
setattr(args, f"{k}", args.workspace_name+"/"+v)
|
| 549 |
+
|
| 550 |
+
print(f"args.{k} : {getattr(args, f'{k}')}")
|
| 551 |
+
|
| 552 |
+
print("Loading Vocab", args.vocab_path)
|
| 553 |
+
vocab_obj = Vocab(args.vocab_path)
|
| 554 |
+
vocab_obj.load_vocab()
|
| 555 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
print("Testing using finetuned model......")
|
| 559 |
+
print("Loading Test Dataset", args.test_dataset_path)
|
| 560 |
+
test_dataset = TokenizerDataset(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
| 561 |
+
# test_dataset = TokenizerDatasetForCalibration(args.test_dataset_path, args.test_label_path, vocab_obj, seq_len=args.seq_len)
|
| 562 |
+
|
| 563 |
+
print("Creating Dataloader...")
|
| 564 |
+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 565 |
+
|
| 566 |
+
print("Load fine-tuned BERT classifier model with feats")
|
| 567 |
+
# cuda_condition = torch.cuda.is_available() and args.with_cuda
|
| 568 |
+
device = torch.device("cpu") #torch.device("cuda:0" if cuda_condition else "cpu")
|
| 569 |
+
finetunedBERTclassifier = torch.load(args.finetuned_bert_classifier_checkpoint, map_location=device)
|
| 570 |
+
if isinstance(finetunedBERTclassifier, torch.nn.DataParallel):
|
| 571 |
+
finetunedBERTclassifier = finetunedBERTclassifier.module
|
| 572 |
+
|
| 573 |
+
new_log_folder = f"{args.workspace_name}/logs"
|
| 574 |
+
new_output_folder = f"{args.workspace_name}/output"
|
| 575 |
+
if args.finetune_task: # is sent almost all the time
|
| 576 |
+
new_log_folder = f"{args.workspace_name}/logs/{args.finetune_task}"
|
| 577 |
+
new_output_folder = f"{args.workspace_name}/output/{args.finetune_task}"
|
| 578 |
+
|
| 579 |
+
if not os.path.exists(new_log_folder):
|
| 580 |
+
os.makedirs(new_log_folder)
|
| 581 |
+
if not os.path.exists(new_output_folder):
|
| 582 |
+
os.makedirs(new_output_folder)
|
| 583 |
+
|
| 584 |
+
print("Creating BERT Fine Tuned Test Trainer")
|
| 585 |
+
trainer = BERTFineTuneTrainer(finetunedBERTclassifier,
|
| 586 |
+
len(vocab_obj.vocab), test_dataloader=test_data_loader,
|
| 587 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
| 588 |
+
with_cuda=args.with_cuda, cuda_devices = args.cuda_devices, log_freq=args.log_freq,
|
| 589 |
+
workspace_name = args.workspace_name, num_labels=args.num_labels, log_folder_path=new_log_folder)
|
| 590 |
+
|
| 591 |
+
# trainer = BERTFineTuneCalibratedTrainer(finetunedBERTclassifier,
|
| 592 |
+
# len(vocab_obj.vocab), test_dataloader=test_data_loader,
|
| 593 |
+
# lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
| 594 |
+
# with_cuda=args.with_cuda, cuda_devices = args.cuda_devices, log_freq=args.log_freq,
|
| 595 |
+
# workspace_name = args.workspace_name, num_labels=args.num_labels, log_folder_path=new_log_folder)
|
| 596 |
+
print("Testing fine-tuned model Start....")
|
| 597 |
+
start_time = time.time()
|
| 598 |
+
repoch = range(args.check_epoch, args.epochs) if args.check_epoch else range(args.epochs)
|
| 599 |
+
counter = 0
|
| 600 |
+
# patience = 10
|
| 601 |
+
for epoch in repoch:
|
| 602 |
+
print(f'Test Epoch {epoch} Starts, Time: {time.strftime("%D %T", time.localtime(time.time()))}')
|
| 603 |
+
trainer.test(epoch)
|
| 604 |
+
# pickle.dump(trainer.probability_list, open(f"{args.workspace_name}/output/aaai/change4_mid_prob_{epoch}.pkl","wb"))
|
| 605 |
+
print(f'Test Epoch {epoch} Ends, Time: {time.strftime("%D %T", time.localtime(time.time()))} \n')
|
| 606 |
+
end_time = time.time()
|
| 607 |
+
print("Time Taken to fine-tune model = ", end_time - start_time)
|
| 608 |
+
print(f'Pretraining Ends, Time: {time.strftime("%D %T", time.localtime(end_time))}')
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
if __name__ == "__main__":
|
| 613 |
+
train()
|
plot.png
CHANGED
|
|
prepare_pretraining_input_vocab_file.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ratio_proportion_change3_2223/sch_largest_100-coded/pretraining/vocab.txt
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[MASK]
|
| 4 |
+
[CLS]
|
| 5 |
+
[SEP]
|
| 6 |
+
DenominatorFactor
|
| 7 |
+
DenominatorQuantity1-0
|
| 8 |
+
DenominatorQuantity1-1
|
| 9 |
+
DenominatorQuantity1-2
|
| 10 |
+
EquationAnswer
|
| 11 |
+
FinalAnswer-0
|
| 12 |
+
FinalAnswer-1
|
| 13 |
+
FinalAnswer-2
|
| 14 |
+
FinalAnswerDirection-0
|
| 15 |
+
FinalAnswerDirection-1
|
| 16 |
+
FinalAnswerDirection-2
|
| 17 |
+
FirstRow1:1
|
| 18 |
+
FirstRow1:2
|
| 19 |
+
FirstRow2:1
|
| 20 |
+
FirstRow2:2
|
| 21 |
+
NumeratorFactor
|
| 22 |
+
NumeratorQuantity1-0
|
| 23 |
+
NumeratorQuantity1-1
|
| 24 |
+
NumeratorQuantity1-2
|
| 25 |
+
NumeratorQuantity2-0
|
| 26 |
+
NumeratorQuantity2-1
|
| 27 |
+
NumeratorQuantity2-2
|
| 28 |
+
OptionalTask_1
|
| 29 |
+
OptionalTask_2
|
| 30 |
+
PercentChange-0
|
| 31 |
+
PercentChange-1
|
| 32 |
+
PercentChange-2
|
| 33 |
+
SecondRow
|
| 34 |
+
ThirdRow
|
recalibration.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, optim
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
import metrics
|
| 6 |
+
|
| 7 |
+
class ModelWithTemperature(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
A thin decorator, which wraps a model with temperature scaling
|
| 10 |
+
model (nn.Module):
|
| 11 |
+
A classification neural network
|
| 12 |
+
NB: Output of the neural network should be the classification logits,
|
| 13 |
+
NOT the softmax (or log softmax)!
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, model, device="cpu"):
|
| 16 |
+
super(ModelWithTemperature, self).__init__()
|
| 17 |
+
self.model = model
|
| 18 |
+
self.device = torch.device(device)
|
| 19 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
| 20 |
+
|
| 21 |
+
def forward(self, input):
|
| 22 |
+
logits = self.model(input["input"], input["segment_label"], input["feat"])
|
| 23 |
+
return self.temperature_scale(logits)
|
| 24 |
+
|
| 25 |
+
def temperature_scale(self, logits):
|
| 26 |
+
"""
|
| 27 |
+
Perform temperature scaling on logits
|
| 28 |
+
"""
|
| 29 |
+
# Expand temperature to match the size of logits
|
| 30 |
+
temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)).to(self.device)
|
| 31 |
+
return logits / temperature
|
| 32 |
+
|
| 33 |
+
# This function probably should live outside of this class, but whatever
|
| 34 |
+
def set_temperature(self, valid_loader):
|
| 35 |
+
"""
|
| 36 |
+
Tune the tempearature of the model (using the validation set).
|
| 37 |
+
We're going to set it to optimize NLL.
|
| 38 |
+
valid_loader (DataLoader): validation set loader
|
| 39 |
+
"""
|
| 40 |
+
#self.cuda()
|
| 41 |
+
nll_criterion = nn.CrossEntropyLoss()
|
| 42 |
+
ece_criterion = metrics.ECELoss()
|
| 43 |
+
|
| 44 |
+
# First: collect all the logits and labels for the validation set
|
| 45 |
+
logits_list = []
|
| 46 |
+
labels_list = []
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
for input, label in valid_loader:
|
| 49 |
+
# print("Input = ", input["input"])
|
| 50 |
+
# print("Input = ", input["segment_label"])
|
| 51 |
+
# print("Input = ", input["feat"])
|
| 52 |
+
# input = input
|
| 53 |
+
logits = self.model(input["input"].to(self.device), input["segment_label"].to(self.device), input["feat"].to(self.device))
|
| 54 |
+
logits_list.append(logits)
|
| 55 |
+
labels_list.append(label)
|
| 56 |
+
logits = torch.cat(logits_list).to(self.device)
|
| 57 |
+
labels = torch.cat(labels_list).to(self.device)
|
| 58 |
+
|
| 59 |
+
# Calculate NLL and ECE before temperature scaling
|
| 60 |
+
before_temperature_nll = nll_criterion(logits, labels).item()
|
| 61 |
+
before_temperature_ece = ece_criterion.loss(logits.cpu().numpy(),labels.cpu().numpy(),15)
|
| 62 |
+
#before_temperature_ece = ece_criterion(logits, labels).item()
|
| 63 |
+
#ece_2 = ece_criterion_2.loss(logits,labels)
|
| 64 |
+
print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))
|
| 65 |
+
#print(ece_2)
|
| 66 |
+
# Next: optimize the temperature w.r.t. NLL
|
| 67 |
+
optimizer = optim.LBFGS([self.temperature], lr=0.005, max_iter=1000)
|
| 68 |
+
|
| 69 |
+
def eval():
|
| 70 |
+
loss = nll_criterion(self.temperature_scale(logits.to(self.device)), labels.to(self.device))
|
| 71 |
+
loss.backward()
|
| 72 |
+
return loss
|
| 73 |
+
optimizer.step(eval)
|
| 74 |
+
|
| 75 |
+
# Calculate NLL and ECE after temperature scaling
|
| 76 |
+
after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
|
| 77 |
+
after_temperature_ece = ece_criterion.loss(self.temperature_scale(logits).detach().cpu().numpy(),labels.cpu().numpy(),15)
|
| 78 |
+
#after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
|
| 79 |
+
print('Optimal temperature: %.3f' % self.temperature.item())
|
| 80 |
+
print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
|
| 81 |
+
|
| 82 |
+
return self
|
src/__pycache__/attention.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/attention.cpython-312.pyc and b/src/__pycache__/attention.cpython-312.pyc differ
|
|
|
src/__pycache__/bert.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/bert.cpython-312.pyc and b/src/__pycache__/bert.cpython-312.pyc differ
|
|
|
src/__pycache__/classifier_model.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/classifier_model.cpython-312.pyc and b/src/__pycache__/classifier_model.cpython-312.pyc differ
|
|
|
src/__pycache__/dataset.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/dataset.cpython-312.pyc and b/src/__pycache__/dataset.cpython-312.pyc differ
|
|
|
src/__pycache__/embedding.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/embedding.cpython-312.pyc and b/src/__pycache__/embedding.cpython-312.pyc differ
|
|
|
src/__pycache__/seq_model.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/seq_model.cpython-312.pyc and b/src/__pycache__/seq_model.cpython-312.pyc differ
|
|
|
src/__pycache__/transformer.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/transformer.cpython-312.pyc and b/src/__pycache__/transformer.cpython-312.pyc differ
|
|
|
src/__pycache__/transformer_component.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/transformer_component.cpython-312.pyc and b/src/__pycache__/transformer_component.cpython-312.pyc differ
|
|
|
src/__pycache__/vocab.cpython-312.pyc
CHANGED
|
Binary files a/src/__pycache__/vocab.cpython-312.pyc and b/src/__pycache__/vocab.cpython-312.pyc differ
|
|
|
src/attention.py
CHANGED
|
@@ -3,11 +3,19 @@ import torch.nn.functional as F
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class Attention(nn.Module):
|
| 9 |
"""
|
| 10 |
Compute 'Scaled Dot Product Attention
|
|
|
|
| 11 |
"""
|
| 12 |
|
| 13 |
def __init__(self):
|
|
@@ -45,7 +53,10 @@ class MultiHeadedAttention(nn.Module):
|
|
| 45 |
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
|
| 46 |
self.output_linear = nn.Linear(d_model, d_model)
|
| 47 |
self.attention = Attention()
|
|
|
|
|
|
|
| 48 |
|
|
|
|
| 49 |
self.dropout = nn.Dropout(p=dropout)
|
| 50 |
|
| 51 |
def forward(self, query, key, value, mask=None):
|
|
@@ -59,6 +70,14 @@ class MultiHeadedAttention(nn.Module):
|
|
| 59 |
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
| 60 |
for l, x in zip(self.linear_layers, (query, key, value))]
|
| 61 |
# 2) Apply attention on all the projected vectors in batch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
|
| 63 |
# torch.Size([64, 8, 100, 100])
|
| 64 |
# print("Attention", attn.shape)
|
|
@@ -67,4 +86,5 @@ class MultiHeadedAttention(nn.Module):
|
|
| 67 |
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
|
| 68 |
|
| 69 |
return self.output_linear(x)
|
| 70 |
-
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
|
| 5 |
import math
|
| 6 |
+
<<<<<<< HEAD
|
| 7 |
+
import pickle
|
| 8 |
+
|
| 9 |
+
class Attention(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Compute Scaled Dot Product Attention
|
| 12 |
+
=======
|
| 13 |
|
| 14 |
|
| 15 |
class Attention(nn.Module):
|
| 16 |
"""
|
| 17 |
Compute 'Scaled Dot Product Attention
|
| 18 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 19 |
"""
|
| 20 |
|
| 21 |
def __init__(self):
|
|
|
|
| 53 |
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
|
| 54 |
self.output_linear = nn.Linear(d_model, d_model)
|
| 55 |
self.attention = Attention()
|
| 56 |
+
<<<<<<< HEAD
|
| 57 |
+
=======
|
| 58 |
|
| 59 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 60 |
self.dropout = nn.Dropout(p=dropout)
|
| 61 |
|
| 62 |
def forward(self, query, key, value, mask=None):
|
|
|
|
| 70 |
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
| 71 |
for l, x in zip(self.linear_layers, (query, key, value))]
|
| 72 |
# 2) Apply attention on all the projected vectors in batch.
|
| 73 |
+
<<<<<<< HEAD
|
| 74 |
+
x, p_attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
|
| 75 |
+
|
| 76 |
+
# 3) "Concat" using a view and apply a final linear.
|
| 77 |
+
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
|
| 78 |
+
|
| 79 |
+
return self.output_linear(x), p_attn
|
| 80 |
+
=======
|
| 81 |
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
|
| 82 |
# torch.Size([64, 8, 100, 100])
|
| 83 |
# print("Attention", attn.shape)
|
|
|
|
| 86 |
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
|
| 87 |
|
| 88 |
return self.output_linear(x)
|
| 89 |
+
|
| 90 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/bert.py
CHANGED
|
@@ -1,7 +1,14 @@
|
|
| 1 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from transformer import TransformerBlock
|
| 4 |
from embedding import BERTEmbedding
|
|
|
|
| 5 |
|
| 6 |
class BERT(nn.Module):
|
| 7 |
"""
|
|
@@ -31,10 +38,37 @@ class BERT(nn.Module):
|
|
| 31 |
# multi-layers transformer blocks, deep network
|
| 32 |
self.transformer_blocks = nn.ModuleList(
|
| 33 |
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def forward(self, x, segment_info):
|
| 36 |
# attention masking for padded token
|
| 37 |
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
| 39 |
# print("bert mask: ", mask)
|
| 40 |
# embedding the indexed sequence to sequence of vectors
|
|
@@ -43,5 +77,6 @@ class BERT(nn.Module):
|
|
| 43 |
# running over multiple transformer blocks
|
| 44 |
for transformer in self.transformer_blocks:
|
| 45 |
x = transformer.forward(x, mask)
|
|
|
|
| 46 |
|
| 47 |
return x
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
+
<<<<<<< HEAD
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .transformer import TransformerBlock
|
| 6 |
+
from .embedding import BERTEmbedding
|
| 7 |
+
=======
|
| 8 |
|
| 9 |
from transformer import TransformerBlock
|
| 10 |
from embedding import BERTEmbedding
|
| 11 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 12 |
|
| 13 |
class BERT(nn.Module):
|
| 14 |
"""
|
|
|
|
| 38 |
# multi-layers transformer blocks, deep network
|
| 39 |
self.transformer_blocks = nn.ModuleList(
|
| 40 |
[TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
|
| 41 |
+
<<<<<<< HEAD
|
| 42 |
+
# self.attention_values = []
|
| 43 |
+
=======
|
| 44 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 45 |
|
| 46 |
def forward(self, x, segment_info):
|
| 47 |
# attention masking for padded token
|
| 48 |
# torch.ByteTensor([batch_size, 1, seq_len, seq_len)
|
| 49 |
+
<<<<<<< HEAD
|
| 50 |
+
|
| 51 |
+
device = x.device
|
| 52 |
+
|
| 53 |
+
masked = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1)
|
| 54 |
+
r,e,c = masked.shape
|
| 55 |
+
mask = torch.zeros((r, e, c), dtype=torch.bool).to(device=device)
|
| 56 |
+
|
| 57 |
+
for i in range(r):
|
| 58 |
+
mask[i] = masked[i].T*masked[i]
|
| 59 |
+
mask = mask.unsqueeze(1)
|
| 60 |
+
# mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
| 61 |
+
|
| 62 |
+
# print("bert mask: ", mask)
|
| 63 |
+
# embedding the indexed sequence to sequence of vectors
|
| 64 |
+
x = self.embedding(x, segment_info)
|
| 65 |
+
|
| 66 |
+
# self.attention_values = []
|
| 67 |
+
# running over multiple transformer blocks
|
| 68 |
+
for transformer in self.transformer_blocks:
|
| 69 |
+
x = transformer.forward(x, mask)
|
| 70 |
+
# self.attention_values.append(transformer.p_attn)
|
| 71 |
+
=======
|
| 72 |
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
|
| 73 |
# print("bert mask: ", mask)
|
| 74 |
# embedding the indexed sequence to sequence of vectors
|
|
|
|
| 77 |
# running over multiple transformer blocks
|
| 78 |
for transformer in self.transformer_blocks:
|
| 79 |
x = transformer.forward(x, mask)
|
| 80 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 81 |
|
| 82 |
return x
|
src/classifier_model.py
CHANGED
|
@@ -1,16 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
|
| 3 |
from bert import BERT
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class BERTForClassification(nn.Module):
|
| 7 |
"""
|
|
|
|
|
|
|
|
|
|
| 8 |
Progress Classifier Model
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
def __init__(self, bert: BERT, vocab_size, n_labels):
|
| 12 |
"""
|
| 13 |
:param bert: BERT model which should be trained
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
:param vocab_size: total vocab size for masked_lm
|
| 15 |
"""
|
| 16 |
|
|
@@ -21,4 +71,5 @@ class BERTForClassification(nn.Module):
|
|
| 21 |
|
| 22 |
def forward(self, x, segment_label):
|
| 23 |
x = self.bert(x, segment_label)
|
| 24 |
-
return x, self.linear(x[:, 0])
|
|
|
|
|
|
| 1 |
+
<<<<<<< HEAD
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from .bert import BERT
|
| 6 |
+
=======
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
from bert import BERT
|
| 10 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 11 |
|
| 12 |
|
| 13 |
class BERTForClassification(nn.Module):
|
| 14 |
"""
|
| 15 |
+
<<<<<<< HEAD
|
| 16 |
+
Fine-tune Task Classifier Model
|
| 17 |
+
=======
|
| 18 |
Progress Classifier Model
|
| 19 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 20 |
"""
|
| 21 |
|
| 22 |
def __init__(self, bert: BERT, vocab_size, n_labels):
|
| 23 |
"""
|
| 24 |
:param bert: BERT model which should be trained
|
| 25 |
+
<<<<<<< HEAD
|
| 26 |
+
:param vocab_size: total vocab size
|
| 27 |
+
:param n_labels: number of labels for the task
|
| 28 |
+
"""
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.bert = bert
|
| 31 |
+
self.linear = nn.Linear(self.bert.hidden, n_labels)
|
| 32 |
+
|
| 33 |
+
def forward(self, x, segment_label):
|
| 34 |
+
x = self.bert(x, segment_label)
|
| 35 |
+
return self.linear(x[:, 0])
|
| 36 |
+
|
| 37 |
+
class BERTForClassificationWithFeats(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Fine-tune Task Classifier Model
|
| 40 |
+
BERT embeddings concatenated with features
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(self, bert: BERT, n_labels, feat_size=9):
|
| 44 |
+
"""
|
| 45 |
+
:param bert: BERT model which should be trained
|
| 46 |
+
:param vocab_size: total vocab size
|
| 47 |
+
:param n_labels: number of labels for the task
|
| 48 |
+
"""
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.bert = bert
|
| 51 |
+
# self.linear1 = nn.Linear(self.bert.hidden+feat_size, 128)
|
| 52 |
+
self.linear = nn.Linear(self.bert.hidden+feat_size, n_labels)
|
| 53 |
+
# self.RELU = nn.ReLU()
|
| 54 |
+
# self.linear2 = nn.Linear(128, n_labels)
|
| 55 |
+
|
| 56 |
+
def forward(self, x, segment_label, feat):
|
| 57 |
+
x = self.bert(x, segment_label)
|
| 58 |
+
x = torch.cat((x[:, 0], feat), dim=-1)
|
| 59 |
+
# x = self.linear1(x)
|
| 60 |
+
# x = self.RELU(x)
|
| 61 |
+
# return self.linear2(x)
|
| 62 |
+
return self.linear(x)
|
| 63 |
+
=======
|
| 64 |
:param vocab_size: total vocab size for masked_lm
|
| 65 |
"""
|
| 66 |
|
|
|
|
| 71 |
|
| 72 |
def forward(self, x, segment_label):
|
| 73 |
x = self.bert(x, segment_label)
|
| 74 |
+
return x, self.linear(x[:, 0])
|
| 75 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/dataset.py
CHANGED
|
@@ -4,17 +4,28 @@ import pandas as pd
|
|
| 4 |
import numpy as np
|
| 5 |
import tqdm
|
| 6 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from vocab import Vocab
|
| 8 |
import pickle
|
| 9 |
import copy
|
| 10 |
from sklearn.preprocessing import OneHotEncoder
|
|
|
|
| 11 |
|
| 12 |
class PretrainerDataset(Dataset):
|
| 13 |
"""
|
| 14 |
Class name: PretrainDataset
|
| 15 |
|
| 16 |
"""
|
|
|
|
|
|
|
|
|
|
| 17 |
def __init__(self, dataset_path, vocab, seq_len=30, select_next_seq= False):
|
|
|
|
| 18 |
self.dataset_path = dataset_path
|
| 19 |
self.vocab = vocab # Vocab object
|
| 20 |
|
|
@@ -35,6 +46,22 @@ class PretrainerDataset(Dataset):
|
|
| 35 |
self.index_documents[i] = []
|
| 36 |
else:
|
| 37 |
self.index_documents[i].append(index)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
self.lines.append(line.split())
|
| 39 |
len_line = len(line.split())
|
| 40 |
seq_len_list.append(len_line)
|
|
@@ -49,6 +76,7 @@ class PretrainerDataset(Dataset):
|
|
| 49 |
print("Sequence length set at ", self.seq_len)
|
| 50 |
print("select_next_seq: ", self.select_next_seq)
|
| 51 |
print(len(self.index_documents))
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
def __len__(self):
|
|
@@ -56,6 +84,53 @@ class PretrainerDataset(Dataset):
|
|
| 56 |
|
| 57 |
def __getitem__(self, item):
|
| 58 |
token_a = self.lines[item]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
token_b = None
|
| 60 |
is_same_student = None
|
| 61 |
sa_masked = None
|
|
@@ -92,6 +167,7 @@ class PretrainerDataset(Dataset):
|
|
| 92 |
if self.select_next_seq:
|
| 93 |
output['is_same_student'] = is_same_student
|
| 94 |
# print(item, len(s1), len(s1_label), len(segment_label))
|
|
|
|
| 95 |
return {key: torch.tensor(value) for key, value in output.items()}
|
| 96 |
|
| 97 |
def random_mask_seq(self, tokens):
|
|
@@ -100,6 +176,28 @@ class PretrainerDataset(Dataset):
|
|
| 100 |
Output: masked token seq, output label
|
| 101 |
"""
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# masked_pos_label = {}
|
| 104 |
output_labels = []
|
| 105 |
output_tokens = copy.deepcopy(tokens)
|
|
@@ -108,17 +206,34 @@ class PretrainerDataset(Dataset):
|
|
| 108 |
for i, token in enumerate(tokens):
|
| 109 |
prob = random.random()
|
| 110 |
if prob < 0.15:
|
|
|
|
| 111 |
# chooses 15% of token positions at random
|
| 112 |
# prob /= 0.15
|
| 113 |
prob = random.random()
|
| 114 |
if prob < 0.8: #[MASK] token 80% of the time
|
| 115 |
output_tokens[i] = self.vocab.vocab['[MASK]']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
elif prob < 0.9: # a random token 10% of the time
|
| 117 |
# print(".......0.8-0.9......")
|
| 118 |
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1)
|
| 119 |
else: # the unchanged i-th token 10% of the time
|
| 120 |
# print(".......unchanged......")
|
| 121 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
|
|
|
| 122 |
# True Label
|
| 123 |
output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']))
|
| 124 |
# masked_pos_label[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
|
@@ -127,11 +242,53 @@ class PretrainerDataset(Dataset):
|
|
| 127 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
| 128 |
# Padded label
|
| 129 |
output_labels.append(self.vocab.vocab['[PAD]'])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# label_position = []
|
| 131 |
# label_tokens = []
|
| 132 |
# for k, v in masked_pos_label.items():
|
| 133 |
# label_position.append(k)
|
| 134 |
# label_tokens.append(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
return output_tokens, output_labels
|
| 136 |
|
| 137 |
def get_token_b(self, item):
|
|
@@ -167,6 +324,7 @@ class PretrainerDataset(Dataset):
|
|
| 167 |
else:
|
| 168 |
sb.pop()
|
| 169 |
return sa, sb
|
|
|
|
| 170 |
|
| 171 |
class TokenizerDataset(Dataset):
|
| 172 |
"""
|
|
@@ -174,15 +332,89 @@ class TokenizerDataset(Dataset):
|
|
| 174 |
Tokenize the data in the dataset
|
| 175 |
|
| 176 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
def __init__(self, dataset_path, label_path, vocab, seq_len=30, train=True):
|
| 178 |
self.dataset_path = dataset_path
|
| 179 |
self.label_path = label_path
|
| 180 |
self.vocab = vocab # Vocab object
|
| 181 |
self.encoder = OneHotEncoder(sparse_output=False)
|
|
|
|
| 182 |
|
| 183 |
# Related to input dataset file
|
| 184 |
self.lines = []
|
| 185 |
self.labels = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
self.labels = []
|
| 187 |
|
| 188 |
self.label_file = open(self.label_path, "r")
|
|
@@ -234,11 +466,14 @@ class TokenizerDataset(Dataset):
|
|
| 234 |
|
| 235 |
self.file = open(self.dataset_path, "r")
|
| 236 |
# index = 0
|
|
|
|
| 237 |
for line in self.file:
|
| 238 |
if line:
|
| 239 |
line = line.strip()
|
| 240 |
if line:
|
| 241 |
self.lines.append(line)
|
|
|
|
|
|
|
| 242 |
# if train:
|
| 243 |
# if index in indices_of_zeros:
|
| 244 |
# # if index in indices_of_prom:
|
|
@@ -253,17 +488,46 @@ class TokenizerDataset(Dataset):
|
|
| 253 |
# self.labels.append(labels[index])
|
| 254 |
# self.labels.append(progress[index])
|
| 255 |
# index += 1
|
|
|
|
| 256 |
self.file.close()
|
| 257 |
|
| 258 |
self.len = len(self.lines)
|
| 259 |
self.seq_len = seq_len
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels))
|
|
|
|
| 262 |
|
| 263 |
def __len__(self):
|
| 264 |
return self.len
|
| 265 |
|
| 266 |
def __getitem__(self, item):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
s1 = self.vocab.to_seq(self.lines[item], self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
| 269 |
s1_label = self.labels[item]
|
|
@@ -274,11 +538,132 @@ class TokenizerDataset(Dataset):
|
|
| 274 |
|
| 275 |
output = {'bert_input': s1,
|
| 276 |
'progress_status': s1_label,
|
|
|
|
| 277 |
'segment_label': segment_label}
|
| 278 |
return {key: torch.tensor(value) for key, value in output.items()}
|
| 279 |
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
# if __name__ == "__main__":
|
|
|
|
| 282 |
# # import pickle
|
| 283 |
# # k = pickle.load(open("dataset/CL4999_1920/unique_steps_list.pkl","rb"))
|
| 284 |
# # print(k)
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import tqdm
|
| 6 |
import random
|
| 7 |
+
<<<<<<< HEAD
|
| 8 |
+
from .vocab import Vocab
|
| 9 |
+
import pickle
|
| 10 |
+
import copy
|
| 11 |
+
# from sklearn.preprocessing import OneHotEncoder
|
| 12 |
+
=======
|
| 13 |
from vocab import Vocab
|
| 14 |
import pickle
|
| 15 |
import copy
|
| 16 |
from sklearn.preprocessing import OneHotEncoder
|
| 17 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 18 |
|
| 19 |
class PretrainerDataset(Dataset):
|
| 20 |
"""
|
| 21 |
Class name: PretrainDataset
|
| 22 |
|
| 23 |
"""
|
| 24 |
+
<<<<<<< HEAD
|
| 25 |
+
def __init__(self, dataset_path, vocab, seq_len=30, max_mask=0.15):
|
| 26 |
+
=======
|
| 27 |
def __init__(self, dataset_path, vocab, seq_len=30, select_next_seq= False):
|
| 28 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 29 |
self.dataset_path = dataset_path
|
| 30 |
self.vocab = vocab # Vocab object
|
| 31 |
|
|
|
|
| 46 |
self.index_documents[i] = []
|
| 47 |
else:
|
| 48 |
self.index_documents[i].append(index)
|
| 49 |
+
<<<<<<< HEAD
|
| 50 |
+
self.lines.append(line.split("\t"))
|
| 51 |
+
len_line = len(line.split("\t"))
|
| 52 |
+
seq_len_list.append(len_line)
|
| 53 |
+
index+=1
|
| 54 |
+
reader.close()
|
| 55 |
+
print("Sequence Stats: len: %s, min: %s, max: %s, average: %s"% (len(seq_len_list),
|
| 56 |
+
min(seq_len_list), max(seq_len_list), sum(seq_len_list)/len(seq_len_list)))
|
| 57 |
+
print("Unique Sequences: ", len({tuple(ll) for ll in self.lines}))
|
| 58 |
+
self.index_documents = {k:v for k,v in self.index_documents.items() if v}
|
| 59 |
+
print(len(self.index_documents))
|
| 60 |
+
self.seq_len = seq_len
|
| 61 |
+
print("Sequence length set at: ", self.seq_len)
|
| 62 |
+
self.max_mask = max_mask
|
| 63 |
+
print("% of input tokens selected for masking : ",self.max_mask)
|
| 64 |
+
=======
|
| 65 |
self.lines.append(line.split())
|
| 66 |
len_line = len(line.split())
|
| 67 |
seq_len_list.append(len_line)
|
|
|
|
| 76 |
print("Sequence length set at ", self.seq_len)
|
| 77 |
print("select_next_seq: ", self.select_next_seq)
|
| 78 |
print(len(self.index_documents))
|
| 79 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 80 |
|
| 81 |
|
| 82 |
def __len__(self):
|
|
|
|
| 84 |
|
| 85 |
def __getitem__(self, item):
|
| 86 |
token_a = self.lines[item]
|
| 87 |
+
<<<<<<< HEAD
|
| 88 |
+
# sa_masked = None
|
| 89 |
+
# sa_masked_label = None
|
| 90 |
+
# token_b = None
|
| 91 |
+
# is_same_student = None
|
| 92 |
+
# sb_masked = None
|
| 93 |
+
# sb_masked_label = None
|
| 94 |
+
|
| 95 |
+
# if self.select_next_seq:
|
| 96 |
+
# is_same_student, token_b = self.get_token_b(item)
|
| 97 |
+
# is_same_student = 1 if is_same_student else 0
|
| 98 |
+
# token_a1, token_b1 = self.truncate_to_max_seq(token_a, token_b)
|
| 99 |
+
# sa_masked, sa_masked_label = self.random_mask_seq(token_a1)
|
| 100 |
+
# sb_masked, sb_masked_label = self.random_mask_seq(token_b1)
|
| 101 |
+
# else:
|
| 102 |
+
token_a = token_a[:self.seq_len-2]
|
| 103 |
+
sa_masked, sa_masked_label, sa_masked_pos = self.random_mask_seq(token_a)
|
| 104 |
+
|
| 105 |
+
s1 = ([self.vocab.vocab['[CLS]']] + sa_masked + [self.vocab.vocab['[SEP]']])
|
| 106 |
+
s1_label = ([self.vocab.vocab['[PAD]']] + sa_masked_label + [self.vocab.vocab['[PAD]']])
|
| 107 |
+
segment_label = [1 for _ in range(len(s1))]
|
| 108 |
+
masked_pos = ([0] + sa_masked_pos + [0])
|
| 109 |
+
|
| 110 |
+
# if self.select_next_seq:
|
| 111 |
+
# s1 = s1 + sb_masked + [self.vocab.vocab['[SEP]']]
|
| 112 |
+
# s1_label = s1_label + sb_masked_label + [self.vocab.vocab['[PAD]']]
|
| 113 |
+
# segment_label = segment_label + [2 for _ in range(len(sb_masked)+1)]
|
| 114 |
+
|
| 115 |
+
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
|
| 116 |
+
s1.extend(padding)
|
| 117 |
+
s1_label.extend(padding)
|
| 118 |
+
segment_label.extend(padding)
|
| 119 |
+
masked_pos.extend(padding)
|
| 120 |
+
|
| 121 |
+
output = {'bert_input': s1,
|
| 122 |
+
'bert_label': s1_label,
|
| 123 |
+
'segment_label': segment_label,
|
| 124 |
+
'masked_pos': masked_pos}
|
| 125 |
+
# print(f"tokenA: {token_a}")
|
| 126 |
+
# print(f"output: {output}")
|
| 127 |
+
|
| 128 |
+
# if self.select_next_seq:
|
| 129 |
+
# output['is_same_student'] = is_same_student
|
| 130 |
+
|
| 131 |
+
# print(item, len(s1), len(s1_label), len(segment_label))
|
| 132 |
+
# print(f"{item}.")
|
| 133 |
+
=======
|
| 134 |
token_b = None
|
| 135 |
is_same_student = None
|
| 136 |
sa_masked = None
|
|
|
|
| 167 |
if self.select_next_seq:
|
| 168 |
output['is_same_student'] = is_same_student
|
| 169 |
# print(item, len(s1), len(s1_label), len(segment_label))
|
| 170 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 171 |
return {key: torch.tensor(value) for key, value in output.items()}
|
| 172 |
|
| 173 |
def random_mask_seq(self, tokens):
|
|
|
|
| 176 |
Output: masked token seq, output label
|
| 177 |
"""
|
| 178 |
|
| 179 |
+
<<<<<<< HEAD
|
| 180 |
+
masked_pos = []
|
| 181 |
+
output_labels = []
|
| 182 |
+
output_tokens = copy.deepcopy(tokens)
|
| 183 |
+
opt_step = False
|
| 184 |
+
for i, token in enumerate(tokens):
|
| 185 |
+
if token in ['OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor', 'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow', 'ThirdRow']:
|
| 186 |
+
opt_step = True
|
| 187 |
+
# if opt_step:
|
| 188 |
+
# prob = random.random()
|
| 189 |
+
# if prob < self.max_mask:
|
| 190 |
+
# output_tokens[i] = random.choice([3,7,8,9,11,12,13,14,15,16,22,23,24,25,26,27,30,31,32])
|
| 191 |
+
# masked_pos.append(1)
|
| 192 |
+
# else:
|
| 193 |
+
# output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
| 194 |
+
# masked_pos.append(0)
|
| 195 |
+
# output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']))
|
| 196 |
+
# opt_step = False
|
| 197 |
+
# else:
|
| 198 |
+
prob = random.random()
|
| 199 |
+
if prob < self.max_mask:
|
| 200 |
+
=======
|
| 201 |
# masked_pos_label = {}
|
| 202 |
output_labels = []
|
| 203 |
output_tokens = copy.deepcopy(tokens)
|
|
|
|
| 206 |
for i, token in enumerate(tokens):
|
| 207 |
prob = random.random()
|
| 208 |
if prob < 0.15:
|
| 209 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 210 |
# chooses 15% of token positions at random
|
| 211 |
# prob /= 0.15
|
| 212 |
prob = random.random()
|
| 213 |
if prob < 0.8: #[MASK] token 80% of the time
|
| 214 |
output_tokens[i] = self.vocab.vocab['[MASK]']
|
| 215 |
+
<<<<<<< HEAD
|
| 216 |
+
masked_pos.append(1)
|
| 217 |
+
elif prob < 0.9: # a random token 10% of the time
|
| 218 |
+
# print(".......0.8-0.9......")
|
| 219 |
+
if opt_step:
|
| 220 |
+
output_tokens[i] = random.choice([7,8,9,11,12,13,14,15,16,22,23,24,25,26,27,30,31,32])
|
| 221 |
+
opt_step = False
|
| 222 |
+
else:
|
| 223 |
+
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1)
|
| 224 |
+
masked_pos.append(1)
|
| 225 |
+
else: # the unchanged i-th token 10% of the time
|
| 226 |
+
# print(".......unchanged......")
|
| 227 |
+
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
| 228 |
+
masked_pos.append(0)
|
| 229 |
+
=======
|
| 230 |
elif prob < 0.9: # a random token 10% of the time
|
| 231 |
# print(".......0.8-0.9......")
|
| 232 |
output_tokens[i] = random.randint(1, len(self.vocab.vocab)-1)
|
| 233 |
else: # the unchanged i-th token 10% of the time
|
| 234 |
# print(".......unchanged......")
|
| 235 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
| 236 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 237 |
# True Label
|
| 238 |
output_labels.append(self.vocab.vocab.get(token, self.vocab.vocab['[UNK]']))
|
| 239 |
# masked_pos_label[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
|
|
|
| 242 |
output_tokens[i] = self.vocab.vocab.get(token, self.vocab.vocab['[UNK]'])
|
| 243 |
# Padded label
|
| 244 |
output_labels.append(self.vocab.vocab['[PAD]'])
|
| 245 |
+
<<<<<<< HEAD
|
| 246 |
+
masked_pos.append(0)
|
| 247 |
+
=======
|
| 248 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 249 |
# label_position = []
|
| 250 |
# label_tokens = []
|
| 251 |
# for k, v in masked_pos_label.items():
|
| 252 |
# label_position.append(k)
|
| 253 |
# label_tokens.append(v)
|
| 254 |
+
<<<<<<< HEAD
|
| 255 |
+
return output_tokens, output_labels, masked_pos
|
| 256 |
+
|
| 257 |
+
# def get_token_b(self, item):
|
| 258 |
+
# document_id = [k for k,v in self.index_documents.items() if item in v][0]
|
| 259 |
+
# random_document_id = document_id
|
| 260 |
+
|
| 261 |
+
# if random.random() < 0.5:
|
| 262 |
+
# document_ids = [k for k in self.index_documents.keys() if k != document_id]
|
| 263 |
+
# random_document_id = random.choice(document_ids)
|
| 264 |
+
|
| 265 |
+
# same_student = (random_document_id == document_id)
|
| 266 |
+
|
| 267 |
+
# nex_seq_list = self.index_documents.get(random_document_id)
|
| 268 |
+
|
| 269 |
+
# if same_student:
|
| 270 |
+
# if len(nex_seq_list) != 1:
|
| 271 |
+
# nex_seq_list = [v for v in nex_seq_list if v !=item]
|
| 272 |
+
|
| 273 |
+
# next_seq = random.choice(nex_seq_list)
|
| 274 |
+
# tokens = self.lines[next_seq]
|
| 275 |
+
# # print(f"item = {item}, tokens: {tokens}")
|
| 276 |
+
# # print(f"item={item}, next={next_seq}, same_student = {same_student}, {document_id} == {random_document_id}, b. {tokens}")
|
| 277 |
+
# return same_student, tokens
|
| 278 |
+
|
| 279 |
+
# def truncate_to_max_seq(self, s1, s2):
|
| 280 |
+
# sa = copy.deepcopy(s1)
|
| 281 |
+
# sb = copy.deepcopy(s1)
|
| 282 |
+
# total_allowed_seq = self.seq_len - 3
|
| 283 |
+
|
| 284 |
+
# while((len(sa)+len(sb)) > total_allowed_seq):
|
| 285 |
+
# if random.random() < 0.5:
|
| 286 |
+
# sa.pop()
|
| 287 |
+
# else:
|
| 288 |
+
# sb.pop()
|
| 289 |
+
# return sa, sb
|
| 290 |
+
|
| 291 |
+
=======
|
| 292 |
return output_tokens, output_labels
|
| 293 |
|
| 294 |
def get_token_b(self, item):
|
|
|
|
| 324 |
else:
|
| 325 |
sb.pop()
|
| 326 |
return sa, sb
|
| 327 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 328 |
|
| 329 |
class TokenizerDataset(Dataset):
|
| 330 |
"""
|
|
|
|
| 332 |
Tokenize the data in the dataset
|
| 333 |
|
| 334 |
"""
|
| 335 |
+
<<<<<<< HEAD
|
| 336 |
+
def __init__(self, dataset_path, label_path, vocab, seq_len=30):
|
| 337 |
+
self.dataset_path = dataset_path
|
| 338 |
+
self.label_path = label_path
|
| 339 |
+
self.vocab = vocab # Vocab object
|
| 340 |
+
# self.encoder = OneHotEncoder(sparse=False)
|
| 341 |
+
=======
|
| 342 |
def __init__(self, dataset_path, label_path, vocab, seq_len=30, train=True):
|
| 343 |
self.dataset_path = dataset_path
|
| 344 |
self.label_path = label_path
|
| 345 |
self.vocab = vocab # Vocab object
|
| 346 |
self.encoder = OneHotEncoder(sparse_output=False)
|
| 347 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 348 |
|
| 349 |
# Related to input dataset file
|
| 350 |
self.lines = []
|
| 351 |
self.labels = []
|
| 352 |
+
<<<<<<< HEAD
|
| 353 |
+
self.feats = []
|
| 354 |
+
if self.label_path:
|
| 355 |
+
self.label_file = open(self.label_path, "r")
|
| 356 |
+
for line in self.label_file:
|
| 357 |
+
if line:
|
| 358 |
+
line = line.strip()
|
| 359 |
+
if not line:
|
| 360 |
+
continue
|
| 361 |
+
self.labels.append(int(line))
|
| 362 |
+
self.label_file.close()
|
| 363 |
+
|
| 364 |
+
# Comment this section if you are not using feat attribute
|
| 365 |
+
try:
|
| 366 |
+
j = 0
|
| 367 |
+
dataset_info_file = open(self.label_path.replace("label", "info"), "r")
|
| 368 |
+
for line in dataset_info_file:
|
| 369 |
+
if line:
|
| 370 |
+
line = line.strip()
|
| 371 |
+
if not line:
|
| 372 |
+
continue
|
| 373 |
+
|
| 374 |
+
# # highGRschool_w_prior
|
| 375 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
| 376 |
+
|
| 377 |
+
# highGRschool_w_prior_w_diffskill_wo_fa
|
| 378 |
+
feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
| 379 |
+
feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
| 380 |
+
feat_vec.extend(feat2[1:])
|
| 381 |
+
|
| 382 |
+
# # highGRschool_w_prior_w_p_diffskill_wo_fa
|
| 383 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
| 384 |
+
# feat2 = [-float(i) for i in line.split(",")[-2].split("\t")]
|
| 385 |
+
# feat_vec.extend(feat2[1:])
|
| 386 |
+
|
| 387 |
+
# # highGRschool_w_prior_w_diffskill_0fa_skill
|
| 388 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
| 389 |
+
# feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
| 390 |
+
# fa_feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
| 391 |
+
|
| 392 |
+
# diff_skill = [f2 if f1==0 else 0 for f2, f1 in zip(feat2, fa_feat_vec)]
|
| 393 |
+
# feat_vec.extend(diff_skill)
|
| 394 |
+
|
| 395 |
+
if j == 0:
|
| 396 |
+
print(len(feat_vec))
|
| 397 |
+
j+=1
|
| 398 |
+
|
| 399 |
+
# feat_vec.extend(feat2[1:])
|
| 400 |
+
# feat_vec.extend(feat2)
|
| 401 |
+
# feat_vec = [float(i) for i in line.split(",")[-2].split("\t")]
|
| 402 |
+
# feat_vec = feat_vec[1:]
|
| 403 |
+
# feat_vec = [float(line.split(",")[-1])]
|
| 404 |
+
# feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
| 405 |
+
# feat_vec = [ft-f1 for ft, f1 in zip(feat_vec, fa_feat_vec)]
|
| 406 |
+
|
| 407 |
+
self.feats.append(feat_vec)
|
| 408 |
+
dataset_info_file.close()
|
| 409 |
+
except Exception as e:
|
| 410 |
+
print(e)
|
| 411 |
+
# labeler = np.array([0, 1]) #np.unique(self.labels)
|
| 412 |
+
# print(f"Labeler {labeler}")
|
| 413 |
+
# self.encoder.fit(labeler.reshape(-1,1))
|
| 414 |
+
# self.labels = self.encoder.transform(np.array(self.labels).reshape(-1,1))
|
| 415 |
+
|
| 416 |
+
self.file = open(self.dataset_path, "r")
|
| 417 |
+
=======
|
| 418 |
self.labels = []
|
| 419 |
|
| 420 |
self.label_file = open(self.label_path, "r")
|
|
|
|
| 466 |
|
| 467 |
self.file = open(self.dataset_path, "r")
|
| 468 |
# index = 0
|
| 469 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 470 |
for line in self.file:
|
| 471 |
if line:
|
| 472 |
line = line.strip()
|
| 473 |
if line:
|
| 474 |
self.lines.append(line)
|
| 475 |
+
<<<<<<< HEAD
|
| 476 |
+
=======
|
| 477 |
# if train:
|
| 478 |
# if index in indices_of_zeros:
|
| 479 |
# # if index in indices_of_prom:
|
|
|
|
| 488 |
# self.labels.append(labels[index])
|
| 489 |
# self.labels.append(progress[index])
|
| 490 |
# index += 1
|
| 491 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 492 |
self.file.close()
|
| 493 |
|
| 494 |
self.len = len(self.lines)
|
| 495 |
self.seq_len = seq_len
|
| 496 |
+
<<<<<<< HEAD
|
| 497 |
+
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels) if self.label_path else 0)
|
| 498 |
+
=======
|
| 499 |
|
| 500 |
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels))
|
| 501 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 502 |
|
| 503 |
def __len__(self):
|
| 504 |
return self.len
|
| 505 |
|
| 506 |
def __getitem__(self, item):
|
| 507 |
+
<<<<<<< HEAD
|
| 508 |
+
org_line = self.lines[item].split("\t")
|
| 509 |
+
dup_line = []
|
| 510 |
+
opt = False
|
| 511 |
+
for l in org_line:
|
| 512 |
+
if l in ["OptionalTask_1", "EquationAnswer", "NumeratorFactor", "DenominatorFactor", "OptionalTask_2", "FirstRow1:1", "FirstRow1:2", "FirstRow2:1", "FirstRow2:2", "SecondRow", "ThirdRow"]:
|
| 513 |
+
opt = True
|
| 514 |
+
if opt and 'FinalAnswer-' in l:
|
| 515 |
+
dup_line.append('[UNK]')
|
| 516 |
+
else:
|
| 517 |
+
dup_line.append(l)
|
| 518 |
+
dup_line = "\t".join(dup_line)
|
| 519 |
+
# print(dup_line)
|
| 520 |
+
s1 = self.vocab.to_seq(dup_line, self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
| 521 |
+
s1_label = self.labels[item] if self.label_path else 0
|
| 522 |
+
segment_label = [1 for _ in range(len(s1))]
|
| 523 |
+
s1_feat = self.feats[item] if len(self.feats)>0 else 0
|
| 524 |
+
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
|
| 525 |
+
s1.extend(padding), segment_label.extend(padding)
|
| 526 |
+
|
| 527 |
+
output = {'input': s1,
|
| 528 |
+
'label': s1_label,
|
| 529 |
+
'feat': s1_feat,
|
| 530 |
+
=======
|
| 531 |
|
| 532 |
s1 = self.vocab.to_seq(self.lines[item], self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
| 533 |
s1_label = self.labels[item]
|
|
|
|
| 538 |
|
| 539 |
output = {'bert_input': s1,
|
| 540 |
'progress_status': s1_label,
|
| 541 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 542 |
'segment_label': segment_label}
|
| 543 |
return {key: torch.tensor(value) for key, value in output.items()}
|
| 544 |
|
| 545 |
|
| 546 |
+
<<<<<<< HEAD
|
| 547 |
+
class TokenizerDatasetForCalibration(Dataset):
|
| 548 |
+
"""
|
| 549 |
+
Class name: TokenizerDataset
|
| 550 |
+
Tokenize the data in the dataset
|
| 551 |
+
|
| 552 |
+
"""
|
| 553 |
+
def __init__(self, dataset_path, label_path, vocab, seq_len=30):
|
| 554 |
+
self.dataset_path = dataset_path
|
| 555 |
+
self.label_path = label_path
|
| 556 |
+
self.vocab = vocab # Vocab object
|
| 557 |
+
# self.encoder = OneHotEncoder(sparse=False)
|
| 558 |
+
|
| 559 |
+
# Related to input dataset file
|
| 560 |
+
self.lines = []
|
| 561 |
+
self.labels = []
|
| 562 |
+
self.feats = []
|
| 563 |
+
if self.label_path:
|
| 564 |
+
self.label_file = open(self.label_path, "r")
|
| 565 |
+
for line in self.label_file:
|
| 566 |
+
if line:
|
| 567 |
+
line = line.strip()
|
| 568 |
+
if not line:
|
| 569 |
+
continue
|
| 570 |
+
self.labels.append(int(line))
|
| 571 |
+
self.label_file.close()
|
| 572 |
+
|
| 573 |
+
# Comment this section if you are not using feat attribute
|
| 574 |
+
try:
|
| 575 |
+
j = 0
|
| 576 |
+
dataset_info_file = open(self.label_path.replace("label", "info"), "r")
|
| 577 |
+
for line in dataset_info_file:
|
| 578 |
+
if line:
|
| 579 |
+
line = line.strip()
|
| 580 |
+
if not line:
|
| 581 |
+
continue
|
| 582 |
+
|
| 583 |
+
# # highGRschool_w_prior
|
| 584 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
| 585 |
+
|
| 586 |
+
# highGRschool_w_prior_w_diffskill_wo_fa
|
| 587 |
+
feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
| 588 |
+
feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
| 589 |
+
feat_vec.extend(feat2[1:])
|
| 590 |
+
|
| 591 |
+
# # highGRschool_w_prior_w_diffskill_0fa_skill
|
| 592 |
+
# feat_vec = [float(i) for i in line.split(",")[-3].split("\t")]
|
| 593 |
+
# feat2 = [float(i) for i in line.split(",")[-2].split("\t")]
|
| 594 |
+
# fa_feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
| 595 |
+
|
| 596 |
+
# diff_skill = [f2 if f1==0 else 0 for f2, f1 in zip(feat2, fa_feat_vec)]
|
| 597 |
+
# feat_vec.extend(diff_skill)
|
| 598 |
+
|
| 599 |
+
if j == 0:
|
| 600 |
+
print(len(feat_vec))
|
| 601 |
+
j+=1
|
| 602 |
+
|
| 603 |
+
# feat_vec.extend(feat2[1:])
|
| 604 |
+
# feat_vec.extend(feat2)
|
| 605 |
+
# feat_vec = [float(i) for i in line.split(",")[-2].split("\t")]
|
| 606 |
+
# feat_vec = feat_vec[1:]
|
| 607 |
+
# feat_vec = [float(line.split(",")[-1])]
|
| 608 |
+
# feat_vec = [float(i) for i in line.split(",")[-1].split("\t")]
|
| 609 |
+
# feat_vec = [ft-f1 for ft, f1 in zip(feat_vec, fa_feat_vec)]
|
| 610 |
+
|
| 611 |
+
self.feats.append(feat_vec)
|
| 612 |
+
dataset_info_file.close()
|
| 613 |
+
except Exception as e:
|
| 614 |
+
print(e)
|
| 615 |
+
# labeler = np.array([0, 1]) #np.unique(self.labels)
|
| 616 |
+
# print(f"Labeler {labeler}")
|
| 617 |
+
# self.encoder.fit(labeler.reshape(-1,1))
|
| 618 |
+
# self.labels = self.encoder.transform(np.array(self.labels).reshape(-1,1))
|
| 619 |
+
|
| 620 |
+
self.file = open(self.dataset_path, "r")
|
| 621 |
+
for line in self.file:
|
| 622 |
+
if line:
|
| 623 |
+
line = line.strip()
|
| 624 |
+
if line:
|
| 625 |
+
self.lines.append(line)
|
| 626 |
+
self.file.close()
|
| 627 |
+
|
| 628 |
+
self.len = len(self.lines)
|
| 629 |
+
self.seq_len = seq_len
|
| 630 |
+
print("Sequence length set at ", self.seq_len, len(self.lines), len(self.labels) if self.label_path else 0)
|
| 631 |
+
|
| 632 |
+
def __len__(self):
|
| 633 |
+
return self.len
|
| 634 |
+
|
| 635 |
+
def __getitem__(self, item):
|
| 636 |
+
org_line = self.lines[item].split("\t")
|
| 637 |
+
dup_line = []
|
| 638 |
+
opt = False
|
| 639 |
+
for l in org_line:
|
| 640 |
+
if l in ["OptionalTask_1", "EquationAnswer", "NumeratorFactor", "DenominatorFactor", "OptionalTask_2", "FirstRow1:1", "FirstRow1:2", "FirstRow2:1", "FirstRow2:2", "SecondRow", "ThirdRow"]:
|
| 641 |
+
opt = True
|
| 642 |
+
if opt and 'FinalAnswer-' in l:
|
| 643 |
+
dup_line.append('[UNK]')
|
| 644 |
+
else:
|
| 645 |
+
dup_line.append(l)
|
| 646 |
+
dup_line = "\t".join(dup_line)
|
| 647 |
+
# print(dup_line)
|
| 648 |
+
s1 = self.vocab.to_seq(dup_line, self.seq_len) # This is like tokenizer and adds [CLS] and [SEP].
|
| 649 |
+
s1_label = self.labels[item] if self.label_path else 0
|
| 650 |
+
segment_label = [1 for _ in range(len(s1))]
|
| 651 |
+
s1_feat = self.feats[item] if len(self.feats)>0 else 0
|
| 652 |
+
padding = [self.vocab.vocab['[PAD]'] for _ in range(self.seq_len - len(s1))]
|
| 653 |
+
s1.extend(padding), segment_label.extend(padding)
|
| 654 |
+
|
| 655 |
+
output = {'input': s1,
|
| 656 |
+
'label': s1_label,
|
| 657 |
+
'feat': s1_feat,
|
| 658 |
+
'segment_label': segment_label}
|
| 659 |
+
return ({key: torch.tensor(value) for key, value in output.items()}, s1_label)
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
# if __name__ == "__main__":
|
| 664 |
+
=======
|
| 665 |
# if __name__ == "__main__":
|
| 666 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 667 |
# # import pickle
|
| 668 |
# # k = pickle.load(open("dataset/CL4999_1920/unique_steps_list.pkl","rb"))
|
| 669 |
# # print(k)
|
src/pretrainer.py
CHANGED
|
@@ -1,5 +1,42 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
from torch.optim import Adam, SGD
|
| 5 |
from torch.utils.data import DataLoader
|
|
@@ -67,6 +104,7 @@ class BERTTrainer:
|
|
| 67 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 68 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False,
|
| 69 |
workspace_name=None):
|
|
|
|
| 70 |
"""
|
| 71 |
:param bert: BERT model which you want to train
|
| 72 |
:param vocab_size: total word vocab size
|
|
@@ -79,6 +117,17 @@ class BERTTrainer:
|
|
| 79 |
:param log_freq: logging frequency of the batch iteration
|
| 80 |
"""
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 83 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 84 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
|
@@ -87,15 +136,24 @@ class BERTTrainer:
|
|
| 87 |
# This BERT model will be saved every epoch
|
| 88 |
self.bert = bert
|
| 89 |
# Initialize the BERT Language Model, with BERT model
|
|
|
|
| 90 |
self.model = BERTSM(bert, vocab_size).to(self.device)
|
| 91 |
|
| 92 |
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 93 |
if with_cuda and torch.cuda.device_count() > 1:
|
| 94 |
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
|
| 96 |
|
| 97 |
# Setting the train and test data loader
|
| 98 |
self.train_data = train_dataloader
|
|
|
|
| 99 |
self.test_data = test_dataloader
|
| 100 |
|
| 101 |
# Setting the Adam optimizer with hyper-param
|
|
@@ -106,19 +164,44 @@ class BERTTrainer:
|
|
| 106 |
self.criterion = nn.NLLLoss(ignore_index=0)
|
| 107 |
|
| 108 |
self.log_freq = log_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
self.same_student_prediction = same_student_prediction
|
| 110 |
self.workspace_name = workspace_name
|
| 111 |
self.save_model = False
|
| 112 |
self.avg_loss = 10000
|
|
|
|
| 113 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 114 |
|
| 115 |
def train(self, epoch):
|
| 116 |
self.iteration(epoch, self.train_data)
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def test(self, epoch):
|
| 119 |
self.iteration(epoch, self.test_data, train=False)
|
| 120 |
|
| 121 |
def iteration(self, epoch, data_loader, train=True):
|
|
|
|
| 122 |
"""
|
| 123 |
loop over the data_loader for training or testing
|
| 124 |
if on train status, backward operation is activated
|
|
@@ -129,6 +212,30 @@ class BERTTrainer:
|
|
| 129 |
:param train: boolean value of is train or test
|
| 130 |
:return: None
|
| 131 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
str_code = "train" if train else "test"
|
| 133 |
code = "masked_prediction" if self.same_student_prediction else "masked"
|
| 134 |
|
|
@@ -155,10 +262,25 @@ class BERTTrainer:
|
|
| 155 |
|
| 156 |
avg_loss = 0.0
|
| 157 |
with open(self.log_file, 'a') as f:
|
|
|
|
| 158 |
sys.stdout = f
|
| 159 |
for i, data in data_iter:
|
| 160 |
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 161 |
data = {key: value.to(self.device) for key, value in data.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
# 1. forward the next_sentence_prediction and masked_lm model
|
| 164 |
# next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
|
|
@@ -184,10 +306,49 @@ class BERTTrainer:
|
|
| 184 |
|
| 185 |
# 3. backward and optimization only in train
|
| 186 |
if train:
|
|
|
|
| 187 |
self.optim_schedule.zero_grad()
|
| 188 |
loss.backward()
|
| 189 |
self.optim_schedule.step_and_update_lr()
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
non_zero_mask = (data["bert_label"] != 0).float()
|
| 193 |
predictions = torch.argmax(mask_lm_output, dim=-1)
|
|
@@ -249,6 +410,7 @@ class BERTTrainer:
|
|
| 249 |
# pickle.dump(bert_hidden_representations, open(f"embeddings/{code}/{str_code}_embeddings_{epoch}.pkl","wb"))
|
| 250 |
|
| 251 |
|
|
|
|
| 252 |
|
| 253 |
def save(self, epoch, file_path="output/bert_trained.model"):
|
| 254 |
"""
|
|
@@ -270,7 +432,12 @@ class BERTFineTuneTrainer:
|
|
| 270 |
def __init__(self, bert: BERT, vocab_size: int,
|
| 271 |
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
| 272 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, num_labels=2):
|
|
|
|
| 274 |
"""
|
| 275 |
:param bert: BERT model which you want to train
|
| 276 |
:param vocab_size: total word vocab size
|
|
@@ -286,6 +453,302 @@ class BERTFineTuneTrainer:
|
|
| 286 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 287 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 288 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
print("Device used = ", self.device)
|
| 290 |
|
| 291 |
# This BERT model will be saved every epoch
|
|
@@ -320,15 +783,28 @@ class BERTFineTuneTrainer:
|
|
| 320 |
self.workspace_name = workspace_name
|
| 321 |
self.save_model = False
|
| 322 |
self.avg_loss = 10000
|
|
|
|
| 323 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 324 |
|
| 325 |
def train(self, epoch):
|
| 326 |
self.iteration(epoch, self.train_data)
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
def test(self, epoch):
|
| 329 |
self.iteration(epoch, self.test_data, train=False)
|
| 330 |
|
| 331 |
def iteration(self, epoch, data_loader, train=True):
|
|
|
|
| 332 |
"""
|
| 333 |
loop over the data_loader for training or testing
|
| 334 |
if on train status, backward operation is activated
|
|
@@ -339,6 +815,12 @@ class BERTFineTuneTrainer:
|
|
| 339 |
:param train: boolean value of is train or test
|
| 340 |
:return: None
|
| 341 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
str_code = "train" if train else "test"
|
| 343 |
|
| 344 |
self.log_file = f"{self.workspace_name}/logs/masked/log_{str_code}_FS_finetuned.txt"
|
|
@@ -352,6 +834,7 @@ class BERTFineTuneTrainer:
|
|
| 352 |
# Setting the tqdm progress bar
|
| 353 |
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 354 |
desc="EP_%s:%d" % (str_code, epoch),
|
|
|
|
| 355 |
total=len(data_loader),
|
| 356 |
bar_format="{l_bar}{r_bar}")
|
| 357 |
|
|
@@ -360,6 +843,28 @@ class BERTFineTuneTrainer:
|
|
| 360 |
total_element = 0
|
| 361 |
plabels = []
|
| 362 |
tlabels = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
eval_accurate_nb = 0
|
| 364 |
nb_eval_examples = 0
|
| 365 |
logits_list = []
|
|
@@ -390,10 +895,81 @@ class BERTFineTuneTrainer:
|
|
| 390 |
progress_loss = self.criterion(logits, data["progress_status"])
|
| 391 |
loss = progress_loss
|
| 392 |
|
|
|
|
| 393 |
if torch.cuda.device_count() > 1:
|
| 394 |
loss = loss.mean()
|
| 395 |
|
| 396 |
# 3. backward and optimization only in train
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
if train:
|
| 398 |
self.optim.zero_grad()
|
| 399 |
loss.backward()
|
|
@@ -489,13 +1065,40 @@ class BERTFineTuneTrainer:
|
|
| 489 |
f.close()
|
| 490 |
sys.stdout = sys.__stdout__
|
| 491 |
if train:
|
|
|
|
| 492 |
self.save_model = False
|
| 493 |
if self.avg_loss > (avg_loss / len(data_iter)):
|
| 494 |
self.save_model = True
|
| 495 |
self.avg_loss = (avg_loss / len(data_iter))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
# plt_test.show()
|
| 498 |
# print("EP%d_%s, " % (epoch, str_code))
|
|
|
|
| 499 |
|
| 500 |
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
| 501 |
"""
|
|
@@ -510,3 +1113,113 @@ class BERTFineTuneTrainer:
|
|
| 510 |
self.model.to(self.device)
|
| 511 |
print("EP:%d Model Saved on:" % epoch, output_path)
|
| 512 |
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
<<<<<<< HEAD
|
| 4 |
+
# from torch.nn import functional as F
|
| 5 |
+
from torch.optim import Adam
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
# import pickle
|
| 8 |
+
|
| 9 |
+
from .bert import BERT
|
| 10 |
+
from .seq_model import BERTSM
|
| 11 |
+
from .classifier_model import BERTForClassification, BERTForClassificationWithFeats
|
| 12 |
+
from .optim_schedule import ScheduledOptim
|
| 13 |
+
|
| 14 |
+
import tqdm
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
|
| 21 |
+
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import seaborn as sns
|
| 24 |
+
import pandas as pd
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
class BERTTrainer:
|
| 29 |
+
"""
|
| 30 |
+
BERTTrainer pretrains BERT model on input sequence of strategies.
|
| 31 |
+
BERTTrainer make the pretrained BERT model with one training method objective.
|
| 32 |
+
1. Masked Strategy Modeling :Masked SM
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
| 36 |
+
train_dataloader: DataLoader, val_dataloader: DataLoader = None, test_dataloader: DataLoader = None,
|
| 37 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=5000,
|
| 38 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, log_folder_path: str = None):
|
| 39 |
+
=======
|
| 40 |
from torch.nn import functional as F
|
| 41 |
from torch.optim import Adam, SGD
|
| 42 |
from torch.utils.data import DataLoader
|
|
|
|
| 104 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 105 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False,
|
| 106 |
workspace_name=None):
|
| 107 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 108 |
"""
|
| 109 |
:param bert: BERT model which you want to train
|
| 110 |
:param vocab_size: total word vocab size
|
|
|
|
| 117 |
:param log_freq: logging frequency of the batch iteration
|
| 118 |
"""
|
| 119 |
|
| 120 |
+
<<<<<<< HEAD
|
| 121 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 122 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 123 |
+
print(cuda_condition, " Device used = ", self.device)
|
| 124 |
+
|
| 125 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
| 126 |
+
|
| 127 |
+
# This BERT model will be saved
|
| 128 |
+
self.bert = bert.to(self.device)
|
| 129 |
+
# Initialize the BERT Sequence Model, with BERT model
|
| 130 |
+
=======
|
| 131 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 132 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 133 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
|
|
|
| 136 |
# This BERT model will be saved every epoch
|
| 137 |
self.bert = bert
|
| 138 |
# Initialize the BERT Language Model, with BERT model
|
| 139 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 140 |
self.model = BERTSM(bert, vocab_size).to(self.device)
|
| 141 |
|
| 142 |
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 143 |
if with_cuda and torch.cuda.device_count() > 1:
|
| 144 |
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 145 |
+
<<<<<<< HEAD
|
| 146 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
| 147 |
+
|
| 148 |
+
# Setting the train, validation and test data loader
|
| 149 |
+
self.train_data = train_dataloader
|
| 150 |
+
self.val_data = val_dataloader
|
| 151 |
+
=======
|
| 152 |
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
|
| 153 |
|
| 154 |
# Setting the train and test data loader
|
| 155 |
self.train_data = train_dataloader
|
| 156 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 157 |
self.test_data = test_dataloader
|
| 158 |
|
| 159 |
# Setting the Adam optimizer with hyper-param
|
|
|
|
| 164 |
self.criterion = nn.NLLLoss(ignore_index=0)
|
| 165 |
|
| 166 |
self.log_freq = log_freq
|
| 167 |
+
<<<<<<< HEAD
|
| 168 |
+
self.log_folder_path = log_folder_path
|
| 169 |
+
# self.workspace_name = workspace_name
|
| 170 |
+
self.save_model = False
|
| 171 |
+
# self.code = code
|
| 172 |
+
self.avg_loss = 10000
|
| 173 |
+
for fi in ['train', 'val', 'test']:
|
| 174 |
+
f = open(self.log_folder_path+f"/log_{fi}_pretrained.txt", 'w')
|
| 175 |
+
f.close()
|
| 176 |
+
self.start_time = time.time()
|
| 177 |
+
|
| 178 |
+
=======
|
| 179 |
self.same_student_prediction = same_student_prediction
|
| 180 |
self.workspace_name = workspace_name
|
| 181 |
self.save_model = False
|
| 182 |
self.avg_loss = 10000
|
| 183 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 184 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 185 |
|
| 186 |
def train(self, epoch):
|
| 187 |
self.iteration(epoch, self.train_data)
|
| 188 |
|
| 189 |
+
<<<<<<< HEAD
|
| 190 |
+
def val(self, epoch):
|
| 191 |
+
if epoch == 0:
|
| 192 |
+
self.avg_loss = 10000
|
| 193 |
+
self.iteration(epoch, self.val_data, phase="val")
|
| 194 |
+
|
| 195 |
+
def test(self, epoch):
|
| 196 |
+
self.iteration(epoch, self.test_data, phase="test")
|
| 197 |
+
|
| 198 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
| 199 |
+
=======
|
| 200 |
def test(self, epoch):
|
| 201 |
self.iteration(epoch, self.test_data, train=False)
|
| 202 |
|
| 203 |
def iteration(self, epoch, data_loader, train=True):
|
| 204 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 205 |
"""
|
| 206 |
loop over the data_loader for training or testing
|
| 207 |
if on train status, backward operation is activated
|
|
|
|
| 212 |
:param train: boolean value of is train or test
|
| 213 |
:return: None
|
| 214 |
"""
|
| 215 |
+
<<<<<<< HEAD
|
| 216 |
+
|
| 217 |
+
# self.log_file = f"{self.workspace_name}/logs/{self.code}/log_{phase}_pretrained.txt"
|
| 218 |
+
# bert_hidden_representations = [] can be used
|
| 219 |
+
# if epoch == 0:
|
| 220 |
+
# f = open(self.log_file, 'w')
|
| 221 |
+
# f.close()
|
| 222 |
+
|
| 223 |
+
# Progress bar
|
| 224 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 225 |
+
desc="EP_%s:%d" % (phase, epoch),
|
| 226 |
+
total=len(data_loader),
|
| 227 |
+
bar_format="{l_bar}{r_bar}")
|
| 228 |
+
|
| 229 |
+
total_correct = 0
|
| 230 |
+
total_element = 0
|
| 231 |
+
avg_loss = 0.0
|
| 232 |
+
|
| 233 |
+
if phase == "train":
|
| 234 |
+
self.model.train()
|
| 235 |
+
else:
|
| 236 |
+
self.model.eval()
|
| 237 |
+
with open(self.log_folder_path+f"/log_{phase}_pretrained.txt", 'a') as f:
|
| 238 |
+
=======
|
| 239 |
str_code = "train" if train else "test"
|
| 240 |
code = "masked_prediction" if self.same_student_prediction else "masked"
|
| 241 |
|
|
|
|
| 262 |
|
| 263 |
avg_loss = 0.0
|
| 264 |
with open(self.log_file, 'a') as f:
|
| 265 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 266 |
sys.stdout = f
|
| 267 |
for i, data in data_iter:
|
| 268 |
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 269 |
data = {key: value.to(self.device) for key, value in data.items()}
|
| 270 |
+
<<<<<<< HEAD
|
| 271 |
+
|
| 272 |
+
# 1. forward masked_sm model
|
| 273 |
+
# mask_sm_output is log-probabilities output
|
| 274 |
+
mask_sm_output, bert_hidden_rep = self.model.forward(data["bert_input"], data["segment_label"])
|
| 275 |
+
|
| 276 |
+
# 2. NLLLoss of predicting masked token word
|
| 277 |
+
loss = self.criterion(mask_sm_output.transpose(1, 2), data["bert_label"])
|
| 278 |
+
if torch.cuda.device_count() > 1:
|
| 279 |
+
loss = loss.mean()
|
| 280 |
+
|
| 281 |
+
# 3. backward and optimization only in train
|
| 282 |
+
if phase == "train":
|
| 283 |
+
=======
|
| 284 |
|
| 285 |
# 1. forward the next_sentence_prediction and masked_lm model
|
| 286 |
# next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
|
|
|
|
| 306 |
|
| 307 |
# 3. backward and optimization only in train
|
| 308 |
if train:
|
| 309 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 310 |
self.optim_schedule.zero_grad()
|
| 311 |
loss.backward()
|
| 312 |
self.optim_schedule.step_and_update_lr()
|
| 313 |
|
| 314 |
+
<<<<<<< HEAD
|
| 315 |
+
# tokens with highest log-probabilities creates a predicted sequence
|
| 316 |
+
pred_tokens = torch.argmax(mask_sm_output, dim=-1)
|
| 317 |
+
mask_correct = (data["bert_label"] == pred_tokens) & data["masked_pos"]
|
| 318 |
+
|
| 319 |
+
total_correct += mask_correct.sum().item()
|
| 320 |
+
total_element += data["masked_pos"].sum().item()
|
| 321 |
+
avg_loss +=loss.item()
|
| 322 |
+
|
| 323 |
+
torch.cuda.empty_cache()
|
| 324 |
+
|
| 325 |
+
post_fix = {
|
| 326 |
+
"epoch": epoch,
|
| 327 |
+
"iter": i,
|
| 328 |
+
"avg_loss": avg_loss / (i + 1),
|
| 329 |
+
"avg_acc_mask": (total_correct / total_element * 100) if total_element != 0 else 0,
|
| 330 |
+
"loss": loss.item()
|
| 331 |
+
}
|
| 332 |
+
if i % self.log_freq == 0:
|
| 333 |
+
data_iter.write(str(post_fix))
|
| 334 |
+
|
| 335 |
+
end_time = time.time()
|
| 336 |
+
final_msg = {
|
| 337 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 338 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 339 |
+
"total_masked_acc": (total_correct / total_element * 100) if total_element != 0 else 0,
|
| 340 |
+
"time_taken_from_start": end_time - self.start_time
|
| 341 |
+
}
|
| 342 |
+
print(final_msg)
|
| 343 |
+
f.close()
|
| 344 |
+
sys.stdout = sys.__stdout__
|
| 345 |
+
|
| 346 |
+
if phase == "val":
|
| 347 |
+
self.save_model = False
|
| 348 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
| 349 |
+
self.save_model = True
|
| 350 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
| 351 |
+
=======
|
| 352 |
|
| 353 |
non_zero_mask = (data["bert_label"] != 0).float()
|
| 354 |
predictions = torch.argmax(mask_lm_output, dim=-1)
|
|
|
|
| 410 |
# pickle.dump(bert_hidden_representations, open(f"embeddings/{code}/{str_code}_embeddings_{epoch}.pkl","wb"))
|
| 411 |
|
| 412 |
|
| 413 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 414 |
|
| 415 |
def save(self, epoch, file_path="output/bert_trained.model"):
|
| 416 |
"""
|
|
|
|
| 432 |
def __init__(self, bert: BERT, vocab_size: int,
|
| 433 |
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
| 434 |
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 435 |
+
<<<<<<< HEAD
|
| 436 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
| 437 |
+
num_labels=2, log_folder_path: str = None):
|
| 438 |
+
=======
|
| 439 |
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, num_labels=2):
|
| 440 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 441 |
"""
|
| 442 |
:param bert: BERT model which you want to train
|
| 443 |
:param vocab_size: total word vocab size
|
|
|
|
| 453 |
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 454 |
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 455 |
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 456 |
+
<<<<<<< HEAD
|
| 457 |
+
print(cuda_condition, " Device used = ", self.device)
|
| 458 |
+
|
| 459 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
| 460 |
+
|
| 461 |
+
# This BERT model will be saved every epoch
|
| 462 |
+
self.bert = bert
|
| 463 |
+
for param in self.bert.parameters():
|
| 464 |
+
param.requires_grad = False
|
| 465 |
+
# Initialize the BERT Language Model, with BERT model
|
| 466 |
+
# self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
| 467 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
| 468 |
+
self.model = BERTForClassificationWithFeats(self.bert, num_labels, 17).to(self.device)
|
| 469 |
+
|
| 470 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
| 471 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 472 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
| 473 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 474 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
| 475 |
+
|
| 476 |
+
# Setting the train, validation and test data loader
|
| 477 |
+
self.train_data = train_dataloader
|
| 478 |
+
# self.val_data = val_dataloader
|
| 479 |
+
self.test_data = test_dataloader
|
| 480 |
+
|
| 481 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
| 482 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
| 483 |
+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
|
| 484 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
| 485 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 486 |
+
|
| 487 |
+
# if num_labels == 1:
|
| 488 |
+
# self.criterion = nn.MSELoss()
|
| 489 |
+
# elif num_labels == 2:
|
| 490 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 491 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
| 492 |
+
# elif num_labels > 2:
|
| 493 |
+
# self.criterion = nn.CrossEntropyLoss()
|
| 494 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
self.log_freq = log_freq
|
| 498 |
+
self.log_folder_path = log_folder_path
|
| 499 |
+
# self.workspace_name = workspace_name
|
| 500 |
+
# self.finetune_task = finetune_task
|
| 501 |
+
self.save_model = False
|
| 502 |
+
self.avg_loss = 10000
|
| 503 |
+
self.start_time = time.time()
|
| 504 |
+
# self.probability_list = []
|
| 505 |
+
for fi in ['train', 'test']: #'val',
|
| 506 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
| 507 |
+
f.close()
|
| 508 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 509 |
+
|
| 510 |
+
def train(self, epoch):
|
| 511 |
+
self.iteration(epoch, self.train_data)
|
| 512 |
+
|
| 513 |
+
# def val(self, epoch):
|
| 514 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
| 515 |
+
|
| 516 |
+
def test(self, epoch):
|
| 517 |
+
if epoch == 0:
|
| 518 |
+
self.avg_loss = 10000
|
| 519 |
+
self.iteration(epoch, self.test_data, phase="test")
|
| 520 |
+
|
| 521 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
| 522 |
+
"""
|
| 523 |
+
loop over the data_loader for training or testing
|
| 524 |
+
if on train status, backward operation is activated
|
| 525 |
+
and also auto save the model every peoch
|
| 526 |
+
|
| 527 |
+
:param epoch: current epoch index
|
| 528 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
| 529 |
+
:param train: boolean value of is train or test
|
| 530 |
+
:return: None
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
# Setting the tqdm progress bar
|
| 534 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 535 |
+
desc="EP_%s:%d" % (phase, epoch),
|
| 536 |
+
total=len(data_loader),
|
| 537 |
+
bar_format="{l_bar}{r_bar}")
|
| 538 |
+
|
| 539 |
+
avg_loss = 0.0
|
| 540 |
+
total_correct = 0
|
| 541 |
+
total_element = 0
|
| 542 |
+
plabels = []
|
| 543 |
+
tlabels = []
|
| 544 |
+
probabs = []
|
| 545 |
+
|
| 546 |
+
if phase == "train":
|
| 547 |
+
self.model.train()
|
| 548 |
+
else:
|
| 549 |
+
self.model.eval()
|
| 550 |
+
# self.probability_list = []
|
| 551 |
+
|
| 552 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
| 553 |
+
sys.stdout = f
|
| 554 |
+
for i, data in data_iter:
|
| 555 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 556 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 557 |
+
if phase == "train":
|
| 558 |
+
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
| 559 |
+
else:
|
| 560 |
+
with torch.no_grad():
|
| 561 |
+
logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
|
| 562 |
+
|
| 563 |
+
loss = self.criterion(logits, data["label"])
|
| 564 |
+
if torch.cuda.device_count() > 1:
|
| 565 |
+
loss = loss.mean()
|
| 566 |
+
|
| 567 |
+
# 3. backward and optimization only in train
|
| 568 |
+
if phase == "train":
|
| 569 |
+
self.optim_schedule.zero_grad()
|
| 570 |
+
loss.backward()
|
| 571 |
+
self.optim_schedule.step_and_update_lr()
|
| 572 |
+
|
| 573 |
+
# prediction accuracy
|
| 574 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
| 575 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
| 576 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
| 577 |
+
# self.probability_list.append(probs)
|
| 578 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
| 579 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
| 580 |
+
tlabels.extend(data['label'].cpu().numpy())
|
| 581 |
+
|
| 582 |
+
# Compare predicted labels to true labels and calculate accuracy
|
| 583 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
| 584 |
+
|
| 585 |
+
avg_loss += loss.item()
|
| 586 |
+
total_correct += correct
|
| 587 |
+
# total_element += true_labels.nelement()
|
| 588 |
+
total_element += data["label"].nelement()
|
| 589 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
| 590 |
+
|
| 591 |
+
post_fix = {
|
| 592 |
+
"epoch": epoch,
|
| 593 |
+
"iter": i,
|
| 594 |
+
"avg_loss": avg_loss / (i + 1),
|
| 595 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
| 596 |
+
"loss": loss.item()
|
| 597 |
+
}
|
| 598 |
+
if i % self.log_freq == 0:
|
| 599 |
+
data_iter.write(str(post_fix))
|
| 600 |
+
|
| 601 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
| 602 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
| 603 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
| 604 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
| 605 |
+
end_time = time.time()
|
| 606 |
+
final_msg = {
|
| 607 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 608 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 609 |
+
"total_acc": total_correct * 100.0 / total_element,
|
| 610 |
+
"precisions": precisions,
|
| 611 |
+
"recalls": recalls,
|
| 612 |
+
"f1_scores": f1_scores,
|
| 613 |
+
# "confusion_matrix": f"{cmatrix}",
|
| 614 |
+
# "true_labels": f"{tlabels}",
|
| 615 |
+
# "predicted_labels": f"{plabels}",
|
| 616 |
+
"time_taken_from_start": end_time - self.start_time
|
| 617 |
+
}
|
| 618 |
+
print(final_msg)
|
| 619 |
+
f.close()
|
| 620 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
| 621 |
+
sys.stdout = f1
|
| 622 |
+
final_msg = {
|
| 623 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 624 |
+
"confusion_matrix": f"{cmatrix}",
|
| 625 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
| 626 |
+
"predicted_labels": f"{plabels}",
|
| 627 |
+
"probabilities": f"{probabs}",
|
| 628 |
+
"time_taken_from_start": end_time - self.start_time
|
| 629 |
+
}
|
| 630 |
+
print(final_msg)
|
| 631 |
+
f1.close()
|
| 632 |
+
sys.stdout = sys.__stdout__
|
| 633 |
+
sys.stdout = sys.__stdout__
|
| 634 |
+
|
| 635 |
+
if phase == "test":
|
| 636 |
+
self.save_model = False
|
| 637 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
| 638 |
+
self.save_model = True
|
| 639 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
| 640 |
+
|
| 641 |
+
def iteration_1(self, epoch_idx, data):
|
| 642 |
+
try:
|
| 643 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 644 |
+
logits = self.model(data['input_ids'], data['segment_label'])
|
| 645 |
+
# Ensure logits is a tensor, not a tuple
|
| 646 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 647 |
+
loss = loss_fct(logits, data['labels'])
|
| 648 |
+
|
| 649 |
+
# Backpropagation and optimization
|
| 650 |
+
self.optim.zero_grad()
|
| 651 |
+
loss.backward()
|
| 652 |
+
self.optim.step()
|
| 653 |
+
|
| 654 |
+
if self.log_freq > 0 and epoch_idx % self.log_freq == 0:
|
| 655 |
+
print(f"Epoch {epoch_idx}: Loss = {loss.item()}")
|
| 656 |
+
|
| 657 |
+
return loss
|
| 658 |
+
|
| 659 |
+
except Exception as e:
|
| 660 |
+
print(f"Error during iteration: {e}")
|
| 661 |
+
raise
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
| 665 |
+
"""
|
| 666 |
+
Saving the current BERT model on file_path
|
| 667 |
+
|
| 668 |
+
:param epoch: current epoch number
|
| 669 |
+
:param file_path: model output path which gonna be file_path+"ep%d" % epoch
|
| 670 |
+
:return: final_output_path
|
| 671 |
+
"""
|
| 672 |
+
output_path = file_path + ".ep%d" % epoch
|
| 673 |
+
torch.save(self.model.cpu(), output_path)
|
| 674 |
+
self.model.to(self.device)
|
| 675 |
+
print("EP:%d Model Saved on:" % epoch, output_path)
|
| 676 |
+
return output_path
|
| 677 |
+
|
| 678 |
+
class BERTFineTuneTrainer1:
|
| 679 |
+
|
| 680 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
| 681 |
+
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
| 682 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 683 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
| 684 |
+
num_labels=2, log_folder_path: str = None):
|
| 685 |
+
"""
|
| 686 |
+
:param bert: BERT model which you want to train
|
| 687 |
+
:param vocab_size: total word vocab size
|
| 688 |
+
:param train_dataloader: train dataset data loader
|
| 689 |
+
:param test_dataloader: test dataset data loader [can be None]
|
| 690 |
+
:param lr: learning rate of optimizer
|
| 691 |
+
:param betas: Adam optimizer betas
|
| 692 |
+
:param weight_decay: Adam optimizer weight decay param
|
| 693 |
+
:param with_cuda: traning with cuda
|
| 694 |
+
:param log_freq: logging frequency of the batch iteration
|
| 695 |
+
"""
|
| 696 |
+
|
| 697 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 698 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 699 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 700 |
+
print(cuda_condition, " Device used = ", self.device)
|
| 701 |
+
|
| 702 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
| 703 |
+
|
| 704 |
+
# This BERT model will be saved every epoch
|
| 705 |
+
self.bert = bert
|
| 706 |
+
for param in self.bert.parameters():
|
| 707 |
+
param.requires_grad = False
|
| 708 |
+
# Initialize the BERT Language Model, with BERT model
|
| 709 |
+
self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
| 710 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8).to(self.device)
|
| 711 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 8*2).to(self.device)
|
| 712 |
+
|
| 713 |
+
# self.model = BERTForClassificationWithFeats(self.bert, num_labels, 1).to(self.device)
|
| 714 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 715 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
| 716 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 717 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
| 718 |
+
|
| 719 |
+
# Setting the train, validation and test data loader
|
| 720 |
+
self.train_data = train_dataloader
|
| 721 |
+
# self.val_data = val_dataloader
|
| 722 |
+
self.test_data = test_dataloader
|
| 723 |
+
|
| 724 |
+
# self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
| 725 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
| 726 |
+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
|
| 727 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
| 728 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 729 |
+
|
| 730 |
+
# if num_labels == 1:
|
| 731 |
+
# self.criterion = nn.MSELoss()
|
| 732 |
+
# elif num_labels == 2:
|
| 733 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 734 |
+
# # self.criterion = nn.CrossEntropyLoss()
|
| 735 |
+
# elif num_labels > 2:
|
| 736 |
+
# self.criterion = nn.CrossEntropyLoss()
|
| 737 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
self.log_freq = log_freq
|
| 741 |
+
self.log_folder_path = log_folder_path
|
| 742 |
+
# self.workspace_name = workspace_name
|
| 743 |
+
# self.finetune_task = finetune_task
|
| 744 |
+
self.save_model = False
|
| 745 |
+
self.avg_loss = 10000
|
| 746 |
+
self.start_time = time.time()
|
| 747 |
+
# self.probability_list = []
|
| 748 |
+
for fi in ['train', 'test']: #'val',
|
| 749 |
+
f = open(self.log_folder_path+f"/log_{fi}_finetuned.txt", 'w')
|
| 750 |
+
f.close()
|
| 751 |
+
=======
|
| 752 |
print("Device used = ", self.device)
|
| 753 |
|
| 754 |
# This BERT model will be saved every epoch
|
|
|
|
| 783 |
self.workspace_name = workspace_name
|
| 784 |
self.save_model = False
|
| 785 |
self.avg_loss = 10000
|
| 786 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 787 |
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 788 |
|
| 789 |
def train(self, epoch):
|
| 790 |
self.iteration(epoch, self.train_data)
|
| 791 |
|
| 792 |
+
<<<<<<< HEAD
|
| 793 |
+
# def val(self, epoch):
|
| 794 |
+
# self.iteration(epoch, self.val_data, phase="val")
|
| 795 |
+
|
| 796 |
+
def test(self, epoch):
|
| 797 |
+
if epoch == 0:
|
| 798 |
+
self.avg_loss = 10000
|
| 799 |
+
self.iteration(epoch, self.test_data, phase="test")
|
| 800 |
+
|
| 801 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
| 802 |
+
=======
|
| 803 |
def test(self, epoch):
|
| 804 |
self.iteration(epoch, self.test_data, train=False)
|
| 805 |
|
| 806 |
def iteration(self, epoch, data_loader, train=True):
|
| 807 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 808 |
"""
|
| 809 |
loop over the data_loader for training or testing
|
| 810 |
if on train status, backward operation is activated
|
|
|
|
| 815 |
:param train: boolean value of is train or test
|
| 816 |
:return: None
|
| 817 |
"""
|
| 818 |
+
<<<<<<< HEAD
|
| 819 |
+
|
| 820 |
+
# Setting the tqdm progress bar
|
| 821 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 822 |
+
desc="EP_%s:%d" % (phase, epoch),
|
| 823 |
+
=======
|
| 824 |
str_code = "train" if train else "test"
|
| 825 |
|
| 826 |
self.log_file = f"{self.workspace_name}/logs/masked/log_{str_code}_FS_finetuned.txt"
|
|
|
|
| 834 |
# Setting the tqdm progress bar
|
| 835 |
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 836 |
desc="EP_%s:%d" % (str_code, epoch),
|
| 837 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 838 |
total=len(data_loader),
|
| 839 |
bar_format="{l_bar}{r_bar}")
|
| 840 |
|
|
|
|
| 843 |
total_element = 0
|
| 844 |
plabels = []
|
| 845 |
tlabels = []
|
| 846 |
+
<<<<<<< HEAD
|
| 847 |
+
probabs = []
|
| 848 |
+
|
| 849 |
+
if phase == "train":
|
| 850 |
+
self.model.train()
|
| 851 |
+
else:
|
| 852 |
+
self.model.eval()
|
| 853 |
+
# self.probability_list = []
|
| 854 |
+
|
| 855 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned.txt", 'a') as f:
|
| 856 |
+
sys.stdout = f
|
| 857 |
+
for i, data in data_iter:
|
| 858 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 859 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 860 |
+
if phase == "train":
|
| 861 |
+
logits = self.model.forward(data["input"], data["segment_label"])#, data["feat"])
|
| 862 |
+
else:
|
| 863 |
+
with torch.no_grad():
|
| 864 |
+
logits = self.model.forward(data["input"], data["segment_label"])#, data["feat"])
|
| 865 |
+
|
| 866 |
+
loss = self.criterion(logits, data["label"])
|
| 867 |
+
=======
|
| 868 |
eval_accurate_nb = 0
|
| 869 |
nb_eval_examples = 0
|
| 870 |
logits_list = []
|
|
|
|
| 895 |
progress_loss = self.criterion(logits, data["progress_status"])
|
| 896 |
loss = progress_loss
|
| 897 |
|
| 898 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 899 |
if torch.cuda.device_count() > 1:
|
| 900 |
loss = loss.mean()
|
| 901 |
|
| 902 |
# 3. backward and optimization only in train
|
| 903 |
+
<<<<<<< HEAD
|
| 904 |
+
if phase == "train":
|
| 905 |
+
self.optim_schedule.zero_grad()
|
| 906 |
+
loss.backward()
|
| 907 |
+
self.optim_schedule.step_and_update_lr()
|
| 908 |
+
|
| 909 |
+
# prediction accuracy
|
| 910 |
+
probs = nn.Softmax(dim=-1)(logits) # Probabilities
|
| 911 |
+
probabs.extend(probs.detach().cpu().numpy().tolist())
|
| 912 |
+
predicted_labels = torch.argmax(probs, dim=-1) #correct
|
| 913 |
+
# self.probability_list.append(probs)
|
| 914 |
+
# true_labels = torch.argmax(data["label"], dim=-1)
|
| 915 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
| 916 |
+
tlabels.extend(data['label'].cpu().numpy())
|
| 917 |
+
|
| 918 |
+
# Compare predicted labels to true labels and calculate accuracy
|
| 919 |
+
correct = (data['label'] == predicted_labels).sum().item()
|
| 920 |
+
|
| 921 |
+
avg_loss += loss.item()
|
| 922 |
+
total_correct += correct
|
| 923 |
+
# total_element += true_labels.nelement()
|
| 924 |
+
total_element += data["label"].nelement()
|
| 925 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
| 926 |
+
|
| 927 |
+
post_fix = {
|
| 928 |
+
"epoch": epoch,
|
| 929 |
+
"iter": i,
|
| 930 |
+
"avg_loss": avg_loss / (i + 1),
|
| 931 |
+
"avg_acc": total_correct / total_element * 100 if total_element != 0 else 0,
|
| 932 |
+
"loss": loss.item()
|
| 933 |
+
}
|
| 934 |
+
if i % self.log_freq == 0:
|
| 935 |
+
data_iter.write(str(post_fix))
|
| 936 |
+
|
| 937 |
+
precisions = precision_score(tlabels, plabels, average="weighted", zero_division=0)
|
| 938 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
| 939 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
| 940 |
+
cmatrix = confusion_matrix(tlabels, plabels)
|
| 941 |
+
end_time = time.time()
|
| 942 |
+
final_msg = {
|
| 943 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 944 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 945 |
+
"total_acc": total_correct * 100.0 / total_element,
|
| 946 |
+
"precisions": precisions,
|
| 947 |
+
"recalls": recalls,
|
| 948 |
+
"f1_scores": f1_scores,
|
| 949 |
+
# "confusion_matrix": f"{cmatrix}",
|
| 950 |
+
# "true_labels": f"{tlabels}",
|
| 951 |
+
# "predicted_labels": f"{plabels}",
|
| 952 |
+
"time_taken_from_start": end_time - self.start_time
|
| 953 |
+
}
|
| 954 |
+
print(final_msg)
|
| 955 |
+
f.close()
|
| 956 |
+
with open(self.log_folder_path+f"/log_{phase}_finetuned_info.txt", 'a') as f1:
|
| 957 |
+
sys.stdout = f1
|
| 958 |
+
final_msg = {
|
| 959 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 960 |
+
"confusion_matrix": f"{cmatrix}",
|
| 961 |
+
"true_labels": f"{tlabels if epoch == 0 else ''}",
|
| 962 |
+
"predicted_labels": f"{plabels}",
|
| 963 |
+
"probabilities": f"{probabs}",
|
| 964 |
+
"time_taken_from_start": end_time - self.start_time
|
| 965 |
+
}
|
| 966 |
+
print(final_msg)
|
| 967 |
+
f1.close()
|
| 968 |
+
sys.stdout = sys.__stdout__
|
| 969 |
+
sys.stdout = sys.__stdout__
|
| 970 |
+
|
| 971 |
+
if phase == "test":
|
| 972 |
+
=======
|
| 973 |
if train:
|
| 974 |
self.optim.zero_grad()
|
| 975 |
loss.backward()
|
|
|
|
| 1065 |
f.close()
|
| 1066 |
sys.stdout = sys.__stdout__
|
| 1067 |
if train:
|
| 1068 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 1069 |
self.save_model = False
|
| 1070 |
if self.avg_loss > (avg_loss / len(data_iter)):
|
| 1071 |
self.save_model = True
|
| 1072 |
self.avg_loss = (avg_loss / len(data_iter))
|
| 1073 |
+
<<<<<<< HEAD
|
| 1074 |
+
|
| 1075 |
+
def iteration_1(self, epoch_idx, data):
|
| 1076 |
+
try:
|
| 1077 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 1078 |
+
logits = self.model(data['input_ids'], data['segment_label'])
|
| 1079 |
+
# Ensure logits is a tensor, not a tuple
|
| 1080 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1081 |
+
loss = loss_fct(logits, data['labels'])
|
| 1082 |
+
|
| 1083 |
+
# Backpropagation and optimization
|
| 1084 |
+
self.optim.zero_grad()
|
| 1085 |
+
loss.backward()
|
| 1086 |
+
self.optim.step()
|
| 1087 |
+
|
| 1088 |
+
if self.log_freq > 0 and epoch_idx % self.log_freq == 0:
|
| 1089 |
+
print(f"Epoch {epoch_idx}: Loss = {loss.item()}")
|
| 1090 |
+
|
| 1091 |
+
return loss
|
| 1092 |
+
|
| 1093 |
+
except Exception as e:
|
| 1094 |
+
print(f"Error during iteration: {e}")
|
| 1095 |
+
raise
|
| 1096 |
+
|
| 1097 |
+
=======
|
| 1098 |
|
| 1099 |
# plt_test.show()
|
| 1100 |
# print("EP%d_%s, " % (epoch, str_code))
|
| 1101 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 1102 |
|
| 1103 |
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
| 1104 |
"""
|
|
|
|
| 1113 |
self.model.to(self.device)
|
| 1114 |
print("EP:%d Model Saved on:" % epoch, output_path)
|
| 1115 |
return output_path
|
| 1116 |
+
<<<<<<< HEAD
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
class BERTAttention:
|
| 1120 |
+
def __init__(self, bert: BERT, vocab_obj, train_dataloader: DataLoader, workspace_name=None, code=None, finetune_task=None, with_cuda=True):
|
| 1121 |
+
|
| 1122 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
| 1123 |
+
|
| 1124 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 1125 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 1126 |
+
print(with_cuda, cuda_condition, " Device used = ", self.device)
|
| 1127 |
+
self.bert = bert.to(self.device)
|
| 1128 |
+
|
| 1129 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
| 1130 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 1131 |
+
# self.bert = nn.DataParallel(self.bert, device_ids=available_gpus)
|
| 1132 |
+
|
| 1133 |
+
self.train_dataloader = train_dataloader
|
| 1134 |
+
self.workspace_name = workspace_name
|
| 1135 |
+
self.code = code
|
| 1136 |
+
self.finetune_task = finetune_task
|
| 1137 |
+
self.vocab_obj = vocab_obj
|
| 1138 |
+
|
| 1139 |
+
def getAttention(self):
|
| 1140 |
+
# self.log_file = f"{self.workspace_name}/logs/{self.code}/log_attention.txt"
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
labels = ['PercentChange', 'NumeratorQuantity2', 'NumeratorQuantity1', 'DenominatorQuantity1',
|
| 1144 |
+
'OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor',
|
| 1145 |
+
'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow',
|
| 1146 |
+
'ThirdRow', 'FinalAnswer','FinalAnswerDirection']
|
| 1147 |
+
df_all = pd.DataFrame(0.0, index=labels, columns=labels)
|
| 1148 |
+
# Setting the tqdm progress bar
|
| 1149 |
+
data_iter = tqdm.tqdm(enumerate(self.train_dataloader),
|
| 1150 |
+
desc="attention",
|
| 1151 |
+
total=len(self.train_dataloader),
|
| 1152 |
+
bar_format="{l_bar}{r_bar}")
|
| 1153 |
+
count = 0
|
| 1154 |
+
for i, data in data_iter:
|
| 1155 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 1156 |
+
a = self.bert.forward(data["bert_input"], data["segment_label"])
|
| 1157 |
+
non_zero = np.sum(data["segment_label"].cpu().detach().numpy())
|
| 1158 |
+
|
| 1159 |
+
# Last Transformer Layer
|
| 1160 |
+
last_layer = self.bert.attention_values[-1].transpose(1,0,2,3)
|
| 1161 |
+
# print(last_layer.shape)
|
| 1162 |
+
head, d_model, s, s = last_layer.shape
|
| 1163 |
+
|
| 1164 |
+
for d in range(d_model):
|
| 1165 |
+
seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
| 1166 |
+
# df_all = pd.DataFrame(0.0, index=seq_labels, columns=seq_labels)
|
| 1167 |
+
indices_to_choose = defaultdict(int)
|
| 1168 |
+
|
| 1169 |
+
for k,s in enumerate(seq_labels):
|
| 1170 |
+
if s in labels:
|
| 1171 |
+
indices_to_choose[s] = k
|
| 1172 |
+
indices_chosen = list(indices_to_choose.values())
|
| 1173 |
+
selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
| 1174 |
+
# print(len(seq_labels), len(selected_seq_labels))
|
| 1175 |
+
for h in range(head):
|
| 1176 |
+
# fig, ax = plt.subplots(figsize=(12, 12))
|
| 1177 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])#[1:non_zero-1]
|
| 1178 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
| 1179 |
+
# indices_to_choose = defaultdict(int)
|
| 1180 |
+
|
| 1181 |
+
# for k,s in enumerate(seq_labels):
|
| 1182 |
+
# if s in labels:
|
| 1183 |
+
# indices_to_choose[s] = k
|
| 1184 |
+
# indices_chosen = list(indices_to_choose.values())
|
| 1185 |
+
# selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
| 1186 |
+
# print(f"Chosen index: {seq_labels, indices_to_choose, indices_chosen, selected_seq_labels}")
|
| 1187 |
+
|
| 1188 |
+
df_cm = pd.DataFrame(last_layer[h][d][indices_chosen,:][:,indices_chosen], index = selected_seq_labels, columns = selected_seq_labels)
|
| 1189 |
+
df_all = df_all.add(df_cm, fill_value=0)
|
| 1190 |
+
count += 1
|
| 1191 |
+
|
| 1192 |
+
# df_cm = pd.DataFrame(last_layer[h][d][1:non_zero-1,:][:,1:non_zero-1], index=seq_labels, columns=seq_labels)
|
| 1193 |
+
# df_all = df_all.add(df_cm, fill_value=0)
|
| 1194 |
+
|
| 1195 |
+
# df_all = df_all.reindex(index=seq_labels, columns=seq_labels)
|
| 1196 |
+
# sns.heatmap(df_all, annot=False)
|
| 1197 |
+
# plt.title("Attentions") #Probabilities
|
| 1198 |
+
# plt.xlabel("Steps")
|
| 1199 |
+
# plt.ylabel("Steps")
|
| 1200 |
+
# plt.grid(True)
|
| 1201 |
+
# plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
| 1202 |
+
# plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores_over_[{h}]_head_n_data[{d}].png", bbox_inches='tight')
|
| 1203 |
+
# plt.show()
|
| 1204 |
+
# plt.close()
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
print(f"Count of total : {count, head * self.train_dataloader.dataset.len}")
|
| 1209 |
+
df_all = df_all.div(count) # head * self.train_dataloader.dataset.len
|
| 1210 |
+
df_all = df_all.reindex(index=labels, columns=labels)
|
| 1211 |
+
sns.heatmap(df_all, annot=False)
|
| 1212 |
+
plt.title("Attentions") #Probabilities
|
| 1213 |
+
plt.xlabel("Steps")
|
| 1214 |
+
plt.ylabel("Steps")
|
| 1215 |
+
plt.grid(True)
|
| 1216 |
+
plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
| 1217 |
+
plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores.png", bbox_inches='tight')
|
| 1218 |
+
plt.show()
|
| 1219 |
+
plt.close()
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
|
| 1224 |
+
=======
|
| 1225 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/reference_code/bert_reference_code.py
ADDED
|
@@ -0,0 +1,1622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch BERT model. """
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import warnings
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils.checkpoint
|
| 26 |
+
from torch import nn
|
| 27 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 28 |
+
|
| 29 |
+
from .activations import gelu, gelu_new, swish
|
| 30 |
+
from .configuration_bert import BertConfig
|
| 31 |
+
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
| 32 |
+
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
| 38 |
+
|
| 39 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 40 |
+
"bert-base-uncased",
|
| 41 |
+
"bert-large-uncased",
|
| 42 |
+
"bert-base-cased",
|
| 43 |
+
"bert-large-cased",
|
| 44 |
+
"bert-base-multilingual-uncased",
|
| 45 |
+
"bert-base-multilingual-cased",
|
| 46 |
+
"bert-base-chinese",
|
| 47 |
+
"bert-base-german-cased",
|
| 48 |
+
"bert-large-uncased-whole-word-masking",
|
| 49 |
+
"bert-large-cased-whole-word-masking",
|
| 50 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
| 51 |
+
"bert-large-cased-whole-word-masking-finetuned-squad",
|
| 52 |
+
"bert-base-cased-finetuned-mrpc",
|
| 53 |
+
"bert-base-german-dbmdz-cased",
|
| 54 |
+
"bert-base-german-dbmdz-uncased",
|
| 55 |
+
"cl-tohoku/bert-base-japanese",
|
| 56 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
| 57 |
+
"cl-tohoku/bert-base-japanese-char",
|
| 58 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
|
| 59 |
+
"TurkuNLP/bert-base-finnish-cased-v1",
|
| 60 |
+
"TurkuNLP/bert-base-finnish-uncased-v1",
|
| 61 |
+
"wietsedv/bert-base-dutch-cased",
|
| 62 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
| 67 |
+
""" Load tf checkpoints in a pytorch model.
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
import re
|
| 71 |
+
import numpy as np
|
| 72 |
+
import tensorflow as tf
|
| 73 |
+
except ImportError:
|
| 74 |
+
logger.error(
|
| 75 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 76 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 77 |
+
)
|
| 78 |
+
raise
|
| 79 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 80 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
| 81 |
+
# Load weights from TF model
|
| 82 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 83 |
+
names = []
|
| 84 |
+
arrays = []
|
| 85 |
+
for name, shape in init_vars:
|
| 86 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
| 87 |
+
array = tf.train.load_variable(tf_path, name)
|
| 88 |
+
names.append(name)
|
| 89 |
+
arrays.append(array)
|
| 90 |
+
|
| 91 |
+
for name, array in zip(names, arrays):
|
| 92 |
+
name = name.split("/")
|
| 93 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 94 |
+
# which are not required for using pretrained model
|
| 95 |
+
if any(
|
| 96 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
| 97 |
+
for n in name
|
| 98 |
+
):
|
| 99 |
+
logger.info("Skipping {}".format("/".join(name)))
|
| 100 |
+
continue
|
| 101 |
+
pointer = model
|
| 102 |
+
for m_name in name:
|
| 103 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
| 104 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
| 105 |
+
else:
|
| 106 |
+
scope_names = [m_name]
|
| 107 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
| 108 |
+
pointer = getattr(pointer, "weight")
|
| 109 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
| 110 |
+
pointer = getattr(pointer, "bias")
|
| 111 |
+
elif scope_names[0] == "output_weights":
|
| 112 |
+
pointer = getattr(pointer, "weight")
|
| 113 |
+
elif scope_names[0] == "squad":
|
| 114 |
+
pointer = getattr(pointer, "classifier")
|
| 115 |
+
else:
|
| 116 |
+
try:
|
| 117 |
+
pointer = getattr(pointer, scope_names[0])
|
| 118 |
+
except AttributeError:
|
| 119 |
+
logger.info("Skipping {}".format("/".join(name)))
|
| 120 |
+
continue
|
| 121 |
+
if len(scope_names) >= 2:
|
| 122 |
+
num = int(scope_names[1])
|
| 123 |
+
pointer = pointer[num]
|
| 124 |
+
if m_name[-11:] == "_embeddings":
|
| 125 |
+
pointer = getattr(pointer, "weight")
|
| 126 |
+
elif m_name == "kernel":
|
| 127 |
+
array = np.transpose(array)
|
| 128 |
+
try:
|
| 129 |
+
assert pointer.shape == array.shape
|
| 130 |
+
except AssertionError as e:
|
| 131 |
+
e.args += (pointer.shape, array.shape)
|
| 132 |
+
raise
|
| 133 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
| 134 |
+
pointer.data = torch.from_numpy(array)
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def mish(x):
|
| 139 |
+
return x * torch.tanh(nn.functional.softplus(x))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
BertLayerNorm = torch.nn.LayerNorm
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class BertEmbeddings(nn.Module):
|
| 149 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, config):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 155 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 156 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 157 |
+
|
| 158 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 159 |
+
# any TensorFlow checkpoint file
|
| 160 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 161 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 162 |
+
|
| 163 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
| 164 |
+
if input_ids is not None:
|
| 165 |
+
input_shape = input_ids.size()
|
| 166 |
+
else:
|
| 167 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 168 |
+
|
| 169 |
+
seq_length = input_shape[1]
|
| 170 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 171 |
+
if position_ids is None:
|
| 172 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
| 173 |
+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
| 174 |
+
if token_type_ids is None:
|
| 175 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 176 |
+
|
| 177 |
+
if inputs_embeds is None:
|
| 178 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 179 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 180 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 181 |
+
|
| 182 |
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
| 183 |
+
embeddings = self.LayerNorm(embeddings)
|
| 184 |
+
embeddings = self.dropout(embeddings)
|
| 185 |
+
return embeddings
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class BertSelfAttention(nn.Module):
|
| 189 |
+
def __init__(self, config):
|
| 190 |
+
super().__init__()
|
| 191 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 192 |
+
raise ValueError(
|
| 193 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 194 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.num_attention_heads = config.num_attention_heads
|
| 198 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 199 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 200 |
+
|
| 201 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 202 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 203 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 204 |
+
|
| 205 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 206 |
+
|
| 207 |
+
def transpose_for_scores(self, x):
|
| 208 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 209 |
+
x = x.view(*new_x_shape)
|
| 210 |
+
return x.permute(0, 2, 1, 3)
|
| 211 |
+
|
| 212 |
+
def forward(
|
| 213 |
+
self,
|
| 214 |
+
hidden_states,
|
| 215 |
+
attention_mask=None,
|
| 216 |
+
head_mask=None,
|
| 217 |
+
encoder_hidden_states=None,
|
| 218 |
+
encoder_attention_mask=None,
|
| 219 |
+
output_attentions=False,
|
| 220 |
+
):
|
| 221 |
+
mixed_query_layer = self.query(hidden_states)
|
| 222 |
+
|
| 223 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 224 |
+
# and values come from an encoder; the attention mask needs to be
|
| 225 |
+
# such that the encoder's padding tokens are not attended to.
|
| 226 |
+
if encoder_hidden_states is not None:
|
| 227 |
+
mixed_key_layer = self.key(encoder_hidden_states)
|
| 228 |
+
mixed_value_layer = self.value(encoder_hidden_states)
|
| 229 |
+
attention_mask = encoder_attention_mask
|
| 230 |
+
else:
|
| 231 |
+
mixed_key_layer = self.key(hidden_states)
|
| 232 |
+
mixed_value_layer = self.value(hidden_states)
|
| 233 |
+
|
| 234 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 235 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 236 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 237 |
+
|
| 238 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 239 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 240 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 241 |
+
if attention_mask is not None:
|
| 242 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 243 |
+
attention_scores = attention_scores + attention_mask
|
| 244 |
+
|
| 245 |
+
# Normalize the attention scores to probabilities.
|
| 246 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 247 |
+
|
| 248 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 249 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 250 |
+
attention_probs = self.dropout(attention_probs)
|
| 251 |
+
|
| 252 |
+
# Mask heads if we want to
|
| 253 |
+
if head_mask is not None:
|
| 254 |
+
attention_probs = attention_probs * head_mask
|
| 255 |
+
|
| 256 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 257 |
+
|
| 258 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 259 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 260 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 261 |
+
|
| 262 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 263 |
+
return outputs
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class BertSelfOutput(nn.Module):
|
| 267 |
+
def __init__(self, config):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 270 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 271 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 272 |
+
|
| 273 |
+
def forward(self, hidden_states, input_tensor):
|
| 274 |
+
hidden_states = self.dense(hidden_states)
|
| 275 |
+
hidden_states = self.dropout(hidden_states)
|
| 276 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 277 |
+
return hidden_states
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class BertAttention(nn.Module):
|
| 281 |
+
def __init__(self, config):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.self = BertSelfAttention(config)
|
| 284 |
+
self.output = BertSelfOutput(config)
|
| 285 |
+
self.pruned_heads = set()
|
| 286 |
+
|
| 287 |
+
def prune_heads(self, heads):
|
| 288 |
+
if len(heads) == 0:
|
| 289 |
+
return
|
| 290 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 291 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Prune linear layers
|
| 295 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 296 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 297 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 298 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 299 |
+
|
| 300 |
+
# Update hyper params and store pruned heads
|
| 301 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 302 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 303 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 304 |
+
|
| 305 |
+
def forward(
|
| 306 |
+
self,
|
| 307 |
+
hidden_states,
|
| 308 |
+
attention_mask=None,
|
| 309 |
+
head_mask=None,
|
| 310 |
+
encoder_hidden_states=None,
|
| 311 |
+
encoder_attention_mask=None,
|
| 312 |
+
output_attentions=False,
|
| 313 |
+
):
|
| 314 |
+
self_outputs = self.self(
|
| 315 |
+
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 319 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 320 |
+
return outputs
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class BertIntermediate(nn.Module):
|
| 324 |
+
def __init__(self, config):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 327 |
+
if isinstance(config.hidden_act, str):
|
| 328 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 329 |
+
else:
|
| 330 |
+
self.intermediate_act_fn = config.hidden_act
|
| 331 |
+
|
| 332 |
+
def forward(self, hidden_states):
|
| 333 |
+
hidden_states = self.dense(hidden_states)
|
| 334 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 335 |
+
return hidden_states
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class BertOutput(nn.Module):
|
| 339 |
+
def __init__(self, config):
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 342 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 343 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 344 |
+
|
| 345 |
+
def forward(self, hidden_states, input_tensor):
|
| 346 |
+
hidden_states = self.dense(hidden_states)
|
| 347 |
+
hidden_states = self.dropout(hidden_states)
|
| 348 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 349 |
+
return hidden_states
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class BertLayer(nn.Module):
|
| 353 |
+
def __init__(self, config):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.attention = BertAttention(config)
|
| 356 |
+
self.is_decoder = config.is_decoder
|
| 357 |
+
if self.is_decoder:
|
| 358 |
+
self.crossattention = BertAttention(config)
|
| 359 |
+
self.intermediate = BertIntermediate(config)
|
| 360 |
+
self.output = BertOutput(config)
|
| 361 |
+
|
| 362 |
+
def forward(
|
| 363 |
+
self,
|
| 364 |
+
hidden_states,
|
| 365 |
+
attention_mask=None,
|
| 366 |
+
head_mask=None,
|
| 367 |
+
encoder_hidden_states=None,
|
| 368 |
+
encoder_attention_mask=None,
|
| 369 |
+
output_attentions=False,
|
| 370 |
+
):
|
| 371 |
+
self_attention_outputs = self.attention(
|
| 372 |
+
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
|
| 373 |
+
)
|
| 374 |
+
attention_output = self_attention_outputs[0]
|
| 375 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 376 |
+
|
| 377 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 378 |
+
cross_attention_outputs = self.crossattention(
|
| 379 |
+
attention_output,
|
| 380 |
+
attention_mask,
|
| 381 |
+
head_mask,
|
| 382 |
+
encoder_hidden_states,
|
| 383 |
+
encoder_attention_mask,
|
| 384 |
+
output_attentions,
|
| 385 |
+
)
|
| 386 |
+
attention_output = cross_attention_outputs[0]
|
| 387 |
+
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
| 388 |
+
|
| 389 |
+
intermediate_output = self.intermediate(attention_output)
|
| 390 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 391 |
+
outputs = (layer_output,) + outputs
|
| 392 |
+
return outputs
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class BertEncoder(nn.Module):
|
| 396 |
+
def __init__(self, config):
|
| 397 |
+
super().__init__()
|
| 398 |
+
self.config = config
|
| 399 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
| 400 |
+
|
| 401 |
+
def forward(
|
| 402 |
+
self,
|
| 403 |
+
hidden_states,
|
| 404 |
+
attention_mask=None,
|
| 405 |
+
head_mask=None,
|
| 406 |
+
encoder_hidden_states=None,
|
| 407 |
+
encoder_attention_mask=None,
|
| 408 |
+
output_attentions=False,
|
| 409 |
+
output_hidden_states=False,
|
| 410 |
+
):
|
| 411 |
+
all_hidden_states = ()
|
| 412 |
+
all_attentions = ()
|
| 413 |
+
for i, layer_module in enumerate(self.layer):
|
| 414 |
+
if output_hidden_states:
|
| 415 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 416 |
+
|
| 417 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
| 418 |
+
|
| 419 |
+
def create_custom_forward(module):
|
| 420 |
+
def custom_forward(*inputs):
|
| 421 |
+
return module(*inputs, output_attentions)
|
| 422 |
+
|
| 423 |
+
return custom_forward
|
| 424 |
+
|
| 425 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 426 |
+
create_custom_forward(layer_module),
|
| 427 |
+
hidden_states,
|
| 428 |
+
attention_mask,
|
| 429 |
+
head_mask[i],
|
| 430 |
+
encoder_hidden_states,
|
| 431 |
+
encoder_attention_mask,
|
| 432 |
+
)
|
| 433 |
+
else:
|
| 434 |
+
layer_outputs = layer_module(
|
| 435 |
+
hidden_states,
|
| 436 |
+
attention_mask,
|
| 437 |
+
head_mask[i],
|
| 438 |
+
encoder_hidden_states,
|
| 439 |
+
encoder_attention_mask,
|
| 440 |
+
output_attentions,
|
| 441 |
+
)
|
| 442 |
+
hidden_states = layer_outputs[0]
|
| 443 |
+
|
| 444 |
+
if output_attentions:
|
| 445 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 446 |
+
|
| 447 |
+
# Add last layer
|
| 448 |
+
if output_hidden_states:
|
| 449 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 450 |
+
|
| 451 |
+
outputs = (hidden_states,)
|
| 452 |
+
if output_hidden_states:
|
| 453 |
+
outputs = outputs + (all_hidden_states,)
|
| 454 |
+
if output_attentions:
|
| 455 |
+
outputs = outputs + (all_attentions,)
|
| 456 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class BertPooler(nn.Module):
|
| 460 |
+
def __init__(self, config):
|
| 461 |
+
super().__init__()
|
| 462 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 463 |
+
self.activation = nn.Tanh()
|
| 464 |
+
|
| 465 |
+
def forward(self, hidden_states):
|
| 466 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 467 |
+
# to the first token.
|
| 468 |
+
first_token_tensor = hidden_states[:, 0]
|
| 469 |
+
pooled_output = self.dense(first_token_tensor)
|
| 470 |
+
pooled_output = self.activation(pooled_output)
|
| 471 |
+
return pooled_output
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 475 |
+
def __init__(self, config):
|
| 476 |
+
super().__init__()
|
| 477 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 478 |
+
if isinstance(config.hidden_act, str):
|
| 479 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 480 |
+
else:
|
| 481 |
+
self.transform_act_fn = config.hidden_act
|
| 482 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 483 |
+
|
| 484 |
+
def forward(self, hidden_states):
|
| 485 |
+
hidden_states = self.dense(hidden_states)
|
| 486 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 487 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 488 |
+
return hidden_states
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class BertLMPredictionHead(nn.Module):
|
| 492 |
+
def __init__(self, config):
|
| 493 |
+
super().__init__()
|
| 494 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 495 |
+
|
| 496 |
+
# The output weights are the same as the input embeddings, but there is
|
| 497 |
+
# an output-only bias for each token.
|
| 498 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 499 |
+
|
| 500 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 501 |
+
|
| 502 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 503 |
+
self.decoder.bias = self.bias
|
| 504 |
+
|
| 505 |
+
def forward(self, hidden_states):
|
| 506 |
+
hidden_states = self.transform(hidden_states)
|
| 507 |
+
hidden_states = self.decoder(hidden_states)
|
| 508 |
+
return hidden_states
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class BertOnlyMLMHead(nn.Module):
|
| 512 |
+
def __init__(self, config):
|
| 513 |
+
super().__init__()
|
| 514 |
+
self.predictions = BertLMPredictionHead(config)
|
| 515 |
+
|
| 516 |
+
def forward(self, sequence_output):
|
| 517 |
+
prediction_scores = self.predictions(sequence_output)
|
| 518 |
+
return prediction_scores
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class BertOnlyNSPHead(nn.Module):
|
| 522 |
+
def __init__(self, config):
|
| 523 |
+
super().__init__()
|
| 524 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 525 |
+
|
| 526 |
+
def forward(self, pooled_output):
|
| 527 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 528 |
+
return seq_relationship_score
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class BertPreTrainingHeads(nn.Module):
|
| 532 |
+
def __init__(self, config):
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.predictions = BertLMPredictionHead(config)
|
| 535 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 536 |
+
|
| 537 |
+
def forward(self, sequence_output, pooled_output):
|
| 538 |
+
prediction_scores = self.predictions(sequence_output)
|
| 539 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 540 |
+
return prediction_scores, seq_relationship_score
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 544 |
+
""" An abstract class to handle weights initialization and
|
| 545 |
+
a simple interface for downloading and loading pretrained models.
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
config_class = BertConfig
|
| 549 |
+
load_tf_weights = load_tf_weights_in_bert
|
| 550 |
+
base_model_prefix = "bert"
|
| 551 |
+
|
| 552 |
+
def _init_weights(self, module):
|
| 553 |
+
""" Initialize the weights """
|
| 554 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 555 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 556 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 557 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 558 |
+
elif isinstance(module, BertLayerNorm):
|
| 559 |
+
module.bias.data.zero_()
|
| 560 |
+
module.weight.data.fill_(1.0)
|
| 561 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 562 |
+
module.bias.data.zero_()
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
BERT_START_DOCSTRING = r"""
|
| 566 |
+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
| 567 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
| 568 |
+
usage and behavior.
|
| 569 |
+
|
| 570 |
+
Parameters:
|
| 571 |
+
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
| 572 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
| 573 |
+
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
| 574 |
+
"""
|
| 575 |
+
|
| 576 |
+
BERT_INPUTS_DOCSTRING = r"""
|
| 577 |
+
Args:
|
| 578 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
| 579 |
+
Indices of input sequence tokens in the vocabulary.
|
| 580 |
+
|
| 581 |
+
Indices can be obtained using :class:`transformers.BertTokenizer`.
|
| 582 |
+
See :func:`transformers.PreTrainedTokenizer.encode` and
|
| 583 |
+
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
| 584 |
+
|
| 585 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
| 586 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
| 587 |
+
Mask to avoid performing attention on padding token indices.
|
| 588 |
+
Mask values selected in ``[0, 1]``:
|
| 589 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
| 590 |
+
|
| 591 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
| 592 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
| 593 |
+
Segment token indices to indicate first and second portions of the inputs.
|
| 594 |
+
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
| 595 |
+
corresponds to a `sentence B` token
|
| 596 |
+
|
| 597 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
| 598 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
| 599 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
| 600 |
+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
| 601 |
+
|
| 602 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
| 603 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
| 604 |
+
Mask to nullify selected heads of the self-attention modules.
|
| 605 |
+
Mask values selected in ``[0, 1]``:
|
| 606 |
+
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
| 607 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
| 608 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
| 609 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 610 |
+
than the model's internal embedding lookup matrix.
|
| 611 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
| 612 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 613 |
+
if the model is configured as a decoder.
|
| 614 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 615 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
| 616 |
+
is used in the cross-attention if the model is configured as a decoder.
|
| 617 |
+
Mask values selected in ``[0, 1]``:
|
| 618 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
| 619 |
+
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
| 620 |
+
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
| 621 |
+
"""
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
[DOCS]
|
| 626 |
+
@add_start_docstrings(
|
| 627 |
+
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 628 |
+
BERT_START_DOCSTRING,
|
| 629 |
+
)
|
| 630 |
+
class BertModel(BertPreTrainedModel):
|
| 631 |
+
"""
|
| 632 |
+
|
| 633 |
+
The model can behave as an encoder (with only self-attention) as well
|
| 634 |
+
as a decoder, in which case a layer of cross-attention is added between
|
| 635 |
+
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
| 636 |
+
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 637 |
+
|
| 638 |
+
To behave as an decoder the model needs to be initialized with the
|
| 639 |
+
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
|
| 640 |
+
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
|
| 641 |
+
|
| 642 |
+
.. _`Attention is all you need`:
|
| 643 |
+
https://arxiv.org/abs/1706.03762
|
| 644 |
+
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
def __init__(self, config):
|
| 648 |
+
super().__init__(config)
|
| 649 |
+
self.config = config
|
| 650 |
+
|
| 651 |
+
self.embeddings = BertEmbeddings(config)
|
| 652 |
+
self.encoder = BertEncoder(config)
|
| 653 |
+
self.pooler = BertPooler(config)
|
| 654 |
+
|
| 655 |
+
self.init_weights()
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
[DOCS]
|
| 659 |
+
def get_input_embeddings(self):
|
| 660 |
+
return self.embeddings.word_embeddings
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
[DOCS]
|
| 665 |
+
def set_input_embeddings(self, value):
|
| 666 |
+
self.embeddings.word_embeddings = value
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def _prune_heads(self, heads_to_prune):
|
| 670 |
+
""" Prunes heads of the model.
|
| 671 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 672 |
+
See base class PreTrainedModel
|
| 673 |
+
"""
|
| 674 |
+
for layer, heads in heads_to_prune.items():
|
| 675 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
[DOCS]
|
| 679 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 680 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 681 |
+
def forward(
|
| 682 |
+
self,
|
| 683 |
+
input_ids=None,
|
| 684 |
+
attention_mask=None,
|
| 685 |
+
token_type_ids=None,
|
| 686 |
+
position_ids=None,
|
| 687 |
+
head_mask=None,
|
| 688 |
+
inputs_embeds=None,
|
| 689 |
+
encoder_hidden_states=None,
|
| 690 |
+
encoder_attention_mask=None,
|
| 691 |
+
output_attentions=None,
|
| 692 |
+
output_hidden_states=None,
|
| 693 |
+
):
|
| 694 |
+
r"""
|
| 695 |
+
Return:
|
| 696 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 697 |
+
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
| 698 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 699 |
+
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
| 700 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
| 701 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
| 702 |
+
layer weights are trained from the next sentence prediction (classification)
|
| 703 |
+
objective during pre-training.
|
| 704 |
+
|
| 705 |
+
This output is usually *not* a good summary
|
| 706 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
| 707 |
+
the sequence of hidden-states for the whole input sequence.
|
| 708 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 709 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 710 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 711 |
+
|
| 712 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 713 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 714 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 715 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 716 |
+
|
| 717 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 718 |
+
heads.
|
| 719 |
+
"""
|
| 720 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 721 |
+
output_hidden_states = (
|
| 722 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 726 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 727 |
+
elif input_ids is not None:
|
| 728 |
+
input_shape = input_ids.size()
|
| 729 |
+
elif inputs_embeds is not None:
|
| 730 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 731 |
+
else:
|
| 732 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 733 |
+
|
| 734 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 735 |
+
|
| 736 |
+
if attention_mask is None:
|
| 737 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 738 |
+
if token_type_ids is None:
|
| 739 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 740 |
+
|
| 741 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 742 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 743 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
| 744 |
+
|
| 745 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
| 746 |
+
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
| 747 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
| 748 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 749 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 750 |
+
if encoder_attention_mask is None:
|
| 751 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 752 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 753 |
+
else:
|
| 754 |
+
encoder_extended_attention_mask = None
|
| 755 |
+
|
| 756 |
+
# Prepare head mask if needed
|
| 757 |
+
# 1.0 in head_mask indicate we keep the head
|
| 758 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 759 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 760 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 761 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 762 |
+
|
| 763 |
+
embedding_output = self.embeddings(
|
| 764 |
+
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
| 765 |
+
)
|
| 766 |
+
encoder_outputs = self.encoder(
|
| 767 |
+
embedding_output,
|
| 768 |
+
attention_mask=extended_attention_mask,
|
| 769 |
+
head_mask=head_mask,
|
| 770 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 771 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 772 |
+
output_attentions=output_attentions,
|
| 773 |
+
output_hidden_states=output_hidden_states,
|
| 774 |
+
)
|
| 775 |
+
sequence_output = encoder_outputs[0]
|
| 776 |
+
pooled_output = self.pooler(sequence_output)
|
| 777 |
+
|
| 778 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
| 779 |
+
1:
|
| 780 |
+
] # add hidden_states and attentions if they are here
|
| 781 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
[DOCS]
|
| 787 |
+
@add_start_docstrings(
|
| 788 |
+
"""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
|
| 789 |
+
a `next sentence prediction (classification)` head. """,
|
| 790 |
+
BERT_START_DOCSTRING,
|
| 791 |
+
)
|
| 792 |
+
class BertForPreTraining(BertPreTrainedModel):
|
| 793 |
+
def __init__(self, config):
|
| 794 |
+
super().__init__(config)
|
| 795 |
+
|
| 796 |
+
self.bert = BertModel(config)
|
| 797 |
+
self.cls = BertPreTrainingHeads(config)
|
| 798 |
+
|
| 799 |
+
self.init_weights()
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
[DOCS]
|
| 803 |
+
def get_output_embeddings(self):
|
| 804 |
+
return self.cls.predictions.decoder
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
[DOCS]
|
| 809 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 810 |
+
def forward(
|
| 811 |
+
self,
|
| 812 |
+
input_ids=None,
|
| 813 |
+
attention_mask=None,
|
| 814 |
+
token_type_ids=None,
|
| 815 |
+
position_ids=None,
|
| 816 |
+
head_mask=None,
|
| 817 |
+
inputs_embeds=None,
|
| 818 |
+
labels=None,
|
| 819 |
+
next_sentence_label=None,
|
| 820 |
+
output_attentions=None,
|
| 821 |
+
output_hidden_states=None,
|
| 822 |
+
**kwargs
|
| 823 |
+
):
|
| 824 |
+
r"""
|
| 825 |
+
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
| 826 |
+
Labels for computing the masked language modeling loss.
|
| 827 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 828 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 829 |
+
in ``[0, ..., config.vocab_size]``
|
| 830 |
+
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
|
| 831 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
|
| 832 |
+
Indices should be in ``[0, 1]``.
|
| 833 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
| 834 |
+
``1`` indicates sequence B is a random sequence.
|
| 835 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
| 836 |
+
Used to hide legacy arguments that have been deprecated.
|
| 837 |
+
|
| 838 |
+
Returns:
|
| 839 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 840 |
+
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 841 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
|
| 842 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
| 843 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 844 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
| 845 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False
|
| 846 |
+
continuation before SoftMax).
|
| 847 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 848 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 849 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 850 |
+
|
| 851 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 852 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 853 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 854 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 855 |
+
|
| 856 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 857 |
+
heads.
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
Examples::
|
| 861 |
+
|
| 862 |
+
>>> from transformers import BertTokenizer, BertForPreTraining
|
| 863 |
+
>>> import torch
|
| 864 |
+
|
| 865 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 866 |
+
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
|
| 867 |
+
|
| 868 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 869 |
+
>>> outputs = model(**inputs)
|
| 870 |
+
|
| 871 |
+
>>> prediction_scores, seq_relationship_scores = outputs[:2]
|
| 872 |
+
|
| 873 |
+
"""
|
| 874 |
+
if "masked_lm_labels" in kwargs:
|
| 875 |
+
warnings.warn(
|
| 876 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
| 877 |
+
DeprecationWarning,
|
| 878 |
+
)
|
| 879 |
+
labels = kwargs.pop("masked_lm_labels")
|
| 880 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
| 881 |
+
|
| 882 |
+
outputs = self.bert(
|
| 883 |
+
input_ids,
|
| 884 |
+
attention_mask=attention_mask,
|
| 885 |
+
token_type_ids=token_type_ids,
|
| 886 |
+
position_ids=position_ids,
|
| 887 |
+
head_mask=head_mask,
|
| 888 |
+
inputs_embeds=inputs_embeds,
|
| 889 |
+
output_attentions=output_attentions,
|
| 890 |
+
output_hidden_states=output_hidden_states,
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
sequence_output, pooled_output = outputs[:2]
|
| 894 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 895 |
+
|
| 896 |
+
outputs = (prediction_scores, seq_relationship_score,) + outputs[
|
| 897 |
+
2:
|
| 898 |
+
] # add hidden states and attention if they are here
|
| 899 |
+
|
| 900 |
+
if labels is not None and next_sentence_label is not None:
|
| 901 |
+
loss_fct = CrossEntropyLoss()
|
| 902 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 903 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 904 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
| 905 |
+
outputs = (total_loss,) + outputs
|
| 906 |
+
|
| 907 |
+
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@add_start_docstrings(
|
| 912 |
+
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
|
| 913 |
+
)
|
| 914 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 915 |
+
def __init__(self, config):
|
| 916 |
+
super().__init__(config)
|
| 917 |
+
assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
|
| 918 |
+
|
| 919 |
+
self.bert = BertModel(config)
|
| 920 |
+
self.cls = BertOnlyMLMHead(config)
|
| 921 |
+
|
| 922 |
+
self.init_weights()
|
| 923 |
+
|
| 924 |
+
def get_output_embeddings(self):
|
| 925 |
+
return self.cls.predictions.decoder
|
| 926 |
+
|
| 927 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 928 |
+
def forward(
|
| 929 |
+
self,
|
| 930 |
+
input_ids=None,
|
| 931 |
+
attention_mask=None,
|
| 932 |
+
token_type_ids=None,
|
| 933 |
+
position_ids=None,
|
| 934 |
+
head_mask=None,
|
| 935 |
+
inputs_embeds=None,
|
| 936 |
+
labels=None,
|
| 937 |
+
encoder_hidden_states=None,
|
| 938 |
+
encoder_attention_mask=None,
|
| 939 |
+
output_attentions=None,
|
| 940 |
+
output_hidden_states=None,
|
| 941 |
+
**kwargs
|
| 942 |
+
):
|
| 943 |
+
r"""
|
| 944 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 945 |
+
Labels for computing the left-to-right language modeling loss (next word prediction).
|
| 946 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 947 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 948 |
+
in ``[0, ..., config.vocab_size]``
|
| 949 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
| 950 |
+
Used to hide legacy arguments that have been deprecated.
|
| 951 |
+
|
| 952 |
+
Returns:
|
| 953 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 954 |
+
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
| 955 |
+
Next token prediction loss.
|
| 956 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
| 957 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 958 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 959 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 960 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 961 |
+
|
| 962 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 963 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 964 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 965 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 966 |
+
|
| 967 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 968 |
+
heads.
|
| 969 |
+
|
| 970 |
+
Example::
|
| 971 |
+
|
| 972 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 973 |
+
>>> import torch
|
| 974 |
+
|
| 975 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 976 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 977 |
+
>>> config.is_decoder = True
|
| 978 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 979 |
+
|
| 980 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 981 |
+
>>> outputs = model(**inputs)
|
| 982 |
+
|
| 983 |
+
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
| 984 |
+
"""
|
| 985 |
+
|
| 986 |
+
outputs = self.bert(
|
| 987 |
+
input_ids,
|
| 988 |
+
attention_mask=attention_mask,
|
| 989 |
+
token_type_ids=token_type_ids,
|
| 990 |
+
position_ids=position_ids,
|
| 991 |
+
head_mask=head_mask,
|
| 992 |
+
inputs_embeds=inputs_embeds,
|
| 993 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 994 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 995 |
+
output_attentions=output_attentions,
|
| 996 |
+
output_hidden_states=output_hidden_states,
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
sequence_output = outputs[0]
|
| 1000 |
+
prediction_scores = self.cls(sequence_output)
|
| 1001 |
+
|
| 1002 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
| 1003 |
+
|
| 1004 |
+
if labels is not None:
|
| 1005 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 1006 |
+
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 1007 |
+
labels = labels[:, 1:].contiguous()
|
| 1008 |
+
loss_fct = CrossEntropyLoss()
|
| 1009 |
+
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1010 |
+
outputs = (ltr_lm_loss,) + outputs
|
| 1011 |
+
|
| 1012 |
+
return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
| 1013 |
+
|
| 1014 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 1015 |
+
input_shape = input_ids.shape
|
| 1016 |
+
|
| 1017 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 1018 |
+
if attention_mask is None:
|
| 1019 |
+
attention_mask = input_ids.new_ones(input_shape)
|
| 1020 |
+
|
| 1021 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
[DOCS]
|
| 1026 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
| 1027 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
| 1028 |
+
def __init__(self, config):
|
| 1029 |
+
super().__init__(config)
|
| 1030 |
+
assert (
|
| 1031 |
+
not config.is_decoder
|
| 1032 |
+
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
|
| 1033 |
+
|
| 1034 |
+
self.bert = BertModel(config)
|
| 1035 |
+
self.cls = BertOnlyMLMHead(config)
|
| 1036 |
+
|
| 1037 |
+
self.init_weights()
|
| 1038 |
+
|
| 1039 |
+
|
| 1040 |
+
[DOCS]
|
| 1041 |
+
def get_output_embeddings(self):
|
| 1042 |
+
return self.cls.predictions.decoder
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
[DOCS]
|
| 1047 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1048 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1049 |
+
def forward(
|
| 1050 |
+
self,
|
| 1051 |
+
input_ids=None,
|
| 1052 |
+
attention_mask=None,
|
| 1053 |
+
token_type_ids=None,
|
| 1054 |
+
position_ids=None,
|
| 1055 |
+
head_mask=None,
|
| 1056 |
+
inputs_embeds=None,
|
| 1057 |
+
labels=None,
|
| 1058 |
+
encoder_hidden_states=None,
|
| 1059 |
+
encoder_attention_mask=None,
|
| 1060 |
+
output_attentions=None,
|
| 1061 |
+
output_hidden_states=None,
|
| 1062 |
+
**kwargs
|
| 1063 |
+
):
|
| 1064 |
+
r"""
|
| 1065 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 1066 |
+
Labels for computing the masked language modeling loss.
|
| 1067 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 1068 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 1069 |
+
in ``[0, ..., config.vocab_size]``
|
| 1070 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
| 1071 |
+
Used to hide legacy arguments that have been deprecated.
|
| 1072 |
+
|
| 1073 |
+
Returns:
|
| 1074 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1075 |
+
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 1076 |
+
Masked language modeling loss.
|
| 1077 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
| 1078 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 1079 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1080 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1081 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1082 |
+
|
| 1083 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1084 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1085 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1086 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1087 |
+
|
| 1088 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1089 |
+
heads.
|
| 1090 |
+
"""
|
| 1091 |
+
if "masked_lm_labels" in kwargs:
|
| 1092 |
+
warnings.warn(
|
| 1093 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
| 1094 |
+
DeprecationWarning,
|
| 1095 |
+
)
|
| 1096 |
+
labels = kwargs.pop("masked_lm_labels")
|
| 1097 |
+
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
| 1098 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
| 1099 |
+
|
| 1100 |
+
outputs = self.bert(
|
| 1101 |
+
input_ids,
|
| 1102 |
+
attention_mask=attention_mask,
|
| 1103 |
+
token_type_ids=token_type_ids,
|
| 1104 |
+
position_ids=position_ids,
|
| 1105 |
+
head_mask=head_mask,
|
| 1106 |
+
inputs_embeds=inputs_embeds,
|
| 1107 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1108 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1109 |
+
output_attentions=output_attentions,
|
| 1110 |
+
output_hidden_states=output_hidden_states,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
sequence_output = outputs[0]
|
| 1114 |
+
prediction_scores = self.cls(sequence_output)
|
| 1115 |
+
|
| 1116 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
| 1117 |
+
|
| 1118 |
+
if labels is not None:
|
| 1119 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1120 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1121 |
+
outputs = (masked_lm_loss,) + outputs
|
| 1122 |
+
|
| 1123 |
+
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 1127 |
+
input_shape = input_ids.shape
|
| 1128 |
+
effective_batch_size = input_shape[0]
|
| 1129 |
+
|
| 1130 |
+
# add a dummy token
|
| 1131 |
+
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
|
| 1132 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
| 1133 |
+
dummy_token = torch.full(
|
| 1134 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
| 1135 |
+
)
|
| 1136 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 1137 |
+
|
| 1138 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
|
| 1143 |
+
[DOCS]
|
| 1144 |
+
@add_start_docstrings(
|
| 1145 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
|
| 1146 |
+
)
|
| 1147 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
| 1148 |
+
def __init__(self, config):
|
| 1149 |
+
super().__init__(config)
|
| 1150 |
+
|
| 1151 |
+
self.bert = BertModel(config)
|
| 1152 |
+
self.cls = BertOnlyNSPHead(config)
|
| 1153 |
+
|
| 1154 |
+
self.init_weights()
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
[DOCS]
|
| 1158 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1159 |
+
def forward(
|
| 1160 |
+
self,
|
| 1161 |
+
input_ids=None,
|
| 1162 |
+
attention_mask=None,
|
| 1163 |
+
token_type_ids=None,
|
| 1164 |
+
position_ids=None,
|
| 1165 |
+
head_mask=None,
|
| 1166 |
+
inputs_embeds=None,
|
| 1167 |
+
next_sentence_label=None,
|
| 1168 |
+
output_attentions=None,
|
| 1169 |
+
output_hidden_states=None,
|
| 1170 |
+
):
|
| 1171 |
+
r"""
|
| 1172 |
+
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1173 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
| 1174 |
+
Indices should be in ``[0, 1]``.
|
| 1175 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
| 1176 |
+
``1`` indicates sequence B is a random sequence.
|
| 1177 |
+
|
| 1178 |
+
Returns:
|
| 1179 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1180 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
|
| 1181 |
+
Next sequence prediction (classification) loss.
|
| 1182 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
| 1183 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
|
| 1184 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1185 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1186 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1187 |
+
|
| 1188 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1189 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1190 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1191 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1192 |
+
|
| 1193 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1194 |
+
heads.
|
| 1195 |
+
|
| 1196 |
+
Examples::
|
| 1197 |
+
|
| 1198 |
+
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
|
| 1199 |
+
>>> import torch
|
| 1200 |
+
|
| 1201 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 1202 |
+
>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
| 1203 |
+
|
| 1204 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1205 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
| 1206 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
| 1207 |
+
|
| 1208 |
+
>>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
| 1209 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
| 1210 |
+
"""
|
| 1211 |
+
|
| 1212 |
+
outputs = self.bert(
|
| 1213 |
+
input_ids,
|
| 1214 |
+
attention_mask=attention_mask,
|
| 1215 |
+
token_type_ids=token_type_ids,
|
| 1216 |
+
position_ids=position_ids,
|
| 1217 |
+
head_mask=head_mask,
|
| 1218 |
+
inputs_embeds=inputs_embeds,
|
| 1219 |
+
output_attentions=output_attentions,
|
| 1220 |
+
output_hidden_states=output_hidden_states,
|
| 1221 |
+
)
|
| 1222 |
+
|
| 1223 |
+
pooled_output = outputs[1]
|
| 1224 |
+
|
| 1225 |
+
seq_relationship_score = self.cls(pooled_output)
|
| 1226 |
+
|
| 1227 |
+
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
| 1228 |
+
if next_sentence_label is not None:
|
| 1229 |
+
loss_fct = CrossEntropyLoss()
|
| 1230 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 1231 |
+
outputs = (next_sentence_loss,) + outputs
|
| 1232 |
+
|
| 1233 |
+
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
|
| 1237 |
+
|
| 1238 |
+
[DOCS]
|
| 1239 |
+
@add_start_docstrings(
|
| 1240 |
+
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
| 1241 |
+
the pooled output) e.g. for GLUE tasks. """,
|
| 1242 |
+
BERT_START_DOCSTRING,
|
| 1243 |
+
)
|
| 1244 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
| 1245 |
+
def __init__(self, config):
|
| 1246 |
+
super().__init__(config)
|
| 1247 |
+
self.num_labels = config.num_labels
|
| 1248 |
+
|
| 1249 |
+
self.bert = BertModel(config)
|
| 1250 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1251 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1252 |
+
|
| 1253 |
+
self.init_weights()
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
[DOCS]
|
| 1257 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1258 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1259 |
+
def forward(
|
| 1260 |
+
self,
|
| 1261 |
+
input_ids=None,
|
| 1262 |
+
attention_mask=None,
|
| 1263 |
+
token_type_ids=None,
|
| 1264 |
+
position_ids=None,
|
| 1265 |
+
head_mask=None,
|
| 1266 |
+
inputs_embeds=None,
|
| 1267 |
+
labels=None,
|
| 1268 |
+
output_attentions=None,
|
| 1269 |
+
output_hidden_states=None,
|
| 1270 |
+
):
|
| 1271 |
+
r"""
|
| 1272 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1273 |
+
Labels for computing the sequence classification/regression loss.
|
| 1274 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
| 1275 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
| 1276 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1277 |
+
|
| 1278 |
+
Returns:
|
| 1279 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1280 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
| 1281 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 1282 |
+
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
| 1283 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
| 1284 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1285 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1286 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1287 |
+
|
| 1288 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1289 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1290 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1291 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1292 |
+
|
| 1293 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1294 |
+
heads.
|
| 1295 |
+
"""
|
| 1296 |
+
|
| 1297 |
+
outputs = self.bert(
|
| 1298 |
+
input_ids,
|
| 1299 |
+
attention_mask=attention_mask,
|
| 1300 |
+
token_type_ids=token_type_ids,
|
| 1301 |
+
position_ids=position_ids,
|
| 1302 |
+
head_mask=head_mask,
|
| 1303 |
+
inputs_embeds=inputs_embeds,
|
| 1304 |
+
output_attentions=output_attentions,
|
| 1305 |
+
output_hidden_states=output_hidden_states,
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
pooled_output = outputs[1]
|
| 1309 |
+
|
| 1310 |
+
pooled_output = self.dropout(pooled_output)
|
| 1311 |
+
logits = self.classifier(pooled_output)
|
| 1312 |
+
|
| 1313 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1314 |
+
|
| 1315 |
+
if labels is not None:
|
| 1316 |
+
if self.num_labels == 1:
|
| 1317 |
+
# We are doing regression
|
| 1318 |
+
loss_fct = MSELoss()
|
| 1319 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
| 1320 |
+
else:
|
| 1321 |
+
loss_fct = CrossEntropyLoss()
|
| 1322 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1323 |
+
outputs = (loss,) + outputs
|
| 1324 |
+
|
| 1325 |
+
return outputs # (loss), logits, (hidden_states), (attentions)
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
|
| 1330 |
+
[DOCS]
|
| 1331 |
+
@add_start_docstrings(
|
| 1332 |
+
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
| 1333 |
+
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
| 1334 |
+
BERT_START_DOCSTRING,
|
| 1335 |
+
)
|
| 1336 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
| 1337 |
+
def __init__(self, config):
|
| 1338 |
+
super().__init__(config)
|
| 1339 |
+
|
| 1340 |
+
self.bert = BertModel(config)
|
| 1341 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1342 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1343 |
+
|
| 1344 |
+
self.init_weights()
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
[DOCS]
|
| 1348 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
| 1349 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1350 |
+
def forward(
|
| 1351 |
+
self,
|
| 1352 |
+
input_ids=None,
|
| 1353 |
+
attention_mask=None,
|
| 1354 |
+
token_type_ids=None,
|
| 1355 |
+
position_ids=None,
|
| 1356 |
+
head_mask=None,
|
| 1357 |
+
inputs_embeds=None,
|
| 1358 |
+
labels=None,
|
| 1359 |
+
output_attentions=None,
|
| 1360 |
+
output_hidden_states=None,
|
| 1361 |
+
):
|
| 1362 |
+
r"""
|
| 1363 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1364 |
+
Labels for computing the multiple choice classification loss.
|
| 1365 |
+
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
|
| 1366 |
+
of the input tensors. (see `input_ids` above)
|
| 1367 |
+
|
| 1368 |
+
Returns:
|
| 1369 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1370 |
+
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
|
| 1371 |
+
Classification loss.
|
| 1372 |
+
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
| 1373 |
+
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
| 1374 |
+
|
| 1375 |
+
Classification scores (before SoftMax).
|
| 1376 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1377 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1378 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1379 |
+
|
| 1380 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1381 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1382 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1383 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1384 |
+
|
| 1385 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1386 |
+
heads.
|
| 1387 |
+
"""
|
| 1388 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1389 |
+
|
| 1390 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1391 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1392 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1393 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1394 |
+
inputs_embeds = (
|
| 1395 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1396 |
+
if inputs_embeds is not None
|
| 1397 |
+
else None
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
outputs = self.bert(
|
| 1401 |
+
input_ids,
|
| 1402 |
+
attention_mask=attention_mask,
|
| 1403 |
+
token_type_ids=token_type_ids,
|
| 1404 |
+
position_ids=position_ids,
|
| 1405 |
+
head_mask=head_mask,
|
| 1406 |
+
inputs_embeds=inputs_embeds,
|
| 1407 |
+
output_attentions=output_attentions,
|
| 1408 |
+
output_hidden_states=output_hidden_states,
|
| 1409 |
+
)
|
| 1410 |
+
|
| 1411 |
+
pooled_output = outputs[1]
|
| 1412 |
+
|
| 1413 |
+
pooled_output = self.dropout(pooled_output)
|
| 1414 |
+
logits = self.classifier(pooled_output)
|
| 1415 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1416 |
+
|
| 1417 |
+
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1418 |
+
|
| 1419 |
+
if labels is not None:
|
| 1420 |
+
loss_fct = CrossEntropyLoss()
|
| 1421 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1422 |
+
outputs = (loss,) + outputs
|
| 1423 |
+
|
| 1424 |
+
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
|
| 1428 |
+
|
| 1429 |
+
[DOCS]
|
| 1430 |
+
@add_start_docstrings(
|
| 1431 |
+
"""Bert Model with a token classification head on top (a linear layer on top of
|
| 1432 |
+
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
| 1433 |
+
BERT_START_DOCSTRING,
|
| 1434 |
+
)
|
| 1435 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
| 1436 |
+
def __init__(self, config):
|
| 1437 |
+
super().__init__(config)
|
| 1438 |
+
self.num_labels = config.num_labels
|
| 1439 |
+
|
| 1440 |
+
self.bert = BertModel(config)
|
| 1441 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1442 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1443 |
+
|
| 1444 |
+
self.init_weights()
|
| 1445 |
+
|
| 1446 |
+
|
| 1447 |
+
[DOCS]
|
| 1448 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1449 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1450 |
+
def forward(
|
| 1451 |
+
self,
|
| 1452 |
+
input_ids=None,
|
| 1453 |
+
attention_mask=None,
|
| 1454 |
+
token_type_ids=None,
|
| 1455 |
+
position_ids=None,
|
| 1456 |
+
head_mask=None,
|
| 1457 |
+
inputs_embeds=None,
|
| 1458 |
+
labels=None,
|
| 1459 |
+
output_attentions=None,
|
| 1460 |
+
output_hidden_states=None,
|
| 1461 |
+
):
|
| 1462 |
+
r"""
|
| 1463 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 1464 |
+
Labels for computing the token classification loss.
|
| 1465 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
| 1466 |
+
|
| 1467 |
+
Returns:
|
| 1468 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1469 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
| 1470 |
+
Classification loss.
|
| 1471 |
+
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
| 1472 |
+
Classification scores (before SoftMax).
|
| 1473 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1474 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1475 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1476 |
+
|
| 1477 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1478 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1479 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1480 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1481 |
+
|
| 1482 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1483 |
+
heads.
|
| 1484 |
+
"""
|
| 1485 |
+
|
| 1486 |
+
outputs = self.bert(
|
| 1487 |
+
input_ids,
|
| 1488 |
+
attention_mask=attention_mask,
|
| 1489 |
+
token_type_ids=token_type_ids,
|
| 1490 |
+
position_ids=position_ids,
|
| 1491 |
+
head_mask=head_mask,
|
| 1492 |
+
inputs_embeds=inputs_embeds,
|
| 1493 |
+
output_attentions=output_attentions,
|
| 1494 |
+
output_hidden_states=output_hidden_states,
|
| 1495 |
+
)
|
| 1496 |
+
|
| 1497 |
+
sequence_output = outputs[0]
|
| 1498 |
+
|
| 1499 |
+
sequence_output = self.dropout(sequence_output)
|
| 1500 |
+
logits = self.classifier(sequence_output)
|
| 1501 |
+
|
| 1502 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1503 |
+
if labels is not None:
|
| 1504 |
+
loss_fct = CrossEntropyLoss()
|
| 1505 |
+
# Only keep active parts of the loss
|
| 1506 |
+
if attention_mask is not None:
|
| 1507 |
+
active_loss = attention_mask.view(-1) == 1
|
| 1508 |
+
active_logits = logits.view(-1, self.num_labels)
|
| 1509 |
+
active_labels = torch.where(
|
| 1510 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
| 1511 |
+
)
|
| 1512 |
+
loss = loss_fct(active_logits, active_labels)
|
| 1513 |
+
else:
|
| 1514 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1515 |
+
outputs = (loss,) + outputs
|
| 1516 |
+
|
| 1517 |
+
return outputs # (loss), scores, (hidden_states), (attentions)
|
| 1518 |
+
|
| 1519 |
+
|
| 1520 |
+
|
| 1521 |
+
|
| 1522 |
+
[DOCS]
|
| 1523 |
+
@add_start_docstrings(
|
| 1524 |
+
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1525 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
| 1526 |
+
BERT_START_DOCSTRING,
|
| 1527 |
+
)
|
| 1528 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
| 1529 |
+
def __init__(self, config):
|
| 1530 |
+
super().__init__(config)
|
| 1531 |
+
self.num_labels = config.num_labels
|
| 1532 |
+
|
| 1533 |
+
self.bert = BertModel(config)
|
| 1534 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1535 |
+
|
| 1536 |
+
self.init_weights()
|
| 1537 |
+
|
| 1538 |
+
|
| 1539 |
+
[DOCS]
|
| 1540 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1541 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1542 |
+
def forward(
|
| 1543 |
+
self,
|
| 1544 |
+
input_ids=None,
|
| 1545 |
+
attention_mask=None,
|
| 1546 |
+
token_type_ids=None,
|
| 1547 |
+
position_ids=None,
|
| 1548 |
+
head_mask=None,
|
| 1549 |
+
inputs_embeds=None,
|
| 1550 |
+
start_positions=None,
|
| 1551 |
+
end_positions=None,
|
| 1552 |
+
output_attentions=None,
|
| 1553 |
+
output_hidden_states=None,
|
| 1554 |
+
):
|
| 1555 |
+
r"""
|
| 1556 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1557 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1558 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
| 1559 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
| 1560 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1561 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1562 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
| 1563 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
| 1564 |
+
|
| 1565 |
+
Returns:
|
| 1566 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1567 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
| 1568 |
+
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
| 1569 |
+
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
| 1570 |
+
Span-start scores (before SoftMax).
|
| 1571 |
+
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
| 1572 |
+
Span-end scores (before SoftMax).
|
| 1573 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1574 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1575 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1576 |
+
|
| 1577 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1578 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1579 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1580 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1581 |
+
|
| 1582 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1583 |
+
heads.
|
| 1584 |
+
"""
|
| 1585 |
+
|
| 1586 |
+
outputs = self.bert(
|
| 1587 |
+
input_ids,
|
| 1588 |
+
attention_mask=attention_mask,
|
| 1589 |
+
token_type_ids=token_type_ids,
|
| 1590 |
+
position_ids=position_ids,
|
| 1591 |
+
head_mask=head_mask,
|
| 1592 |
+
inputs_embeds=inputs_embeds,
|
| 1593 |
+
output_attentions=output_attentions,
|
| 1594 |
+
output_hidden_states=output_hidden_states,
|
| 1595 |
+
)
|
| 1596 |
+
|
| 1597 |
+
sequence_output = outputs[0]
|
| 1598 |
+
|
| 1599 |
+
logits = self.qa_outputs(sequence_output)
|
| 1600 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1601 |
+
start_logits = start_logits.squeeze(-1)
|
| 1602 |
+
end_logits = end_logits.squeeze(-1)
|
| 1603 |
+
|
| 1604 |
+
outputs = (start_logits, end_logits,) + outputs[2:]
|
| 1605 |
+
if start_positions is not None and end_positions is not None:
|
| 1606 |
+
# If we are on multi-GPU, split add a dimension
|
| 1607 |
+
if len(start_positions.size()) > 1:
|
| 1608 |
+
start_positions = start_positions.squeeze(-1)
|
| 1609 |
+
if len(end_positions.size()) > 1:
|
| 1610 |
+
end_positions = end_positions.squeeze(-1)
|
| 1611 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1612 |
+
ignored_index = start_logits.size(1)
|
| 1613 |
+
start_positions.clamp_(0, ignored_index)
|
| 1614 |
+
end_positions.clamp_(0, ignored_index)
|
| 1615 |
+
|
| 1616 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1617 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1618 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1619 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1620 |
+
outputs = (total_loss,) + outputs
|
| 1621 |
+
|
| 1622 |
+
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
src/reference_code/evaluate_embeddings.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
import numpy
|
| 5 |
+
|
| 6 |
+
import pickle
|
| 7 |
+
import tqdm
|
| 8 |
+
|
| 9 |
+
from ..bert import BERT
|
| 10 |
+
from ..vocab import Vocab
|
| 11 |
+
from ..dataset import TokenizerDataset
|
| 12 |
+
import argparse
|
| 13 |
+
from itertools import combinations
|
| 14 |
+
|
| 15 |
+
def generate_subset(s):
|
| 16 |
+
subsets = []
|
| 17 |
+
for r in range(len(s) + 1):
|
| 18 |
+
combinations_result = combinations(s, r)
|
| 19 |
+
if r==1:
|
| 20 |
+
subsets.extend(([item] for sublist in combinations_result for item in sublist))
|
| 21 |
+
else:
|
| 22 |
+
subsets.extend((list(sublist) for sublist in combinations_result))
|
| 23 |
+
subsets_dict = {i:s for i, s in enumerate(subsets)}
|
| 24 |
+
return subsets_dict
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
parser = argparse.ArgumentParser()
|
| 28 |
+
|
| 29 |
+
parser.add_argument('-workspace_name', type=str, default=None)
|
| 30 |
+
parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length")
|
| 31 |
+
parser.add_argument('-pretrain', type=bool, default=False)
|
| 32 |
+
parser.add_argument('-masked_pred', type=bool, default=False)
|
| 33 |
+
parser.add_argument('-epoch', type=str, default=None)
|
| 34 |
+
# parser.add_argument('-set_label', type=bool, default=False)
|
| 35 |
+
# parser.add_argument('--label_standard', nargs='+', type=str, help='List of optional tasks')
|
| 36 |
+
|
| 37 |
+
options = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
folder_path = options.workspace_name+"/" if options.workspace_name else ""
|
| 40 |
+
|
| 41 |
+
# if options.set_label:
|
| 42 |
+
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2'})
|
| 43 |
+
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
|
| 44 |
+
# else:
|
| 45 |
+
# label_standard = pickle.load(open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "rb"))
|
| 46 |
+
# print(f"options.label _standard: {options.label_standard}")
|
| 47 |
+
vocab_path = f"{folder_path}check/pretraining/vocab.txt"
|
| 48 |
+
# vocab_path = f"{folder_path}pretraining/vocab.txt"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
print("Loading Vocab", vocab_path)
|
| 52 |
+
vocab_obj = Vocab(vocab_path)
|
| 53 |
+
vocab_obj.load_vocab()
|
| 54 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
| 55 |
+
|
| 56 |
+
# label_standard = list(pickle.load(open(f"dataset/CL4999_1920/{options.workspace_name}/unique_problems_list.pkl", "rb")))
|
| 57 |
+
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2', 'OptionalTask_1', 'OptionalTask_2'})
|
| 58 |
+
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
|
| 59 |
+
|
| 60 |
+
if options.masked_pred:
|
| 61 |
+
str_code = "masked_prediction"
|
| 62 |
+
output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}"
|
| 63 |
+
else:
|
| 64 |
+
str_code = "masked"
|
| 65 |
+
output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}"
|
| 66 |
+
|
| 67 |
+
folder_path = folder_path+"check/"
|
| 68 |
+
# folder_path = folder_path
|
| 69 |
+
if options.pretrain:
|
| 70 |
+
pretrain_file = f"{folder_path}pretraining/pretrain.txt"
|
| 71 |
+
pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl"
|
| 72 |
+
|
| 73 |
+
# pretrain_file = f"{folder_path}finetuning/train.txt"
|
| 74 |
+
# pretrain_label = f"{folder_path}finetuning/train_label.txt"
|
| 75 |
+
|
| 76 |
+
embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl"
|
| 77 |
+
print("Loading Pretrain Dataset ", pretrain_file)
|
| 78 |
+
pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len)
|
| 79 |
+
|
| 80 |
+
print("Creating Dataloader")
|
| 81 |
+
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4)
|
| 82 |
+
else:
|
| 83 |
+
val_file = f"{folder_path}pretraining/test.txt"
|
| 84 |
+
val_label = f"{folder_path}pretraining/test_opt.txt"
|
| 85 |
+
|
| 86 |
+
# val_file = f"{folder_path}finetuning/test.txt"
|
| 87 |
+
# val_label = f"{folder_path}finetuning/test_label.txt"
|
| 88 |
+
embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl"
|
| 89 |
+
|
| 90 |
+
print("Loading Validation Dataset ", val_file)
|
| 91 |
+
val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len)
|
| 92 |
+
|
| 93 |
+
print("Creating Dataloader")
|
| 94 |
+
val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)
|
| 95 |
+
|
| 96 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 97 |
+
print(device)
|
| 98 |
+
print("Load Pre-trained BERT model...")
|
| 99 |
+
print(output_name)
|
| 100 |
+
bert = torch.load(output_name, map_location=device)
|
| 101 |
+
# learned_parameters = model_ep0.state_dict()
|
| 102 |
+
for param in bert.parameters():
|
| 103 |
+
param.requires_grad = False
|
| 104 |
+
|
| 105 |
+
if options.pretrain:
|
| 106 |
+
print("Pretrain-embeddings....")
|
| 107 |
+
data_iter = tqdm.tqdm(enumerate(pretrain_data_loader),
|
| 108 |
+
desc="pre-train",
|
| 109 |
+
total=len(pretrain_data_loader),
|
| 110 |
+
bar_format="{l_bar}{r_bar}")
|
| 111 |
+
pretrain_embeddings = []
|
| 112 |
+
for i, data in data_iter:
|
| 113 |
+
data = {key: value.to(device) for key, value in data.items()}
|
| 114 |
+
hrep = bert(data["bert_input"], data["segment_label"])
|
| 115 |
+
# print(hrep[:,0].cpu().detach().numpy())
|
| 116 |
+
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
|
| 117 |
+
pretrain_embeddings.extend(embeddings)
|
| 118 |
+
pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb"))
|
| 119 |
+
# pickle.dump(pretrain_embeddings, open("embeddings/finetune_cfa_train_embeddings.pkl","wb"))
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
print("Validation-embeddings....")
|
| 123 |
+
data_iter = tqdm.tqdm(enumerate(val_data_loader),
|
| 124 |
+
desc="validation",
|
| 125 |
+
total=len(val_data_loader),
|
| 126 |
+
bar_format="{l_bar}{r_bar}")
|
| 127 |
+
val_embeddings = []
|
| 128 |
+
for i, data in data_iter:
|
| 129 |
+
data = {key: value.to(device) for key, value in data.items()}
|
| 130 |
+
hrep = bert(data["bert_input"], data["segment_label"])
|
| 131 |
+
# print(,hrep[:,0].shape)
|
| 132 |
+
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
|
| 133 |
+
val_embeddings.extend(embeddings)
|
| 134 |
+
pickle.dump(val_embeddings, open(embedding_file_path,"wb"))
|
| 135 |
+
# pickle.dump(val_embeddings, open("embeddings/finetune_cfa_test_embeddings.pkl","wb"))
|
| 136 |
+
|
src/reference_code/metrics.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.special import softmax
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CELoss(object):
|
| 6 |
+
|
| 7 |
+
def compute_bin_boundaries(self, probabilities = np.array([])):
|
| 8 |
+
|
| 9 |
+
#uniform bin spacing
|
| 10 |
+
if probabilities.size == 0:
|
| 11 |
+
bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
|
| 12 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 13 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 14 |
+
else:
|
| 15 |
+
#size of bins
|
| 16 |
+
bin_n = int(self.n_data/self.n_bins)
|
| 17 |
+
|
| 18 |
+
bin_boundaries = np.array([])
|
| 19 |
+
|
| 20 |
+
probabilities_sort = np.sort(probabilities)
|
| 21 |
+
|
| 22 |
+
for i in range(0,self.n_bins):
|
| 23 |
+
bin_boundaries = np.append(bin_boundaries,probabilities_sort[i*bin_n])
|
| 24 |
+
bin_boundaries = np.append(bin_boundaries,1.0)
|
| 25 |
+
|
| 26 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 27 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_probabilities(self, output, labels, logits):
|
| 31 |
+
#If not probabilities apply softmax!
|
| 32 |
+
if logits:
|
| 33 |
+
self.probabilities = softmax(output, axis=1)
|
| 34 |
+
else:
|
| 35 |
+
self.probabilities = output
|
| 36 |
+
|
| 37 |
+
self.labels = np.argmax(labels, axis=1)
|
| 38 |
+
self.confidences = np.max(self.probabilities, axis=1)
|
| 39 |
+
self.predictions = np.argmax(self.probabilities, axis=1)
|
| 40 |
+
self.accuracies = np.equal(self.predictions, self.labels)
|
| 41 |
+
|
| 42 |
+
def binary_matrices(self):
|
| 43 |
+
idx = np.arange(self.n_data)
|
| 44 |
+
#make matrices of zeros
|
| 45 |
+
pred_matrix = np.zeros([self.n_data,self.n_class])
|
| 46 |
+
label_matrix = np.zeros([self.n_data,self.n_class])
|
| 47 |
+
#self.acc_matrix = np.zeros([self.n_data,self.n_class])
|
| 48 |
+
pred_matrix[idx,self.predictions] = 1
|
| 49 |
+
label_matrix[idx,self.labels] = 1
|
| 50 |
+
|
| 51 |
+
self.acc_matrix = np.equal(pred_matrix, label_matrix)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def compute_bins(self, index = None):
|
| 55 |
+
self.bin_prop = np.zeros(self.n_bins)
|
| 56 |
+
self.bin_acc = np.zeros(self.n_bins)
|
| 57 |
+
self.bin_conf = np.zeros(self.n_bins)
|
| 58 |
+
self.bin_score = np.zeros(self.n_bins)
|
| 59 |
+
|
| 60 |
+
if index == None:
|
| 61 |
+
confidences = self.confidences
|
| 62 |
+
accuracies = self.accuracies
|
| 63 |
+
else:
|
| 64 |
+
confidences = self.probabilities[:,index]
|
| 65 |
+
accuracies = self.acc_matrix[:,index]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
for i, (bin_lower, bin_upper) in enumerate(zip(self.bin_lowers, self.bin_uppers)):
|
| 69 |
+
# Calculated |confidence - accuracy| in each bin
|
| 70 |
+
in_bin = np.greater(confidences,bin_lower.item()) * np.less_equal(confidences,bin_upper.item())
|
| 71 |
+
self.bin_prop[i] = np.mean(in_bin)
|
| 72 |
+
|
| 73 |
+
if self.bin_prop[i].item() > 0:
|
| 74 |
+
self.bin_acc[i] = np.mean(accuracies[in_bin])
|
| 75 |
+
self.bin_conf[i] = np.mean(confidences[in_bin])
|
| 76 |
+
self.bin_score[i] = np.abs(self.bin_conf[i] - self.bin_acc[i])
|
| 77 |
+
|
| 78 |
+
class MaxProbCELoss(CELoss):
|
| 79 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 80 |
+
self.n_bins = n_bins
|
| 81 |
+
super().compute_bin_boundaries()
|
| 82 |
+
super().get_probabilities(output, labels, logits)
|
| 83 |
+
super().compute_bins()
|
| 84 |
+
|
| 85 |
+
#http://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf
|
| 86 |
+
class ECELoss(MaxProbCELoss):
|
| 87 |
+
|
| 88 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 89 |
+
super().loss(output, labels, n_bins, logits)
|
| 90 |
+
return np.dot(self.bin_prop,self.bin_score)
|
| 91 |
+
|
| 92 |
+
class MCELoss(MaxProbCELoss):
|
| 93 |
+
|
| 94 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 95 |
+
super().loss(output, labels, n_bins, logits)
|
| 96 |
+
return np.max(self.bin_score)
|
| 97 |
+
|
| 98 |
+
#https://arxiv.org/abs/1905.11001
|
| 99 |
+
#Overconfidence Loss (Good in high risk applications where confident but wrong predictions can be especially harmful)
|
| 100 |
+
class OELoss(MaxProbCELoss):
|
| 101 |
+
|
| 102 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 103 |
+
super().loss(output, labels, n_bins, logits)
|
| 104 |
+
return np.dot(self.bin_prop,self.bin_conf * np.maximum(self.bin_conf-self.bin_acc,np.zeros(self.n_bins)))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
#https://arxiv.org/abs/1904.01685
|
| 108 |
+
class SCELoss(CELoss):
|
| 109 |
+
|
| 110 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 111 |
+
sce = 0.0
|
| 112 |
+
self.n_bins = n_bins
|
| 113 |
+
self.n_data = len(output)
|
| 114 |
+
self.n_class = len(output[0])
|
| 115 |
+
|
| 116 |
+
super().compute_bin_boundaries()
|
| 117 |
+
super().get_probabilities(output, labels, logits)
|
| 118 |
+
super().binary_matrices()
|
| 119 |
+
|
| 120 |
+
for i in range(self.n_class):
|
| 121 |
+
super().compute_bins(i)
|
| 122 |
+
sce += np.dot(self.bin_prop,self.bin_score)
|
| 123 |
+
|
| 124 |
+
return sce/self.n_class
|
| 125 |
+
|
| 126 |
+
class TACELoss(CELoss):
|
| 127 |
+
|
| 128 |
+
def loss(self, output, labels, threshold = 0.01, n_bins = 15, logits = True):
|
| 129 |
+
tace = 0.0
|
| 130 |
+
self.n_bins = n_bins
|
| 131 |
+
self.n_data = len(output)
|
| 132 |
+
self.n_class = len(output[0])
|
| 133 |
+
|
| 134 |
+
super().get_probabilities(output, labels, logits)
|
| 135 |
+
self.probabilities[self.probabilities < threshold] = 0
|
| 136 |
+
super().binary_matrices()
|
| 137 |
+
|
| 138 |
+
for i in range(self.n_class):
|
| 139 |
+
super().compute_bin_boundaries(self.probabilities[:,i])
|
| 140 |
+
super().compute_bins(i)
|
| 141 |
+
tace += np.dot(self.bin_prop,self.bin_score)
|
| 142 |
+
|
| 143 |
+
return tace/self.n_class
|
| 144 |
+
|
| 145 |
+
#create TACELoss with threshold fixed at 0
|
| 146 |
+
class ACELoss(TACELoss):
|
| 147 |
+
|
| 148 |
+
def loss(self, output, labels, n_bins = 15, logits = True):
|
| 149 |
+
return super().loss(output, labels, 0.0 , n_bins, logits)
|
src/reference_code/pretrainer-old.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.optim import Adam, SGD
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import pickle
|
| 7 |
+
|
| 8 |
+
from ..bert import BERT
|
| 9 |
+
from ..seq_model import BERTSM
|
| 10 |
+
from ..classifier_model import BERTForClassification
|
| 11 |
+
from ..optim_schedule import ScheduledOptim
|
| 12 |
+
|
| 13 |
+
import tqdm
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
# import visualization
|
| 19 |
+
|
| 20 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 21 |
+
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import seaborn as sns
|
| 24 |
+
import pandas as pd
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
class ECE(nn.Module):
|
| 29 |
+
|
| 30 |
+
def __init__(self, n_bins=15):
|
| 31 |
+
"""
|
| 32 |
+
n_bins (int): number of confidence interval bins
|
| 33 |
+
"""
|
| 34 |
+
super(ECE, self).__init__()
|
| 35 |
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
| 36 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 37 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 38 |
+
|
| 39 |
+
def forward(self, logits, labels):
|
| 40 |
+
softmaxes = F.softmax(logits, dim=1)
|
| 41 |
+
confidences, predictions = torch.max(softmaxes, 1)
|
| 42 |
+
labels = torch.argmax(labels,1)
|
| 43 |
+
accuracies = predictions.eq(labels)
|
| 44 |
+
|
| 45 |
+
ece = torch.zeros(1, device=logits.device)
|
| 46 |
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
| 47 |
+
# Calculated |confidence - accuracy| in each bin
|
| 48 |
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
| 49 |
+
prop_in_bin = in_bin.float().mean()
|
| 50 |
+
if prop_in_bin.item() > 0:
|
| 51 |
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
| 52 |
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
| 53 |
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
| 54 |
+
|
| 55 |
+
return ece
|
| 56 |
+
|
| 57 |
+
def accurate_nb(preds, labels):
|
| 58 |
+
pred_flat = np.argmax(preds, axis=1).flatten()
|
| 59 |
+
labels_flat = np.argmax(labels, axis=1).flatten()
|
| 60 |
+
labels_flat = labels.flatten()
|
| 61 |
+
return np.sum(pred_flat == labels_flat)
|
| 62 |
+
|
| 63 |
+
class BERTTrainer:
|
| 64 |
+
"""
|
| 65 |
+
BERTTrainer pretrains BERT model on input sequence of strategies.
|
| 66 |
+
BERTTrainer make the pretrained BERT model with one training method objective.
|
| 67 |
+
1. Masked Strategy Modelling : 3.3.1 Task #1: Masked SM
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
| 71 |
+
train_dataloader: DataLoader, val_dataloader: DataLoader = None, test_dataloader: DataLoader = None,
|
| 72 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=5000,
|
| 73 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False,
|
| 74 |
+
workspace_name=None, code=None):
|
| 75 |
+
"""
|
| 76 |
+
:param bert: BERT model which you want to train
|
| 77 |
+
:param vocab_size: total word vocab size
|
| 78 |
+
:param train_dataloader: train dataset data loader
|
| 79 |
+
:param test_dataloader: test dataset data loader [can be None]
|
| 80 |
+
:param lr: learning rate of optimizer
|
| 81 |
+
:param betas: Adam optimizer betas
|
| 82 |
+
:param weight_decay: Adam optimizer weight decay param
|
| 83 |
+
:param with_cuda: traning with cuda
|
| 84 |
+
:param log_freq: logging frequency of the batch iteration
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 88 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 89 |
+
print(cuda_condition, " Device used = ", self.device)
|
| 90 |
+
|
| 91 |
+
available_gpus = list(range(torch.cuda.device_count()))
|
| 92 |
+
|
| 93 |
+
# This BERT model will be saved every epoch
|
| 94 |
+
self.bert = bert.to(self.device)
|
| 95 |
+
# Initialize the BERT Language Model, with BERT model
|
| 96 |
+
self.model = BERTSM(bert, vocab_size).to(self.device)
|
| 97 |
+
|
| 98 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 99 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
| 100 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 101 |
+
self.model = nn.DataParallel(self.model, device_ids=available_gpus)
|
| 102 |
+
|
| 103 |
+
# Setting the train and test data loader
|
| 104 |
+
self.train_data = train_dataloader
|
| 105 |
+
self.val_data = val_dataloader
|
| 106 |
+
self.test_data = test_dataloader
|
| 107 |
+
|
| 108 |
+
# Setting the Adam optimizer with hyper-param
|
| 109 |
+
self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
|
| 110 |
+
self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
|
| 111 |
+
|
| 112 |
+
# Using Negative Log Likelihood Loss function for predicting the masked_token
|
| 113 |
+
self.criterion = nn.NLLLoss(ignore_index=0)
|
| 114 |
+
|
| 115 |
+
self.log_freq = log_freq
|
| 116 |
+
self.same_student_prediction = same_student_prediction
|
| 117 |
+
self.workspace_name = workspace_name
|
| 118 |
+
self.save_model = False
|
| 119 |
+
self.code = code
|
| 120 |
+
self.avg_loss = 10000
|
| 121 |
+
self.start_time = time.time()
|
| 122 |
+
|
| 123 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 124 |
+
|
| 125 |
+
def train(self, epoch):
|
| 126 |
+
self.iteration(epoch, self.train_data)
|
| 127 |
+
|
| 128 |
+
def val(self, epoch):
|
| 129 |
+
self.iteration(epoch, self.val_data, phase="val")
|
| 130 |
+
|
| 131 |
+
def test(self, epoch):
|
| 132 |
+
self.iteration(epoch, self.test_data, phase="test")
|
| 133 |
+
|
| 134 |
+
def iteration(self, epoch, data_loader, phase="train"):
|
| 135 |
+
"""
|
| 136 |
+
loop over the data_loader for training or testing
|
| 137 |
+
if on train status, backward operation is activated
|
| 138 |
+
and also auto save the model every peoch
|
| 139 |
+
|
| 140 |
+
:param epoch: current epoch index
|
| 141 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
| 142 |
+
:param train: boolean value of is train or test
|
| 143 |
+
:return: None
|
| 144 |
+
"""
|
| 145 |
+
# str_code = "train" if train else "test"
|
| 146 |
+
# code = "masked_prediction" if self.same_student_prediction else "masked"
|
| 147 |
+
|
| 148 |
+
self.log_file = f"{self.workspace_name}/logs/{self.code}/log_{phase}_pretrained.txt"
|
| 149 |
+
# bert_hidden_representations = []
|
| 150 |
+
if epoch == 0:
|
| 151 |
+
f = open(self.log_file, 'w')
|
| 152 |
+
f.close()
|
| 153 |
+
if phase == "val":
|
| 154 |
+
self.avg_loss = 10000
|
| 155 |
+
# Setting the tqdm progress bar
|
| 156 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 157 |
+
desc="EP_%s:%d" % (phase, epoch),
|
| 158 |
+
total=len(data_loader),
|
| 159 |
+
bar_format="{l_bar}{r_bar}")
|
| 160 |
+
|
| 161 |
+
avg_loss_mask = 0.0
|
| 162 |
+
total_correct_mask = 0
|
| 163 |
+
total_element_mask = 0
|
| 164 |
+
|
| 165 |
+
avg_loss_pred = 0.0
|
| 166 |
+
total_correct_pred = 0
|
| 167 |
+
total_element_pred = 0
|
| 168 |
+
|
| 169 |
+
avg_loss = 0.0
|
| 170 |
+
|
| 171 |
+
if phase == "train":
|
| 172 |
+
self.model.train()
|
| 173 |
+
else:
|
| 174 |
+
self.model.eval()
|
| 175 |
+
with open(self.log_file, 'a') as f:
|
| 176 |
+
sys.stdout = f
|
| 177 |
+
for i, data in data_iter:
|
| 178 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 179 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 180 |
+
# if i == 0:
|
| 181 |
+
# print(f"data : {data[0]}")
|
| 182 |
+
# 1. forward the next_sentence_prediction and masked_lm model
|
| 183 |
+
# next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
|
| 184 |
+
if self.same_student_prediction:
|
| 185 |
+
bert_hidden_rep, mask_lm_output, same_student_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction)
|
| 186 |
+
else:
|
| 187 |
+
bert_hidden_rep, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction)
|
| 188 |
+
|
| 189 |
+
# embeddings = [h for h in bert_hidden_rep.cpu().detach().numpy()]
|
| 190 |
+
# bert_hidden_representations.extend(embeddings)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# 2-2. NLLLoss of predicting masked token word
|
| 194 |
+
mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])
|
| 195 |
+
|
| 196 |
+
# 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
|
| 197 |
+
if self.same_student_prediction:
|
| 198 |
+
# 2-1. NLL(negative log likelihood) loss of is_next classification result
|
| 199 |
+
same_student_loss = self.criterion(same_student_output, data["is_same_student"])
|
| 200 |
+
loss = same_student_loss + mask_loss
|
| 201 |
+
else:
|
| 202 |
+
loss = mask_loss
|
| 203 |
+
|
| 204 |
+
# 3. backward and optimization only in train
|
| 205 |
+
if phase == "train":
|
| 206 |
+
self.optim_schedule.zero_grad()
|
| 207 |
+
loss.backward()
|
| 208 |
+
self.optim_schedule.step_and_update_lr()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# print(f"mask_lm_output : {mask_lm_output}")
|
| 212 |
+
# non_zero_mask = (data["bert_label"] != 0).float()
|
| 213 |
+
# print(f"bert_label : {data['bert_label']}")
|
| 214 |
+
non_zero_mask = (data["bert_label"] != 0).float()
|
| 215 |
+
predictions = torch.argmax(mask_lm_output, dim=-1)
|
| 216 |
+
# print(f"predictions : {predictions}")
|
| 217 |
+
predicted_masked = predictions*non_zero_mask
|
| 218 |
+
# print(f"predicted_masked : {predicted_masked}")
|
| 219 |
+
mask_correct = ((data["bert_label"] == predicted_masked)*non_zero_mask).sum().item()
|
| 220 |
+
# print(f"mask_correct : {mask_correct}")
|
| 221 |
+
# print(f"non_zero_mask.sum().item() : {non_zero_mask.sum().item()}")
|
| 222 |
+
|
| 223 |
+
avg_loss_mask += loss.item()
|
| 224 |
+
total_correct_mask += mask_correct
|
| 225 |
+
total_element_mask += non_zero_mask.sum().item()
|
| 226 |
+
# total_element_mask += data["bert_label"].sum().item()
|
| 227 |
+
|
| 228 |
+
torch.cuda.empty_cache()
|
| 229 |
+
post_fix = {
|
| 230 |
+
"epoch": epoch,
|
| 231 |
+
"iter": i,
|
| 232 |
+
"avg_loss": avg_loss_mask / (i + 1),
|
| 233 |
+
"avg_acc_mask": (total_correct_mask / total_element_mask * 100) if total_element_mask != 0 else 0,
|
| 234 |
+
"loss": loss.item()
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
# next sentence prediction accuracy
|
| 238 |
+
if self.same_student_prediction:
|
| 239 |
+
correct = same_student_output.argmax(dim=-1).eq(data["is_same_student"]).sum().item()
|
| 240 |
+
avg_loss_pred += loss.item()
|
| 241 |
+
total_correct_pred += correct
|
| 242 |
+
total_element_pred += data["is_same_student"].nelement()
|
| 243 |
+
# correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
|
| 244 |
+
post_fix["avg_loss"] = avg_loss_pred / (i + 1)
|
| 245 |
+
post_fix["avg_acc_pred"] = total_correct_pred / total_element_pred * 100
|
| 246 |
+
post_fix["loss"] = loss.item()
|
| 247 |
+
|
| 248 |
+
avg_loss +=loss.item()
|
| 249 |
+
|
| 250 |
+
if i % self.log_freq == 0:
|
| 251 |
+
data_iter.write(str(post_fix))
|
| 252 |
+
# if not train and epoch > 20 :
|
| 253 |
+
# pickle.dump(mask_lm_output.cpu().detach().numpy(), open(f"logs/mask/mask_out_e{epoch}_{i}.pkl","wb"))
|
| 254 |
+
# pickle.dump(data["bert_label"].cpu().detach().numpy(), open(f"logs/mask/label_e{epoch}_{i}.pkl","wb"))
|
| 255 |
+
end_time = time.time()
|
| 256 |
+
final_msg = {
|
| 257 |
+
"epoch": f"EP{epoch}_{phase}",
|
| 258 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 259 |
+
"total_masked_acc": total_correct_mask * 100.0 / total_element_mask if total_element_mask != 0 else 0,
|
| 260 |
+
"time_taken_from_start": end_time - self.start_time
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
if self.same_student_prediction:
|
| 264 |
+
final_msg["total_prediction_acc"] = total_correct_pred * 100.0 / total_element_pred
|
| 265 |
+
|
| 266 |
+
print(final_msg)
|
| 267 |
+
|
| 268 |
+
f.close()
|
| 269 |
+
sys.stdout = sys.__stdout__
|
| 270 |
+
|
| 271 |
+
if phase == "val":
|
| 272 |
+
self.save_model = False
|
| 273 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
| 274 |
+
self.save_model = True
|
| 275 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
| 276 |
+
|
| 277 |
+
# pickle.dump(bert_hidden_representations, open(f"embeddings/{code}/{str_code}_embeddings_{epoch}.pkl","wb"))
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def save(self, epoch, file_path="output/bert_trained.model"):
|
| 282 |
+
"""
|
| 283 |
+
Saving the current BERT model on file_path
|
| 284 |
+
|
| 285 |
+
:param epoch: current epoch number
|
| 286 |
+
:param file_path: model output path which gonna be file_path+"ep%d" % epoch
|
| 287 |
+
:return: final_output_path
|
| 288 |
+
"""
|
| 289 |
+
# if self.code:
|
| 290 |
+
# fpath = file_path.split("/")
|
| 291 |
+
# # output_path = fpath[0]+ "/"+ fpath[1]+f"/{self.code}/" + fpath[2] + ".ep%d" % epoch
|
| 292 |
+
# output_path = "/",join(fpath[0]+ "/"+ fpath[1]+f"/{self.code}/" + fpath[-1] + ".ep%d" % epoch
|
| 293 |
+
|
| 294 |
+
# else:
|
| 295 |
+
output_path = file_path + ".ep%d" % epoch
|
| 296 |
+
|
| 297 |
+
torch.save(self.bert.cpu(), output_path)
|
| 298 |
+
self.bert.to(self.device)
|
| 299 |
+
print("EP:%d Model Saved on:" % epoch, output_path)
|
| 300 |
+
return output_path
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class BERTFineTuneTrainer:
|
| 304 |
+
|
| 305 |
+
def __init__(self, bert: BERT, vocab_size: int,
|
| 306 |
+
train_dataloader: DataLoader, test_dataloader: DataLoader = None,
|
| 307 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 308 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None,
|
| 309 |
+
num_labels=2, finetune_task=""):
|
| 310 |
+
"""
|
| 311 |
+
:param bert: BERT model which you want to train
|
| 312 |
+
:param vocab_size: total word vocab size
|
| 313 |
+
:param train_dataloader: train dataset data loader
|
| 314 |
+
:param test_dataloader: test dataset data loader [can be None]
|
| 315 |
+
:param lr: learning rate of optimizer
|
| 316 |
+
:param betas: Adam optimizer betas
|
| 317 |
+
:param weight_decay: Adam optimizer weight decay param
|
| 318 |
+
:param with_cuda: traning with cuda
|
| 319 |
+
:param log_freq: logging frequency of the batch iteration
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
# Setup cuda device for BERT training, argument -c, --cuda should be true
|
| 323 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 324 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 325 |
+
print(with_cuda, cuda_condition, " Device used = ", self.device)
|
| 326 |
+
|
| 327 |
+
# This BERT model will be saved every epoch
|
| 328 |
+
self.bert = bert
|
| 329 |
+
for param in self.bert.parameters():
|
| 330 |
+
param.requires_grad = False
|
| 331 |
+
# Initialize the BERT Language Model, with BERT model
|
| 332 |
+
self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)
|
| 333 |
+
|
| 334 |
+
# Distributed GPU training if CUDA can detect more than 1 GPU
|
| 335 |
+
if with_cuda and torch.cuda.device_count() > 1:
|
| 336 |
+
print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 337 |
+
self.model = nn.DataParallel(self.model, device_ids=cuda_devices)
|
| 338 |
+
|
| 339 |
+
# Setting the train and test data loader
|
| 340 |
+
self.train_data = train_dataloader
|
| 341 |
+
self.test_data = test_dataloader
|
| 342 |
+
|
| 343 |
+
self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
|
| 344 |
+
# self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
|
| 345 |
+
|
| 346 |
+
if num_labels == 1:
|
| 347 |
+
self.criterion = nn.MSELoss()
|
| 348 |
+
elif num_labels == 2:
|
| 349 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
| 350 |
+
# self.criterion = nn.CrossEntropyLoss()
|
| 351 |
+
elif num_labels > 2:
|
| 352 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 353 |
+
# self.criterion = nn.BCEWithLogitsLoss()
|
| 354 |
+
|
| 355 |
+
# self.ece_criterion = ECE().to(self.device)
|
| 356 |
+
|
| 357 |
+
self.log_freq = log_freq
|
| 358 |
+
self.workspace_name = workspace_name
|
| 359 |
+
self.finetune_task = finetune_task
|
| 360 |
+
self.save_model = False
|
| 361 |
+
self.avg_loss = 10000
|
| 362 |
+
self.start_time = time.time()
|
| 363 |
+
self.probability_list = []
|
| 364 |
+
print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 365 |
+
|
| 366 |
+
def train(self, epoch):
|
| 367 |
+
self.iteration(epoch, self.train_data)
|
| 368 |
+
|
| 369 |
+
def test(self, epoch):
|
| 370 |
+
self.iteration(epoch, self.test_data, train=False)
|
| 371 |
+
|
| 372 |
+
def iteration(self, epoch, data_loader, train=True):
|
| 373 |
+
"""
|
| 374 |
+
loop over the data_loader for training or testing
|
| 375 |
+
if on train status, backward operation is activated
|
| 376 |
+
and also auto save the model every peoch
|
| 377 |
+
|
| 378 |
+
:param epoch: current epoch index
|
| 379 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
| 380 |
+
:param train: boolean value of is train or test
|
| 381 |
+
:return: None
|
| 382 |
+
"""
|
| 383 |
+
str_code = "train" if train else "test"
|
| 384 |
+
|
| 385 |
+
self.log_file = f"{self.workspace_name}/logs/{self.finetune_task}/log_{str_code}_finetuned.txt"
|
| 386 |
+
|
| 387 |
+
if epoch == 0:
|
| 388 |
+
f = open(self.log_file, 'w')
|
| 389 |
+
f.close()
|
| 390 |
+
if not train:
|
| 391 |
+
self.avg_loss = 10000
|
| 392 |
+
|
| 393 |
+
# Setting the tqdm progress bar
|
| 394 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 395 |
+
desc="EP_%s:%d" % (str_code, epoch),
|
| 396 |
+
total=len(data_loader),
|
| 397 |
+
bar_format="{l_bar}{r_bar}")
|
| 398 |
+
|
| 399 |
+
avg_loss = 0.0
|
| 400 |
+
total_correct = 0
|
| 401 |
+
total_element = 0
|
| 402 |
+
plabels = []
|
| 403 |
+
tlabels = []
|
| 404 |
+
|
| 405 |
+
eval_accurate_nb = 0
|
| 406 |
+
nb_eval_examples = 0
|
| 407 |
+
logits_list = []
|
| 408 |
+
labels_list = []
|
| 409 |
+
|
| 410 |
+
if train:
|
| 411 |
+
self.model.train()
|
| 412 |
+
else:
|
| 413 |
+
self.model.eval()
|
| 414 |
+
self.probability_list = []
|
| 415 |
+
with open(self.log_file, 'a') as f:
|
| 416 |
+
sys.stdout = f
|
| 417 |
+
|
| 418 |
+
for i, data in data_iter:
|
| 419 |
+
# 0. batch_data will be sent into the device(GPU or cpu)
|
| 420 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 421 |
+
if train:
|
| 422 |
+
h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"])
|
| 423 |
+
else:
|
| 424 |
+
with torch.no_grad():
|
| 425 |
+
h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"])
|
| 426 |
+
# print(logits, logits.shape)
|
| 427 |
+
logits_list.append(logits.cpu())
|
| 428 |
+
labels_list.append(data["progress_status"].cpu())
|
| 429 |
+
# print(">>>>>>>>>>>>", progress_output)
|
| 430 |
+
# print(f"{epoch}---nelement--- {data['progress_status'].nelement()}")
|
| 431 |
+
# print(data["progress_status"].shape, logits.shape)
|
| 432 |
+
progress_loss = self.criterion(logits, data["progress_status"])
|
| 433 |
+
loss = progress_loss
|
| 434 |
+
|
| 435 |
+
if torch.cuda.device_count() > 1:
|
| 436 |
+
loss = loss.mean()
|
| 437 |
+
|
| 438 |
+
# 3. backward and optimization only in train
|
| 439 |
+
if train:
|
| 440 |
+
self.optim.zero_grad()
|
| 441 |
+
loss.backward()
|
| 442 |
+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 443 |
+
self.optim.step()
|
| 444 |
+
|
| 445 |
+
# progress prediction accuracy
|
| 446 |
+
# correct = progress_output.argmax(dim=-1).eq(data["progress_status"]).sum().item()
|
| 447 |
+
probs = nn.LogSoftmax(dim=-1)(logits)
|
| 448 |
+
self.probability_list.append(probs)
|
| 449 |
+
predicted_labels = torch.argmax(probs, dim=-1)
|
| 450 |
+
true_labels = torch.argmax(data["progress_status"], dim=-1)
|
| 451 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
| 452 |
+
tlabels.extend(true_labels.cpu().numpy())
|
| 453 |
+
|
| 454 |
+
# Compare predicted labels to true labels and calculate accuracy
|
| 455 |
+
correct = (predicted_labels == true_labels).sum().item()
|
| 456 |
+
avg_loss += loss.item()
|
| 457 |
+
total_correct += correct
|
| 458 |
+
# total_element += true_labels.nelement()
|
| 459 |
+
total_element += data["progress_status"].nelement()
|
| 460 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
|
| 461 |
+
|
| 462 |
+
# if train:
|
| 463 |
+
post_fix = {
|
| 464 |
+
"epoch": epoch,
|
| 465 |
+
"iter": i,
|
| 466 |
+
"avg_loss": avg_loss / (i + 1),
|
| 467 |
+
"avg_acc": total_correct / total_element * 100,
|
| 468 |
+
"loss": loss.item()
|
| 469 |
+
}
|
| 470 |
+
# else:
|
| 471 |
+
# logits = logits.detach().cpu().numpy()
|
| 472 |
+
# label_ids = data["progress_status"].to('cpu').numpy()
|
| 473 |
+
# tmp_eval_nb = accurate_nb(logits, label_ids)
|
| 474 |
+
|
| 475 |
+
# eval_accurate_nb += tmp_eval_nb
|
| 476 |
+
# nb_eval_examples += label_ids.shape[0]
|
| 477 |
+
|
| 478 |
+
# # total_element += data["progress_status"].nelement()
|
| 479 |
+
# # avg_loss += loss.item()
|
| 480 |
+
|
| 481 |
+
# post_fix = {
|
| 482 |
+
# "epoch": epoch,
|
| 483 |
+
# "iter": i,
|
| 484 |
+
# "avg_loss": avg_loss / (i + 1),
|
| 485 |
+
# "avg_acc": tmp_eval_nb / total_element * 100,
|
| 486 |
+
# "loss": loss.item()
|
| 487 |
+
# }
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
if i % self.log_freq == 0:
|
| 491 |
+
data_iter.write(str(post_fix))
|
| 492 |
+
|
| 493 |
+
# precisions = precision_score(plabels, tlabels, average="weighted")
|
| 494 |
+
# recalls = recall_score(plabels, tlabels, average="weighted")
|
| 495 |
+
f1_scores = f1_score(plabels, tlabels, average="weighted")
|
| 496 |
+
# if train:
|
| 497 |
+
end_time = time.time()
|
| 498 |
+
final_msg = {
|
| 499 |
+
"epoch": f"EP{epoch}_{str_code}",
|
| 500 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 501 |
+
"total_acc": total_correct * 100.0 / total_element,
|
| 502 |
+
# "precisions": precisions,
|
| 503 |
+
# "recalls": recalls,
|
| 504 |
+
"f1_scores": f1_scores,
|
| 505 |
+
"time_taken_from_start": end_time - self.start_time
|
| 506 |
+
}
|
| 507 |
+
# else:
|
| 508 |
+
# eval_accuracy = eval_accurate_nb/nb_eval_examples
|
| 509 |
+
|
| 510 |
+
# logits_ece = torch.cat(logits_list)
|
| 511 |
+
# labels_ece = torch.cat(labels_list)
|
| 512 |
+
# ece = self.ece_criterion(logits_ece, labels_ece).item()
|
| 513 |
+
# end_time = time.time()
|
| 514 |
+
# final_msg = {
|
| 515 |
+
# "epoch": f"EP{epoch}_{str_code}",
|
| 516 |
+
# "eval_accuracy": eval_accuracy,
|
| 517 |
+
# "ece": ece,
|
| 518 |
+
# "avg_loss": avg_loss / len(data_iter),
|
| 519 |
+
# "precisions": precisions,
|
| 520 |
+
# "recalls": recalls,
|
| 521 |
+
# "f1_scores": f1_scores,
|
| 522 |
+
# "time_taken_from_start": end_time - self.start_time
|
| 523 |
+
# }
|
| 524 |
+
# if self.save_model:
|
| 525 |
+
# conf_hist = visualization.ConfidenceHistogram()
|
| 526 |
+
# plt_test = conf_hist.plot(np.array(logits_ece), np.array(labels_ece), title= f"Confidence Histogram {epoch}")
|
| 527 |
+
# plt_test.savefig(f"{self.workspace_name}/plots/confidence_histogram/{self.finetune_task}/conf_histogram_test_{epoch}.png",bbox_inches='tight')
|
| 528 |
+
# plt_test.close()
|
| 529 |
+
|
| 530 |
+
# rel_diagram = visualization.ReliabilityDiagram()
|
| 531 |
+
# plt_test_2 = rel_diagram.plot(np.array(logits_ece), np.array(labels_ece),title=f"Reliability Diagram {epoch}")
|
| 532 |
+
# plt_test_2.savefig(f"{self.workspace_name}/plots/confidence_histogram/{self.finetune_task}/rel_diagram_test_{epoch}.png",bbox_inches='tight')
|
| 533 |
+
# plt_test_2.close()
|
| 534 |
+
print(final_msg)
|
| 535 |
+
|
| 536 |
+
# print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", total_correct * 100.0 / total_element)
|
| 537 |
+
f.close()
|
| 538 |
+
sys.stdout = sys.__stdout__
|
| 539 |
+
self.save_model = False
|
| 540 |
+
if self.avg_loss > (avg_loss / len(data_iter)):
|
| 541 |
+
self.save_model = True
|
| 542 |
+
self.avg_loss = (avg_loss / len(data_iter))
|
| 543 |
+
|
| 544 |
+
def iteration_1(self, epoch_idx, data):
|
| 545 |
+
try:
|
| 546 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 547 |
+
logits = self.model(data['input_ids'], data['segment_label'])
|
| 548 |
+
# Ensure logits is a tensor, not a tuple
|
| 549 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 550 |
+
loss = loss_fct(logits, data['labels'])
|
| 551 |
+
|
| 552 |
+
# Backpropagation and optimization
|
| 553 |
+
self.optim.zero_grad()
|
| 554 |
+
loss.backward()
|
| 555 |
+
self.optim.step()
|
| 556 |
+
|
| 557 |
+
if self.log_freq > 0 and epoch_idx % self.log_freq == 0:
|
| 558 |
+
print(f"Epoch {epoch_idx}: Loss = {loss.item()}")
|
| 559 |
+
|
| 560 |
+
return loss
|
| 561 |
+
|
| 562 |
+
except Exception as e:
|
| 563 |
+
print(f"Error during iteration: {e}")
|
| 564 |
+
raise
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# plt_test.show()
|
| 571 |
+
# print("EP%d_%s, " % (epoch, str_code))
|
| 572 |
+
|
| 573 |
+
def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
|
| 574 |
+
"""
|
| 575 |
+
Saving the current BERT model on file_path
|
| 576 |
+
|
| 577 |
+
:param epoch: current epoch number
|
| 578 |
+
:param file_path: model output path which gonna be file_path+"ep%d" % epoch
|
| 579 |
+
:return: final_output_path
|
| 580 |
+
"""
|
| 581 |
+
if self.finetune_task:
|
| 582 |
+
fpath = file_path.split("/")
|
| 583 |
+
output_path = fpath[0]+ "/"+ fpath[1]+f"/{self.finetune_task}/" + fpath[2] + ".ep%d" % epoch
|
| 584 |
+
else:
|
| 585 |
+
output_path = file_path + ".ep%d" % epoch
|
| 586 |
+
torch.save(self.model.cpu(), output_path)
|
| 587 |
+
self.model.to(self.device)
|
| 588 |
+
print("EP:%d Model Saved on:" % epoch, output_path)
|
| 589 |
+
return output_path
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class BERTAttention:
|
| 593 |
+
def __init__(self, bert: BERT, vocab_obj, train_dataloader: DataLoader, workspace_name=None, code=None, finetune_task=None, with_cuda=True):
|
| 594 |
+
|
| 595 |
+
# available_gpus = list(range(torch.cuda.device_count()))
|
| 596 |
+
|
| 597 |
+
cuda_condition = torch.cuda.is_available() and with_cuda
|
| 598 |
+
self.device = torch.device("cuda:0" if cuda_condition else "cpu")
|
| 599 |
+
print(with_cuda, cuda_condition, " Device used = ", self.device)
|
| 600 |
+
self.bert = bert.to(self.device)
|
| 601 |
+
|
| 602 |
+
# if with_cuda and torch.cuda.device_count() > 1:
|
| 603 |
+
# print("Using %d GPUS for BERT" % torch.cuda.device_count())
|
| 604 |
+
# self.bert = nn.DataParallel(self.bert, device_ids=available_gpus)
|
| 605 |
+
|
| 606 |
+
self.train_dataloader = train_dataloader
|
| 607 |
+
self.workspace_name = workspace_name
|
| 608 |
+
self.code = code
|
| 609 |
+
self.finetune_task = finetune_task
|
| 610 |
+
self.vocab_obj = vocab_obj
|
| 611 |
+
|
| 612 |
+
def getAttention(self):
|
| 613 |
+
# self.log_file = f"{self.workspace_name}/logs/{self.code}/log_attention.txt"
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
labels = ['PercentChange', 'NumeratorQuantity2', 'NumeratorQuantity1', 'DenominatorQuantity1',
|
| 617 |
+
'OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor',
|
| 618 |
+
'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow',
|
| 619 |
+
'ThirdRow', 'FinalAnswer','FinalAnswerDirection']
|
| 620 |
+
df_all = pd.DataFrame(0.0, index=labels, columns=labels)
|
| 621 |
+
# Setting the tqdm progress bar
|
| 622 |
+
data_iter = tqdm.tqdm(enumerate(self.train_dataloader),
|
| 623 |
+
desc="attention",
|
| 624 |
+
total=len(self.train_dataloader),
|
| 625 |
+
bar_format="{l_bar}{r_bar}")
|
| 626 |
+
count = 0
|
| 627 |
+
for i, data in data_iter:
|
| 628 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 629 |
+
a = self.bert.forward(data["bert_input"], data["segment_label"])
|
| 630 |
+
non_zero = np.sum(data["segment_label"].cpu().detach().numpy())
|
| 631 |
+
|
| 632 |
+
# Last Transformer Layer
|
| 633 |
+
last_layer = self.bert.attention_values[-1].transpose(1,0,2,3)
|
| 634 |
+
# print(last_layer.shape)
|
| 635 |
+
head, d_model, s, s = last_layer.shape
|
| 636 |
+
|
| 637 |
+
for d in range(d_model):
|
| 638 |
+
seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
| 639 |
+
# df_all = pd.DataFrame(0.0, index=seq_labels, columns=seq_labels)
|
| 640 |
+
indices_to_choose = defaultdict(int)
|
| 641 |
+
|
| 642 |
+
for k,s in enumerate(seq_labels):
|
| 643 |
+
if s in labels:
|
| 644 |
+
indices_to_choose[s] = k
|
| 645 |
+
indices_chosen = list(indices_to_choose.values())
|
| 646 |
+
selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
| 647 |
+
# print(len(seq_labels), len(selected_seq_labels))
|
| 648 |
+
for h in range(head):
|
| 649 |
+
# fig, ax = plt.subplots(figsize=(12, 12))
|
| 650 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])#[1:non_zero-1]
|
| 651 |
+
# seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
|
| 652 |
+
# indices_to_choose = defaultdict(int)
|
| 653 |
+
|
| 654 |
+
# for k,s in enumerate(seq_labels):
|
| 655 |
+
# if s in labels:
|
| 656 |
+
# indices_to_choose[s] = k
|
| 657 |
+
# indices_chosen = list(indices_to_choose.values())
|
| 658 |
+
# selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
|
| 659 |
+
# print(f"Chosen index: {seq_labels, indices_to_choose, indices_chosen, selected_seq_labels}")
|
| 660 |
+
|
| 661 |
+
df_cm = pd.DataFrame(last_layer[h][d][indices_chosen,:][:,indices_chosen], index = selected_seq_labels, columns = selected_seq_labels)
|
| 662 |
+
df_all = df_all.add(df_cm, fill_value=0)
|
| 663 |
+
count += 1
|
| 664 |
+
|
| 665 |
+
# df_cm = pd.DataFrame(last_layer[h][d][1:non_zero-1,:][:,1:non_zero-1], index=seq_labels, columns=seq_labels)
|
| 666 |
+
# df_all = df_all.add(df_cm, fill_value=0)
|
| 667 |
+
|
| 668 |
+
# df_all = df_all.reindex(index=seq_labels, columns=seq_labels)
|
| 669 |
+
# sns.heatmap(df_all, annot=False)
|
| 670 |
+
# plt.title("Attentions") #Probabilities
|
| 671 |
+
# plt.xlabel("Steps")
|
| 672 |
+
# plt.ylabel("Steps")
|
| 673 |
+
# plt.grid(True)
|
| 674 |
+
# plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
| 675 |
+
# plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores_over_[{h}]_head_n_data[{d}].png", bbox_inches='tight')
|
| 676 |
+
# plt.show()
|
| 677 |
+
# plt.close()
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
print(f"Count of total : {count, head * self.train_dataloader.dataset.len}")
|
| 682 |
+
df_all = df_all.div(count) # head * self.train_dataloader.dataset.len
|
| 683 |
+
df_all = df_all.reindex(index=labels, columns=labels)
|
| 684 |
+
sns.heatmap(df_all, annot=False)
|
| 685 |
+
plt.title("Attentions") #Probabilities
|
| 686 |
+
plt.xlabel("Steps")
|
| 687 |
+
plt.ylabel("Steps")
|
| 688 |
+
plt.grid(True)
|
| 689 |
+
plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
|
| 690 |
+
plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores.png", bbox_inches='tight')
|
| 691 |
+
plt.show()
|
| 692 |
+
plt.close()
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
|
src/reference_code/test.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn, optim
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
| 5 |
+
import numpy as np
|
| 6 |
+
from keras.preprocessing.sequence import pad_sequences
|
| 7 |
+
from transformers import BertTokenizer
|
| 8 |
+
from transformers import BertForSequenceClassification
|
| 9 |
+
import random
|
| 10 |
+
from sklearn.metrics import f1_score
|
| 11 |
+
from utils import *
|
| 12 |
+
import os
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
warnings.filterwarnings("ignore")
|
| 19 |
+
|
| 20 |
+
class ModelWithTemperature(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
A thin decorator, which wraps a model with temperature scaling
|
| 23 |
+
model (nn.Module):
|
| 24 |
+
A classification neural network
|
| 25 |
+
NB: Output of the neural network should be the classification logits,
|
| 26 |
+
NOT the softmax (or log softmax)!
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self, model):
|
| 29 |
+
super(ModelWithTemperature, self).__init__()
|
| 30 |
+
self.model = model
|
| 31 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
| 32 |
+
|
| 33 |
+
def forward(self, input_ids, token_type_ids, attention_mask):
|
| 34 |
+
logits = self.model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
|
| 35 |
+
return self.temperature_scale(logits)
|
| 36 |
+
|
| 37 |
+
def temperature_scale(self, logits):
|
| 38 |
+
"""
|
| 39 |
+
Perform temperature scaling on logits
|
| 40 |
+
"""
|
| 41 |
+
# Expand temperature to match the size of logits
|
| 42 |
+
temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
|
| 43 |
+
return logits / temperature
|
| 44 |
+
|
| 45 |
+
# This function probably should live outside of this class, but whatever
|
| 46 |
+
def set_temperature(self, valid_loader, args):
|
| 47 |
+
"""
|
| 48 |
+
Tune the tempearature of the model (using the validation set).
|
| 49 |
+
We're going to set it to optimize NLL.
|
| 50 |
+
valid_loader (DataLoader): validation set loader
|
| 51 |
+
"""
|
| 52 |
+
nll_criterion = nn.CrossEntropyLoss()
|
| 53 |
+
ece_criterion = ECE().to(args.device)
|
| 54 |
+
|
| 55 |
+
# First: collect all the logits and labels for the validation set
|
| 56 |
+
logits_list = []
|
| 57 |
+
labels_list = []
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
for step, batch in enumerate(valid_loader):
|
| 60 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 61 |
+
b_input_ids, b_input_mask, b_labels = batch
|
| 62 |
+
logits = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0]
|
| 63 |
+
logits_list.append(logits)
|
| 64 |
+
labels_list.append(b_labels)
|
| 65 |
+
logits = torch.cat(logits_list)
|
| 66 |
+
labels = torch.cat(labels_list)
|
| 67 |
+
|
| 68 |
+
# Calculate NLL and ECE before temperature scaling
|
| 69 |
+
before_temperature_nll = nll_criterion(logits, labels).item()
|
| 70 |
+
before_temperature_ece = ece_criterion(logits, labels).item()
|
| 71 |
+
print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))
|
| 72 |
+
|
| 73 |
+
# Next: optimize the temperature w.r.t. NLL
|
| 74 |
+
optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=50)
|
| 75 |
+
|
| 76 |
+
def eval():
|
| 77 |
+
loss = nll_criterion(self.temperature_scale(logits), labels)
|
| 78 |
+
loss.backward()
|
| 79 |
+
return loss
|
| 80 |
+
optimizer.step(eval)
|
| 81 |
+
|
| 82 |
+
# Calculate NLL and ECE after temperature scaling
|
| 83 |
+
after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
|
| 84 |
+
after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
|
| 85 |
+
print('Optimal temperature: %.3f' % self.temperature.item())
|
| 86 |
+
print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
|
| 87 |
+
|
| 88 |
+
return self
|
| 89 |
+
|
| 90 |
+
class ECE(nn.Module):
|
| 91 |
+
|
| 92 |
+
def __init__(self, n_bins=15):
|
| 93 |
+
"""
|
| 94 |
+
n_bins (int): number of confidence interval bins
|
| 95 |
+
"""
|
| 96 |
+
super(ECE, self).__init__()
|
| 97 |
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
| 98 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 99 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 100 |
+
|
| 101 |
+
def forward(self, logits, labels):
|
| 102 |
+
softmaxes = F.softmax(logits, dim=1)
|
| 103 |
+
confidences, predictions = torch.max(softmaxes, 1)
|
| 104 |
+
accuracies = predictions.eq(labels)
|
| 105 |
+
|
| 106 |
+
ece = torch.zeros(1, device=logits.device)
|
| 107 |
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
| 108 |
+
# Calculated |confidence - accuracy| in each bin
|
| 109 |
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
| 110 |
+
prop_in_bin = in_bin.float().mean()
|
| 111 |
+
if prop_in_bin.item() > 0:
|
| 112 |
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
| 113 |
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
| 114 |
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
| 115 |
+
|
| 116 |
+
return ece
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ECE_v2(nn.Module):
|
| 120 |
+
def __init__(self, n_bins=15):
|
| 121 |
+
"""
|
| 122 |
+
n_bins (int): number of confidence interval bins
|
| 123 |
+
"""
|
| 124 |
+
super(ECE_v2, self).__init__()
|
| 125 |
+
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
|
| 126 |
+
self.bin_lowers = bin_boundaries[:-1]
|
| 127 |
+
self.bin_uppers = bin_boundaries[1:]
|
| 128 |
+
|
| 129 |
+
def forward(self, softmaxes, labels):
|
| 130 |
+
confidences, predictions = torch.max(softmaxes, 1)
|
| 131 |
+
accuracies = predictions.eq(labels)
|
| 132 |
+
ece = torch.zeros(1, device=softmaxes.device)
|
| 133 |
+
|
| 134 |
+
for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
|
| 135 |
+
# Calculated |confidence - accuracy| in each bin
|
| 136 |
+
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
|
| 137 |
+
prop_in_bin = in_bin.float().mean()
|
| 138 |
+
if prop_in_bin.item() > 0:
|
| 139 |
+
accuracy_in_bin = accuracies[in_bin].float().mean()
|
| 140 |
+
avg_confidence_in_bin = confidences[in_bin].mean()
|
| 141 |
+
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
|
| 142 |
+
return ece
|
| 143 |
+
|
| 144 |
+
def accurate_nb(preds, labels):
|
| 145 |
+
pred_flat = np.argmax(preds, axis=1).flatten()
|
| 146 |
+
labels_flat = labels.flatten()
|
| 147 |
+
return np.sum(pred_flat == labels_flat)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def set_seed(args):
|
| 151 |
+
random.seed(args.seed)
|
| 152 |
+
np.random.seed(args.seed)
|
| 153 |
+
torch.manual_seed(args.seed)
|
| 154 |
+
|
| 155 |
+
def apply_dropout(m):
|
| 156 |
+
if type(m) == nn.Dropout:
|
| 157 |
+
m.train()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def main():
|
| 161 |
+
|
| 162 |
+
parser = argparse.ArgumentParser(description='Test code - measure the detection peformance')
|
| 163 |
+
parser.add_argument('--eva_iter', default=1, type=int, help='number of passes for mc-dropout when evaluation')
|
| 164 |
+
parser.add_argument('--model', type=str, choices=['base', 'manifold-smoothing', 'mc-dropout','temperature'], default='base')
|
| 165 |
+
parser.add_argument('--seed', type=int, default=0, help='random seed for test')
|
| 166 |
+
parser.add_argument("--epochs", default=10, type=int, help="Number of epochs for training.")
|
| 167 |
+
parser.add_argument('--index', type=int, default=0, help='random seed you used during training')
|
| 168 |
+
parser.add_argument('--in_dataset', required=True, help='target dataset: 20news')
|
| 169 |
+
parser.add_argument('--out_dataset', required=True, help='out-of-dist dataset')
|
| 170 |
+
parser.add_argument('--eval_batch_size', type=int, default=32)
|
| 171 |
+
parser.add_argument('--saved_dataset', type=str, default='n')
|
| 172 |
+
parser.add_argument('--eps_out', default=0.001, type=float, help="Perturbation size of out-of-domain adversarial training")
|
| 173 |
+
parser.add_argument("--eps_y", default=0.1, type=float, help="Perturbation size of label")
|
| 174 |
+
parser.add_argument('--eps_in', default=0.0001, type=float, help="Perturbation size of in-domain adversarial training")
|
| 175 |
+
|
| 176 |
+
args = parser.parse_args()
|
| 177 |
+
|
| 178 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 179 |
+
args.device = device
|
| 180 |
+
set_seed(args)
|
| 181 |
+
|
| 182 |
+
outf = 'test/'+args.model+'-'+str(args.index)
|
| 183 |
+
if not os.path.isdir(outf):
|
| 184 |
+
os.makedirs(outf)
|
| 185 |
+
|
| 186 |
+
if args.model == 'base':
|
| 187 |
+
dirname = '{}/BERT-base-{}'.format(args.in_dataset, args.index)
|
| 188 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
| 189 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
| 190 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
| 191 |
+
model.to(args.device)
|
| 192 |
+
print('Load Tekenizer')
|
| 193 |
+
|
| 194 |
+
elif args.model == 'mc-dropout':
|
| 195 |
+
dirname = '{}/BERT-base-{}'.format(args.in_dataset, args.index)
|
| 196 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
| 197 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
| 198 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
| 199 |
+
model.to(args.device)
|
| 200 |
+
|
| 201 |
+
elif args.model == 'temperature':
|
| 202 |
+
dirname = '{}/BERT-base-{}'.format(args.in_dataset, args.index)
|
| 203 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
| 204 |
+
orig_model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
| 205 |
+
orig_model.to(args.device)
|
| 206 |
+
model = ModelWithTemperature(orig_model)
|
| 207 |
+
model.to(args.device)
|
| 208 |
+
|
| 209 |
+
elif args.model == 'manifold-smoothing':
|
| 210 |
+
dirname = '{}/BERT-mf-{}-{}-{}-{}'.format(args.in_dataset, args.index, args.eps_in, args.eps_y, args.eps_out)
|
| 211 |
+
print(dirname)
|
| 212 |
+
pretrained_dir = './model_save/{}'.format(dirname)
|
| 213 |
+
model = BertForSequenceClassification.from_pretrained(pretrained_dir)
|
| 214 |
+
model.to(args.device)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
if args.saved_dataset == 'n':
|
| 218 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
|
| 219 |
+
train_sentences, val_sentences, test_sentences, train_labels, val_labels, test_labels = load_dataset(args.in_dataset)
|
| 220 |
+
_, _, nt_test_sentences, _, _, nt_test_labels = load_dataset(args.out_dataset)
|
| 221 |
+
|
| 222 |
+
val_input_ids = []
|
| 223 |
+
test_input_ids = []
|
| 224 |
+
nt_test_input_ids = []
|
| 225 |
+
|
| 226 |
+
if args.in_dataset == '20news' or args.in_dataset == '20news-15':
|
| 227 |
+
MAX_LEN = 150
|
| 228 |
+
else:
|
| 229 |
+
MAX_LEN = 256
|
| 230 |
+
|
| 231 |
+
for sent in val_sentences:
|
| 232 |
+
encoded_sent = tokenizer.encode(
|
| 233 |
+
sent, # Sentence to encode.
|
| 234 |
+
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
|
| 235 |
+
truncation= True,
|
| 236 |
+
max_length = MAX_LEN, # Truncate all sentences.
|
| 237 |
+
#return_tensors = 'pt', # Return pytorch tensors.
|
| 238 |
+
)
|
| 239 |
+
# Add the encoded sentence to the list.
|
| 240 |
+
val_input_ids.append(encoded_sent)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
for sent in test_sentences:
|
| 244 |
+
encoded_sent = tokenizer.encode(
|
| 245 |
+
sent, # Sentence to encode.
|
| 246 |
+
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
|
| 247 |
+
truncation= True,
|
| 248 |
+
max_length = MAX_LEN, # Truncate all sentences.
|
| 249 |
+
#return_tensors = 'pt', # Return pytorch tensors.
|
| 250 |
+
)
|
| 251 |
+
# Add the encoded sentence to the list.
|
| 252 |
+
test_input_ids.append(encoded_sent)
|
| 253 |
+
|
| 254 |
+
for sent in nt_test_sentences:
|
| 255 |
+
encoded_sent = tokenizer.encode(
|
| 256 |
+
sent,
|
| 257 |
+
add_special_tokens = True,
|
| 258 |
+
truncation= True,
|
| 259 |
+
max_length = MAX_LEN,
|
| 260 |
+
)
|
| 261 |
+
nt_test_input_ids.append(encoded_sent)
|
| 262 |
+
|
| 263 |
+
# Pad our input tokens
|
| 264 |
+
val_input_ids = pad_sequences(val_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
|
| 265 |
+
test_input_ids = pad_sequences(test_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
|
| 266 |
+
nt_test_input_ids = pad_sequences(nt_test_input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
|
| 267 |
+
|
| 268 |
+
val_attention_masks = []
|
| 269 |
+
test_attention_masks = []
|
| 270 |
+
nt_test_attention_masks = []
|
| 271 |
+
|
| 272 |
+
for seq in val_input_ids:
|
| 273 |
+
seq_mask = [float(i>0) for i in seq]
|
| 274 |
+
val_attention_masks.append(seq_mask)
|
| 275 |
+
for seq in test_input_ids:
|
| 276 |
+
seq_mask = [float(i>0) for i in seq]
|
| 277 |
+
test_attention_masks.append(seq_mask)
|
| 278 |
+
for seq in nt_test_input_ids:
|
| 279 |
+
seq_mask = [float(i>0) for i in seq]
|
| 280 |
+
nt_test_attention_masks.append(seq_mask)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
val_inputs = torch.tensor(val_input_ids)
|
| 284 |
+
val_labels = torch.tensor(val_labels)
|
| 285 |
+
val_masks = torch.tensor(val_attention_masks)
|
| 286 |
+
|
| 287 |
+
test_inputs = torch.tensor(test_input_ids)
|
| 288 |
+
test_labels = torch.tensor(test_labels)
|
| 289 |
+
test_masks = torch.tensor(test_attention_masks)
|
| 290 |
+
|
| 291 |
+
nt_test_inputs = torch.tensor(nt_test_input_ids)
|
| 292 |
+
nt_test_labels = torch.tensor(nt_test_labels)
|
| 293 |
+
nt_test_masks = torch.tensor(nt_test_attention_masks)
|
| 294 |
+
|
| 295 |
+
val_data = TensorDataset(val_inputs, val_masks, val_labels)
|
| 296 |
+
test_data = TensorDataset(test_inputs, test_masks, test_labels)
|
| 297 |
+
nt_test_data = TensorDataset(nt_test_inputs, nt_test_masks, nt_test_labels)
|
| 298 |
+
|
| 299 |
+
dataset_dir = 'dataset/test'
|
| 300 |
+
if not os.path.exists(dataset_dir):
|
| 301 |
+
os.makedirs(dataset_dir)
|
| 302 |
+
torch.save(val_data, dataset_dir+'/{}_val_in_domain.pt'.format(args.in_dataset))
|
| 303 |
+
torch.save(test_data, dataset_dir+'/{}_test_in_domain.pt'.format(args.in_dataset))
|
| 304 |
+
torch.save(nt_test_data, dataset_dir+'/{}_test_out_of_domain.pt'.format(args.out_dataset))
|
| 305 |
+
|
| 306 |
+
else:
|
| 307 |
+
dataset_dir = 'dataset/test'
|
| 308 |
+
val_data = torch.load(dataset_dir+'/{}_val_in_domain.pt'.format(args.in_dataset))
|
| 309 |
+
test_data = torch.load(dataset_dir+'/{}_test_in_domain.pt'.format(args.in_dataset))
|
| 310 |
+
nt_test_data = torch.load(dataset_dir+'/{}_test_out_of_domain.pt'.format(args.out_dataset))
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
######## saved dataset
|
| 317 |
+
test_sampler = SequentialSampler(test_data)
|
| 318 |
+
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size)
|
| 319 |
+
|
| 320 |
+
nt_test_sampler = SequentialSampler(nt_test_data)
|
| 321 |
+
nt_test_dataloader = DataLoader(nt_test_data, sampler=nt_test_sampler, batch_size=args.eval_batch_size)
|
| 322 |
+
val_sampler = SequentialSampler(val_data)
|
| 323 |
+
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=args.eval_batch_size)
|
| 324 |
+
|
| 325 |
+
if args.model == 'temperature':
|
| 326 |
+
model.set_temperature(val_dataloader, args)
|
| 327 |
+
|
| 328 |
+
model.eval()
|
| 329 |
+
|
| 330 |
+
if args.model == 'mc-dropout':
|
| 331 |
+
model.apply(apply_dropout)
|
| 332 |
+
|
| 333 |
+
correct = 0
|
| 334 |
+
total = 0
|
| 335 |
+
output_list = []
|
| 336 |
+
labels_list = []
|
| 337 |
+
|
| 338 |
+
##### validation dat
|
| 339 |
+
with torch.no_grad():
|
| 340 |
+
for step, batch in enumerate(val_dataloader):
|
| 341 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 342 |
+
b_input_ids, b_input_mask, b_labels = batch
|
| 343 |
+
total += b_labels.shape[0]
|
| 344 |
+
batch_output = 0
|
| 345 |
+
for j in range(args.eva_iter):
|
| 346 |
+
if args.model == 'temperature':
|
| 347 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask) #logits
|
| 348 |
+
else:
|
| 349 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0] #logits
|
| 350 |
+
batch_output = batch_output + F.softmax(current_batch, dim=1)
|
| 351 |
+
batch_output = batch_output/args.eva_iter
|
| 352 |
+
output_list.append(batch_output)
|
| 353 |
+
labels_list.append(b_labels)
|
| 354 |
+
score, predicted = batch_output.max(1)
|
| 355 |
+
correct += predicted.eq(b_labels).sum().item()
|
| 356 |
+
|
| 357 |
+
###calculate accuracy and ECE
|
| 358 |
+
val_eval_accuracy = correct/total
|
| 359 |
+
print("Val Accuracy: {}".format(val_eval_accuracy))
|
| 360 |
+
ece_criterion = ECE_v2().to(args.device)
|
| 361 |
+
softmaxes_ece = torch.cat(output_list)
|
| 362 |
+
labels_ece = torch.cat(labels_list)
|
| 363 |
+
val_ece = ece_criterion(softmaxes_ece, labels_ece).item()
|
| 364 |
+
print('ECE on Val data: {}'.format(val_ece))
|
| 365 |
+
|
| 366 |
+
#### Test data
|
| 367 |
+
correct = 0
|
| 368 |
+
total = 0
|
| 369 |
+
output_list = []
|
| 370 |
+
labels_list = []
|
| 371 |
+
predict_list = []
|
| 372 |
+
true_list = []
|
| 373 |
+
true_list_ood = []
|
| 374 |
+
predict_mis = []
|
| 375 |
+
predict_in = []
|
| 376 |
+
score_list = []
|
| 377 |
+
correct_index_all = []
|
| 378 |
+
## test on in-distribution test set
|
| 379 |
+
with torch.no_grad():
|
| 380 |
+
for step, batch in enumerate(test_dataloader):
|
| 381 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 382 |
+
b_input_ids, b_input_mask, b_labels = batch
|
| 383 |
+
total += b_labels.shape[0]
|
| 384 |
+
batch_output = 0
|
| 385 |
+
for j in range(args.eva_iter):
|
| 386 |
+
if args.model == 'temperature':
|
| 387 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask) #logits
|
| 388 |
+
else:
|
| 389 |
+
current_batch = model(input_ids=b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0] #logits
|
| 390 |
+
batch_output = batch_output + F.softmax(current_batch, dim=1)
|
| 391 |
+
batch_output = batch_output/args.eva_iter
|
| 392 |
+
output_list.append(batch_output)
|
| 393 |
+
labels_list.append(b_labels)
|
| 394 |
+
score, predicted = batch_output.max(1)
|
| 395 |
+
|
| 396 |
+
correct += predicted.eq(b_labels).sum().item()
|
| 397 |
+
|
| 398 |
+
correct_index = (predicted == b_labels)
|
| 399 |
+
correct_index_all.append(correct_index)
|
| 400 |
+
score_list.append(score)
|
| 401 |
+
|
| 402 |
+
###calcutae accuracy
|
| 403 |
+
eval_accuracy = correct/total
|
| 404 |
+
print("Test Accuracy: {}".format(eval_accuracy))
|
| 405 |
+
|
| 406 |
+
##calculate ece
|
| 407 |
+
ece_criterion = ECE_v2().to(args.device)
|
| 408 |
+
softmaxes_ece = torch.cat(output_list)
|
| 409 |
+
labels_ece = torch.cat(labels_list)
|
| 410 |
+
ece = ece_criterion(softmaxes_ece, labels_ece).item()
|
| 411 |
+
print('ECE on Test data: {}'.format(ece))
|
| 412 |
+
|
| 413 |
+
#confidence for in-distribution data
|
| 414 |
+
score_in_array = torch.cat(score_list)
|
| 415 |
+
#indices of data that are classified correctly
|
| 416 |
+
correct_array = torch.cat(correct_index_all)
|
| 417 |
+
label_array = torch.cat(labels_list)
|
| 418 |
+
|
| 419 |
+
### test on out-of-distribution data
|
| 420 |
+
predict_ood = []
|
| 421 |
+
score_ood_list = []
|
| 422 |
+
true_list_ood = []
|
| 423 |
+
with torch.no_grad():
|
| 424 |
+
for step, batch in enumerate(nt_test_dataloader):
|
| 425 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 426 |
+
b_input_ids, b_input_mask, b_labels = batch
|
| 427 |
+
batch_output = 0
|
| 428 |
+
for j in range(args.eva_iter):
|
| 429 |
+
if args.model == 'temperature':
|
| 430 |
+
current_batch = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
|
| 431 |
+
else:
|
| 432 |
+
current_batch = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)[0]
|
| 433 |
+
batch_output = batch_output + F.softmax(current_batch, dim=1)
|
| 434 |
+
batch_output = batch_output/args.eva_iter
|
| 435 |
+
score_out, _ = batch_output.max(1)
|
| 436 |
+
|
| 437 |
+
score_ood_list.append(score_out)
|
| 438 |
+
|
| 439 |
+
score_ood_array = torch.cat(score_ood_list)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
label_array = label_array.cpu().numpy()
|
| 444 |
+
score_ood_array = score_ood_array.cpu().numpy()
|
| 445 |
+
score_in_array = score_in_array.cpu().numpy()
|
| 446 |
+
correct_array = correct_array.cpu().numpy()
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
####### calculate NBAUCC for detection task
|
| 452 |
+
predict_o = np.zeros(len(score_in_array)+len(score_ood_array))
|
| 453 |
+
true_o = np.ones(len(score_in_array)+len(score_ood_array))
|
| 454 |
+
true_o[:len(score_in_array)] = 0 ## in-distribution data as false, ood data as positive
|
| 455 |
+
true_mis = np.ones(len(score_in_array))
|
| 456 |
+
true_mis[correct_array] = 0 ##true instances as false, misclassified instances as positive
|
| 457 |
+
predict_mis = np.zeros(len(score_in_array))
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
ood_sum = 0
|
| 462 |
+
mis_sum = 0
|
| 463 |
+
|
| 464 |
+
ood_sum_list = []
|
| 465 |
+
mis_sum_list = []
|
| 466 |
+
|
| 467 |
+
#### upper bound of the threshold tau for NBAUCC
|
| 468 |
+
stop_points = [0.50, 1.]
|
| 469 |
+
|
| 470 |
+
for threshold in np.arange(0., 1.01, 0.02):
|
| 471 |
+
predict_ood_index1 = (score_in_array < threshold)
|
| 472 |
+
predict_ood_index2 = (score_ood_array < threshold)
|
| 473 |
+
predict_ood_index = np.concatenate((predict_ood_index1, predict_ood_index2), axis=0)
|
| 474 |
+
predict_o[predict_ood_index] = 1
|
| 475 |
+
predict_mis[score_in_array<threshold] = 1
|
| 476 |
+
|
| 477 |
+
ood = f1_score(true_o, predict_o, average='binary') ##### detection f1 score for a specific threshold
|
| 478 |
+
mis = f1_score(true_mis, predict_mis, average='binary')
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
ood_sum += ood*0.02
|
| 482 |
+
mis_sum += mis*0.02
|
| 483 |
+
|
| 484 |
+
if threshold in stop_points:
|
| 485 |
+
ood_sum_list.append(ood_sum)
|
| 486 |
+
mis_sum_list.append(mis_sum)
|
| 487 |
+
|
| 488 |
+
for i in range(len(stop_points)):
|
| 489 |
+
print('OOD detection, NBAUCC {}: {}'.format(stop_points[i], ood_sum_list[i]/stop_points[i]))
|
| 490 |
+
print('misclassification detection, NBAUCC {}: {}'.format(stop_points[i], mis_sum_list[i]/stop_points[i]))
|
| 491 |
+
|
| 492 |
+
if __name__ == "__main__":
|
| 493 |
+
main()
|
src/reference_code/utils.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from collections import Counter
|
| 6 |
+
import numpy as np
|
| 7 |
+
from sklearn.datasets import fetch_20newsgroups
|
| 8 |
+
from collections import Counter, defaultdict
|
| 9 |
+
from nltk.corpus import stopwords
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
import re
|
| 12 |
+
from sklearn.utils import shuffle
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def cos_dist(x, y):
|
| 17 |
+
## cosine distance function
|
| 18 |
+
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
| 19 |
+
batch_size = x.size(0)
|
| 20 |
+
c = torch.clamp(1 - cos(x.view(batch_size, -1), y.view(batch_size, -1)),
|
| 21 |
+
min=0)
|
| 22 |
+
return c.mean()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def tag_mapping(tags):
|
| 28 |
+
"""
|
| 29 |
+
Create a dictionary and a mapping of tags, sorted by frequency.
|
| 30 |
+
"""
|
| 31 |
+
#tags = [s[1] for s in dataset]
|
| 32 |
+
dico = Counter(tags)
|
| 33 |
+
tag_to_id, id_to_tag = create_mapping(dico)
|
| 34 |
+
print("Found %i unique named entity tags" % len(dico))
|
| 35 |
+
return dico, tag_to_id, id_to_tag
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_mapping(dico):
|
| 39 |
+
"""
|
| 40 |
+
Create a mapping (item to ID / ID to item) from a dictionary.
|
| 41 |
+
Items are ordered by decreasing frequency.
|
| 42 |
+
"""
|
| 43 |
+
sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0]))
|
| 44 |
+
id_to_item = {i: v[0] for i, v in enumerate(sorted_items)}
|
| 45 |
+
item_to_id = {v: k for k, v in id_to_item.items()}
|
| 46 |
+
return item_to_id, id_to_item
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def clean_str(string):
|
| 52 |
+
"""
|
| 53 |
+
Tokenization/string cleaning for all datasets except for SST.
|
| 54 |
+
Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
|
| 55 |
+
"""
|
| 56 |
+
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
|
| 57 |
+
string = re.sub(r"\'s", " \'s", string)
|
| 58 |
+
string = re.sub(r"\'ve", " \'ve", string)
|
| 59 |
+
string = re.sub(r"n\'t", " n\'t", string)
|
| 60 |
+
string = re.sub(r"\'re", " \'re", string)
|
| 61 |
+
string = re.sub(r"\'d", " \'d", string)
|
| 62 |
+
string = re.sub(r"\'ll", " \'ll", string)
|
| 63 |
+
string = re.sub(r",", " , ", string)
|
| 64 |
+
string = re.sub(r"!", " ! ", string)
|
| 65 |
+
string = re.sub(r"\(", " \( ", string)
|
| 66 |
+
string = re.sub(r"\)", " \) ", string)
|
| 67 |
+
string = re.sub(r"\?", " \? ", string)
|
| 68 |
+
string = re.sub(r"\s{2,}", " ", string)
|
| 69 |
+
return string.strip().lower()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def clean_doc(x, word_freq):
|
| 73 |
+
stop_words = set(stopwords.words('english'))
|
| 74 |
+
clean_docs = []
|
| 75 |
+
most_commons = dict(word_freq.most_common(min(len(word_freq), 50000)))
|
| 76 |
+
for doc_content in x:
|
| 77 |
+
doc_words = []
|
| 78 |
+
cleaned = clean_str(doc_content.strip())
|
| 79 |
+
for word in cleaned.split():
|
| 80 |
+
if word not in stop_words and word_freq[word] >= 5:
|
| 81 |
+
if word in most_commons:
|
| 82 |
+
doc_words.append(word)
|
| 83 |
+
else:
|
| 84 |
+
doc_words.append("<UNK>")
|
| 85 |
+
doc_str = ' '.join(doc_words).strip()
|
| 86 |
+
clean_docs.append(doc_str)
|
| 87 |
+
return clean_docs
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def load_dataset(dataset):
|
| 92 |
+
|
| 93 |
+
if dataset == 'sst':
|
| 94 |
+
df_train = pd.read_csv("./dataset/sst/SST-2/train.tsv", delimiter='\t', header=0)
|
| 95 |
+
|
| 96 |
+
df_val = pd.read_csv("./dataset/sst/SST-2/dev.tsv", delimiter='\t', header=0)
|
| 97 |
+
|
| 98 |
+
df_test = pd.read_csv("./dataset/sst/SST-2/sst-test.tsv", delimiter='\t', header=None, names=['sentence', 'label'])
|
| 99 |
+
|
| 100 |
+
train_sentences = df_train.sentence.values
|
| 101 |
+
val_sentences = df_val.sentence.values
|
| 102 |
+
test_sentences = df_test.sentence.values
|
| 103 |
+
train_labels = df_train.label.values
|
| 104 |
+
val_labels = df_val.label.values
|
| 105 |
+
test_labels = df_test.label.values
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if dataset == '20news':
|
| 109 |
+
|
| 110 |
+
VALIDATION_SPLIT = 0.8
|
| 111 |
+
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, random_state=0)
|
| 112 |
+
print(newsgroups_train.target_names)
|
| 113 |
+
print(len(newsgroups_train.data))
|
| 114 |
+
|
| 115 |
+
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False)
|
| 116 |
+
|
| 117 |
+
print(len(newsgroups_test.data))
|
| 118 |
+
|
| 119 |
+
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))
|
| 120 |
+
|
| 121 |
+
train_sentences = newsgroups_train.data[:train_len]
|
| 122 |
+
val_sentences = newsgroups_train.data[train_len:]
|
| 123 |
+
test_sentences = newsgroups_test.data
|
| 124 |
+
train_labels = newsgroups_train.target[:train_len]
|
| 125 |
+
val_labels = newsgroups_train.target[train_len:]
|
| 126 |
+
test_labels = newsgroups_test.target
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if dataset == '20news-15':
|
| 131 |
+
VALIDATION_SPLIT = 0.8
|
| 132 |
+
cats = ['alt.atheism',
|
| 133 |
+
'comp.graphics',
|
| 134 |
+
'comp.os.ms-windows.misc',
|
| 135 |
+
'comp.sys.ibm.pc.hardware',
|
| 136 |
+
'comp.sys.mac.hardware',
|
| 137 |
+
'comp.windows.x',
|
| 138 |
+
'rec.autos',
|
| 139 |
+
'rec.motorcycles',
|
| 140 |
+
'rec.sport.baseball',
|
| 141 |
+
'rec.sport.hockey',
|
| 142 |
+
'misc.forsale',
|
| 143 |
+
'sci.crypt',
|
| 144 |
+
'sci.electronics',
|
| 145 |
+
'sci.med',
|
| 146 |
+
'sci.space']
|
| 147 |
+
newsgroups_train = fetch_20newsgroups('dataset/20news', subset='train', shuffle=True, categories=cats, random_state=0)
|
| 148 |
+
print(newsgroups_train.target_names)
|
| 149 |
+
print(len(newsgroups_train.data))
|
| 150 |
+
|
| 151 |
+
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats)
|
| 152 |
+
|
| 153 |
+
print(len(newsgroups_test.data))
|
| 154 |
+
|
| 155 |
+
train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))
|
| 156 |
+
|
| 157 |
+
train_sentences = newsgroups_train.data[:train_len]
|
| 158 |
+
val_sentences = newsgroups_train.data[train_len:]
|
| 159 |
+
test_sentences = newsgroups_test.data
|
| 160 |
+
train_labels = newsgroups_train.target[:train_len]
|
| 161 |
+
val_labels = newsgroups_train.target[train_len:]
|
| 162 |
+
test_labels = newsgroups_test.target
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if dataset == '20news-5':
|
| 166 |
+
cats = [
|
| 167 |
+
'soc.religion.christian',
|
| 168 |
+
'talk.politics.guns',
|
| 169 |
+
'talk.politics.mideast',
|
| 170 |
+
'talk.politics.misc',
|
| 171 |
+
'talk.religion.misc']
|
| 172 |
+
|
| 173 |
+
newsgroups_test = fetch_20newsgroups('dataset/20news', subset='test', shuffle=False, categories=cats)
|
| 174 |
+
print(newsgroups_test.target_names)
|
| 175 |
+
print(len(newsgroups_test.data))
|
| 176 |
+
|
| 177 |
+
train_sentences = None
|
| 178 |
+
val_sentences = None
|
| 179 |
+
test_sentences = newsgroups_test.data
|
| 180 |
+
train_labels = None
|
| 181 |
+
val_labels = None
|
| 182 |
+
test_labels = newsgroups_test.target
|
| 183 |
+
|
| 184 |
+
if dataset == 'wos':
|
| 185 |
+
TESTING_SPLIT = 0.6
|
| 186 |
+
VALIDATION_SPLIT = 0.8
|
| 187 |
+
file_path = './dataset/WebOfScience/WOS46985/X.txt'
|
| 188 |
+
with open(file_path, 'r') as read_file:
|
| 189 |
+
x_temp = read_file.readlines()
|
| 190 |
+
x_all = []
|
| 191 |
+
for x in x_temp:
|
| 192 |
+
x_all.append(str(x))
|
| 193 |
+
|
| 194 |
+
print(len(x_all))
|
| 195 |
+
|
| 196 |
+
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
|
| 197 |
+
with open(file_path, 'r') as read_file:
|
| 198 |
+
y_temp= read_file.readlines()
|
| 199 |
+
y_all = []
|
| 200 |
+
for y in y_temp:
|
| 201 |
+
y_all.append(int(y))
|
| 202 |
+
print(len(y_all))
|
| 203 |
+
print(max(y_all), min(y_all))
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
x_in = []
|
| 207 |
+
y_in = []
|
| 208 |
+
for i in range(len(x_all)):
|
| 209 |
+
x_in.append(x_all[i])
|
| 210 |
+
y_in.append(y_all[i])
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
train_val_len = int(TESTING_SPLIT * len(x_in))
|
| 214 |
+
train_len = int(VALIDATION_SPLIT * train_val_len)
|
| 215 |
+
|
| 216 |
+
train_sentences = x_in[:train_len]
|
| 217 |
+
val_sentences = x_in[train_len:train_val_len]
|
| 218 |
+
test_sentences = x_in[train_val_len:]
|
| 219 |
+
|
| 220 |
+
train_labels = y_in[:train_len]
|
| 221 |
+
val_labels = y_in[train_len:train_val_len]
|
| 222 |
+
test_labels = y_in[train_val_len:]
|
| 223 |
+
|
| 224 |
+
print(len(train_labels))
|
| 225 |
+
print(len(val_labels))
|
| 226 |
+
print(len(test_labels))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
if dataset == 'wos-100':
|
| 230 |
+
TESTING_SPLIT = 0.6
|
| 231 |
+
VALIDATION_SPLIT = 0.8
|
| 232 |
+
file_path = './dataset/WebOfScience/WOS46985/X.txt'
|
| 233 |
+
with open(file_path, 'r') as read_file:
|
| 234 |
+
x_temp = read_file.readlines()
|
| 235 |
+
x_all = []
|
| 236 |
+
for x in x_temp:
|
| 237 |
+
x_all.append(str(x))
|
| 238 |
+
|
| 239 |
+
print(len(x_all))
|
| 240 |
+
|
| 241 |
+
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
|
| 242 |
+
with open(file_path, 'r') as read_file:
|
| 243 |
+
y_temp= read_file.readlines()
|
| 244 |
+
y_all = []
|
| 245 |
+
for y in y_temp:
|
| 246 |
+
y_all.append(int(y))
|
| 247 |
+
print(len(y_all))
|
| 248 |
+
print(max(y_all), min(y_all))
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
x_in = []
|
| 252 |
+
y_in = []
|
| 253 |
+
for i in range(len(x_all)):
|
| 254 |
+
if y_all[i] in range(100):
|
| 255 |
+
x_in.append(x_all[i])
|
| 256 |
+
y_in.append(y_all[i])
|
| 257 |
+
|
| 258 |
+
for i in range(133):
|
| 259 |
+
num = 0
|
| 260 |
+
for y in y_in:
|
| 261 |
+
if y == i:
|
| 262 |
+
num = num + 1
|
| 263 |
+
# print(num)
|
| 264 |
+
|
| 265 |
+
train_val_len = int(TESTING_SPLIT * len(x_in))
|
| 266 |
+
train_len = int(VALIDATION_SPLIT * train_val_len)
|
| 267 |
+
|
| 268 |
+
train_sentences = x_in[:train_len]
|
| 269 |
+
val_sentences = x_in[train_len:train_val_len]
|
| 270 |
+
test_sentences = x_in[train_val_len:]
|
| 271 |
+
|
| 272 |
+
train_labels = y_in[:train_len]
|
| 273 |
+
val_labels = y_in[train_len:train_val_len]
|
| 274 |
+
test_labels = y_in[train_val_len:]
|
| 275 |
+
|
| 276 |
+
print(len(train_labels))
|
| 277 |
+
print(len(val_labels))
|
| 278 |
+
print(len(test_labels))
|
| 279 |
+
|
| 280 |
+
if dataset == 'wos-34':
|
| 281 |
+
TESTING_SPLIT = 0.6
|
| 282 |
+
VALIDATION_SPLIT = 0.8
|
| 283 |
+
file_path = './dataset/WebOfScience/WOS46985/X.txt'
|
| 284 |
+
with open(file_path, 'r') as read_file:
|
| 285 |
+
x_temp = read_file.readlines()
|
| 286 |
+
x_all = []
|
| 287 |
+
for x in x_temp:
|
| 288 |
+
x_all.append(str(x))
|
| 289 |
+
|
| 290 |
+
print(len(x_all))
|
| 291 |
+
|
| 292 |
+
file_path = './dataset/WebOfScience/WOS46985/Y.txt'
|
| 293 |
+
with open(file_path, 'r') as read_file:
|
| 294 |
+
y_temp= read_file.readlines()
|
| 295 |
+
y_all = []
|
| 296 |
+
for y in y_temp:
|
| 297 |
+
y_all.append(int(y))
|
| 298 |
+
print(len(y_all))
|
| 299 |
+
print(max(y_all), min(y_all))
|
| 300 |
+
|
| 301 |
+
x_in = []
|
| 302 |
+
y_in = []
|
| 303 |
+
for i in range(len(x_all)):
|
| 304 |
+
if (y_all[i] in range(100)) != True:
|
| 305 |
+
x_in.append(x_all[i])
|
| 306 |
+
y_in.append(y_all[i])
|
| 307 |
+
|
| 308 |
+
for i in range(133):
|
| 309 |
+
num = 0
|
| 310 |
+
for y in y_in:
|
| 311 |
+
if y == i:
|
| 312 |
+
num = num + 1
|
| 313 |
+
# print(num)
|
| 314 |
+
|
| 315 |
+
train_val_len = int(TESTING_SPLIT * len(x_in))
|
| 316 |
+
train_len = int(VALIDATION_SPLIT * train_val_len)
|
| 317 |
+
|
| 318 |
+
train_sentences = None
|
| 319 |
+
val_sentences = None
|
| 320 |
+
test_sentences = x_in[train_val_len:]
|
| 321 |
+
|
| 322 |
+
train_labels = None
|
| 323 |
+
val_labels = None
|
| 324 |
+
test_labels = y_in[train_val_len:]
|
| 325 |
+
|
| 326 |
+
print(len(test_labels))
|
| 327 |
+
|
| 328 |
+
if dataset == 'agnews':
|
| 329 |
+
|
| 330 |
+
VALIDATION_SPLIT = 0.8
|
| 331 |
+
labels_in_domain = [1, 2]
|
| 332 |
+
|
| 333 |
+
train_df = pd.read_csv('./dataset/agnews/train.csv', header=None)
|
| 334 |
+
train_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
|
| 335 |
+
# train_df = pd.concat([train_df, pd.get_dummies(train_df['label'],prefix='label')], axis=1)
|
| 336 |
+
print(train_df.dtypes)
|
| 337 |
+
train_in_df_sentence = []
|
| 338 |
+
train_in_df_label = []
|
| 339 |
+
|
| 340 |
+
for i in range(len(train_df.sentence.values)):
|
| 341 |
+
sentence_temp = ''.join(str(train_df.sentence.values[i]))
|
| 342 |
+
train_in_df_sentence.append(sentence_temp)
|
| 343 |
+
train_in_df_label.append(train_df.label.values[i]-1)
|
| 344 |
+
|
| 345 |
+
test_df = pd.read_csv('./dataset/agnews/test.csv', header=None)
|
| 346 |
+
test_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
|
| 347 |
+
# test_df = pd.concat([test_df, pd.get_dummies(test_df['label'],prefix='label')], axis=1)
|
| 348 |
+
test_in_df_sentence = []
|
| 349 |
+
test_in_df_label = []
|
| 350 |
+
for i in range(len(test_df.sentence.values)):
|
| 351 |
+
test_in_df_sentence.append(str(test_df.sentence.values[i]))
|
| 352 |
+
test_in_df_label.append(test_df.label.values[i]-1)
|
| 353 |
+
|
| 354 |
+
train_len = int(VALIDATION_SPLIT * len(train_in_df_sentence))
|
| 355 |
+
|
| 356 |
+
train_sentences = train_in_df_sentence[:train_len]
|
| 357 |
+
val_sentences = train_in_df_sentence[train_len:]
|
| 358 |
+
test_sentences = test_in_df_sentence
|
| 359 |
+
train_labels = train_in_df_label[:train_len]
|
| 360 |
+
val_labels = train_in_df_label[train_len:]
|
| 361 |
+
test_labels = test_in_df_label
|
| 362 |
+
print(len(train_sentences))
|
| 363 |
+
print(len(val_sentences))
|
| 364 |
+
print(len(test_sentences))
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
return train_sentences, val_sentences, test_sentences, train_labels, val_labels, test_labels
|
| 368 |
+
|
| 369 |
+
|
src/reference_code/visualization.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
#import matplotlib as mpl
|
| 3 |
+
#mpl.use('Agg')
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
import metrics
|
| 7 |
+
|
| 8 |
+
class ConfidenceHistogram(metrics.MaxProbCELoss):
|
| 9 |
+
|
| 10 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
| 11 |
+
super().loss(output, labels, n_bins, logits)
|
| 12 |
+
#scale each datapoint
|
| 13 |
+
n = len(labels)
|
| 14 |
+
w = np.ones(n)/n
|
| 15 |
+
|
| 16 |
+
plt.rcParams["font.family"] = "serif"
|
| 17 |
+
#size and axis limits
|
| 18 |
+
plt.figure(figsize=(3,3))
|
| 19 |
+
plt.xlim(0,1)
|
| 20 |
+
plt.ylim(0,1)
|
| 21 |
+
plt.xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
| 22 |
+
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
| 23 |
+
#plot grid
|
| 24 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
| 25 |
+
#plot histogram
|
| 26 |
+
plt.hist(self.confidences,n_bins,weights = w,color='b',range=(0.0,1.0),edgecolor = 'k')
|
| 27 |
+
|
| 28 |
+
#plot vertical dashed lines
|
| 29 |
+
acc = np.mean(self.accuracies)
|
| 30 |
+
conf = np.mean(self.confidences)
|
| 31 |
+
plt.axvline(x=acc, color='tab:grey', linestyle='--', linewidth = 3)
|
| 32 |
+
plt.axvline(x=conf, color='tab:grey', linestyle='--', linewidth = 3)
|
| 33 |
+
if acc > conf:
|
| 34 |
+
plt.text(acc+0.03,0.9,'Accuracy',rotation=90,fontsize=11)
|
| 35 |
+
plt.text(conf-0.07,0.9,'Avg. Confidence',rotation=90, fontsize=11)
|
| 36 |
+
else:
|
| 37 |
+
plt.text(acc-0.07,0.9,'Accuracy',rotation=90,fontsize=11)
|
| 38 |
+
plt.text(conf+0.03,0.9,'Avg. Confidence',rotation=90, fontsize=11)
|
| 39 |
+
|
| 40 |
+
plt.ylabel('% of Samples',fontsize=13)
|
| 41 |
+
plt.xlabel('Confidence',fontsize=13)
|
| 42 |
+
plt.tight_layout()
|
| 43 |
+
if title is not None:
|
| 44 |
+
plt.title(title,fontsize=16)
|
| 45 |
+
return plt
|
| 46 |
+
|
| 47 |
+
class ReliabilityDiagram(metrics.MaxProbCELoss):
|
| 48 |
+
|
| 49 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
| 50 |
+
super().loss(output, labels, n_bins, logits)
|
| 51 |
+
|
| 52 |
+
#computations
|
| 53 |
+
delta = 1.0/n_bins
|
| 54 |
+
x = np.arange(0,1,delta)
|
| 55 |
+
mid = np.linspace(delta/2,1-delta/2,n_bins)
|
| 56 |
+
error = np.abs(np.subtract(mid,self.bin_acc))
|
| 57 |
+
|
| 58 |
+
plt.rcParams["font.family"] = "serif"
|
| 59 |
+
#size and axis limits
|
| 60 |
+
plt.figure(figsize=(3,3))
|
| 61 |
+
plt.xlim(0,1)
|
| 62 |
+
plt.ylim(0,1)
|
| 63 |
+
#plot grid
|
| 64 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
| 65 |
+
#plot bars and identity line
|
| 66 |
+
plt.bar(x, self.bin_acc, color = 'b', width=delta,align='edge',edgecolor = 'k',label='Outputs',zorder=5)
|
| 67 |
+
plt.bar(x, error, bottom=np.minimum(self.bin_acc,mid), color = 'mistyrose', alpha=0.5, width=delta,align='edge',edgecolor = 'r',hatch='/',label='Gap',zorder=10)
|
| 68 |
+
ident = [0.0, 1.0]
|
| 69 |
+
plt.plot(ident,ident,linestyle='--',color='tab:grey',zorder=15)
|
| 70 |
+
#labels and legend
|
| 71 |
+
plt.ylabel('Accuracy',fontsize=13)
|
| 72 |
+
plt.xlabel('Confidence',fontsize=13)
|
| 73 |
+
plt.legend(loc='upper left',framealpha=1.0,fontsize='medium')
|
| 74 |
+
if title is not None:
|
| 75 |
+
plt.title(title,fontsize=16)
|
| 76 |
+
plt.tight_layout()
|
| 77 |
+
|
| 78 |
+
return plt
|
src/seq_model.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from bert import BERT
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class BERTSM(nn.Module):
|
|
@@ -18,6 +22,12 @@ class BERTSM(nn.Module):
|
|
| 18 |
super().__init__()
|
| 19 |
self.bert = bert
|
| 20 |
self.mask_lm = MaskedSequenceModel(self.bert.hidden, vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
self.same_student = SameStudentPrediction(self.bert.hidden)
|
| 22 |
|
| 23 |
def forward(self, x, segment_label, pred=False):
|
|
@@ -28,6 +38,7 @@ class BERTSM(nn.Module):
|
|
| 28 |
return x[:, 0], self.mask_lm(x), self.same_student(x)
|
| 29 |
else:
|
| 30 |
return x[:, 0], self.mask_lm(x)
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
class MaskedSequenceModel(nn.Module):
|
|
@@ -46,6 +57,9 @@ class MaskedSequenceModel(nn.Module):
|
|
| 46 |
self.softmax = nn.LogSoftmax(dim=-1)
|
| 47 |
|
| 48 |
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
| 49 |
return self.softmax(self.linear(x))
|
| 50 |
|
| 51 |
|
|
@@ -62,3 +76,4 @@ class SameStudentPrediction(nn.Module):
|
|
| 62 |
def forward(self, x):
|
| 63 |
return self.softmax(self.linear(x[:, 0]))
|
| 64 |
|
|
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
|
| 3 |
+
<<<<<<< HEAD
|
| 4 |
+
from .bert import BERT
|
| 5 |
+
=======
|
| 6 |
from bert import BERT
|
| 7 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 8 |
|
| 9 |
|
| 10 |
class BERTSM(nn.Module):
|
|
|
|
| 22 |
super().__init__()
|
| 23 |
self.bert = bert
|
| 24 |
self.mask_lm = MaskedSequenceModel(self.bert.hidden, vocab_size)
|
| 25 |
+
<<<<<<< HEAD
|
| 26 |
+
|
| 27 |
+
def forward(self, x, segment_label):
|
| 28 |
+
x = self.bert(x, segment_label)
|
| 29 |
+
return self.mask_lm(x), x[:, 0]
|
| 30 |
+
=======
|
| 31 |
self.same_student = SameStudentPrediction(self.bert.hidden)
|
| 32 |
|
| 33 |
def forward(self, x, segment_label, pred=False):
|
|
|
|
| 38 |
return x[:, 0], self.mask_lm(x), self.same_student(x)
|
| 39 |
else:
|
| 40 |
return x[:, 0], self.mask_lm(x)
|
| 41 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 42 |
|
| 43 |
|
| 44 |
class MaskedSequenceModel(nn.Module):
|
|
|
|
| 57 |
self.softmax = nn.LogSoftmax(dim=-1)
|
| 58 |
|
| 59 |
def forward(self, x):
|
| 60 |
+
<<<<<<< HEAD
|
| 61 |
+
return self.softmax(self.linear(x))
|
| 62 |
+
=======
|
| 63 |
return self.softmax(self.linear(x))
|
| 64 |
|
| 65 |
|
|
|
|
| 76 |
def forward(self, x):
|
| 77 |
return self.softmax(self.linear(x[:, 0]))
|
| 78 |
|
| 79 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
src/transformer.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
import torch.nn as nn
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from attention import MultiHeadedAttention
|
| 4 |
from transformer_component import SublayerConnection, PositionwiseFeedForward
|
|
|
|
| 5 |
|
| 6 |
class TransformerBlock(nn.Module):
|
| 7 |
"""
|
|
@@ -25,6 +30,12 @@ class TransformerBlock(nn.Module):
|
|
| 25 |
self.dropout = nn.Dropout(p=dropout)
|
| 26 |
|
| 27 |
def forward(self, x, mask):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
|
|
|
|
| 29 |
x = self.output_sublayer(x, self.feed_forward)
|
| 30 |
return self.dropout(x)
|
|
|
|
| 1 |
import torch.nn as nn
|
| 2 |
|
| 3 |
+
<<<<<<< HEAD
|
| 4 |
+
from .attention import MultiHeadedAttention
|
| 5 |
+
from .transformer_component import SublayerConnection, PositionwiseFeedForward
|
| 6 |
+
=======
|
| 7 |
from attention import MultiHeadedAttention
|
| 8 |
from transformer_component import SublayerConnection, PositionwiseFeedForward
|
| 9 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 10 |
|
| 11 |
class TransformerBlock(nn.Module):
|
| 12 |
"""
|
|
|
|
| 30 |
self.dropout = nn.Dropout(p=dropout)
|
| 31 |
|
| 32 |
def forward(self, x, mask):
|
| 33 |
+
<<<<<<< HEAD
|
| 34 |
+
attn_output, p_attn = self.attention.forward(x, x, x, mask=mask)
|
| 35 |
+
self.p_attn = p_attn.cpu().detach().numpy()
|
| 36 |
+
x = self.input_sublayer(x, lambda _x: attn_output)
|
| 37 |
+
=======
|
| 38 |
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
|
| 39 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 40 |
x = self.output_sublayer(x, self.feed_forward)
|
| 41 |
return self.dropout(x)
|
src/vocab.py
CHANGED
|
@@ -1,9 +1,22 @@
|
|
| 1 |
import collections
|
| 2 |
import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
class Vocab(object):
|
| 5 |
"""
|
| 6 |
Special tokens predefined in the vocab file are:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
-[UNK]
|
| 8 |
-[MASK]
|
| 9 |
-[CLS]
|
|
@@ -35,7 +48,11 @@ class Vocab(object):
|
|
| 35 |
words = [self.invocab[index] if index < len(self.invocab)
|
| 36 |
else "[%d]" % index for index in seq ]
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
return " ".join(words)
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
# if __init__ == "__main__":
|
|
|
|
| 1 |
import collections
|
| 2 |
import tqdm
|
| 3 |
+
<<<<<<< HEAD
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
head_directory = Path(__file__).resolve().parent.parent
|
| 8 |
+
# print(head_directory)
|
| 9 |
+
os.chdir(head_directory)
|
| 10 |
+
=======
|
| 11 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 12 |
|
| 13 |
class Vocab(object):
|
| 14 |
"""
|
| 15 |
Special tokens predefined in the vocab file are:
|
| 16 |
+
<<<<<<< HEAD
|
| 17 |
+
-[PAD]
|
| 18 |
+
=======
|
| 19 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 20 |
-[UNK]
|
| 21 |
-[MASK]
|
| 22 |
-[CLS]
|
|
|
|
| 48 |
words = [self.invocab[index] if index < len(self.invocab)
|
| 49 |
else "[%d]" % index for index in seq ]
|
| 50 |
|
| 51 |
+
<<<<<<< HEAD
|
| 52 |
+
return words #" ".join(words)
|
| 53 |
+
=======
|
| 54 |
return " ".join(words)
|
| 55 |
+
>>>>>>> bffd3381ccb717f802fe651d4111ec0a268e3896
|
| 56 |
|
| 57 |
|
| 58 |
# if __init__ == "__main__":
|
test.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
subprocess.run([
|
| 3 |
+
"python", "new_test_saved_finetuned_model.py",
|
| 4 |
+
"-workspace_name", "ratio_proportion_change3_2223/sch_largest_100-coded",
|
| 5 |
+
"-finetune_task", "highGRschool10",
|
| 6 |
+
"-finetuned_bert_classifier_checkpoint",
|
| 7 |
+
"ratio_proportion_change3_2223/sch_largest_100-coded/output/highGRschool10/bert_fine_tuned.model.ep42"
|
| 8 |
+
])
|
test.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_hint_fine_tuned.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
from src.vocab import Vocab
|
| 4 |
+
from src.dataset import TokenizerDataset
|
| 5 |
+
from hint_fine_tuning import CustomBERTModel
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
def test_model(opt):
|
| 9 |
+
print(f"Loading Vocab {opt.vocab_path}")
|
| 10 |
+
vocab = Vocab(opt.vocab_path)
|
| 11 |
+
vocab.load_vocab()
|
| 12 |
+
|
| 13 |
+
print(f"Vocab Size: {len(vocab.vocab)}")
|
| 14 |
+
|
| 15 |
+
test_dataset = TokenizerDataset(opt.test_dataset, opt.test_label, vocab, seq_len=50) # Using sequence length 50
|
| 16 |
+
print(f"Creating Dataloader")
|
| 17 |
+
test_data_loader = DataLoader(test_dataset, batch_size=32, num_workers=4)
|
| 18 |
+
|
| 19 |
+
# Load the entire fine-tuned model (including both architecture and weights)
|
| 20 |
+
print(f"Loading Model from {opt.finetuned_bert_checkpoint}")
|
| 21 |
+
model = torch.load(opt.finetuned_bert_checkpoint, map_location="cpu")
|
| 22 |
+
|
| 23 |
+
print(f"Number of Labels: {opt.num_labels}")
|
| 24 |
+
|
| 25 |
+
model.eval()
|
| 26 |
+
for batch_idx, data in enumerate(test_data_loader):
|
| 27 |
+
inputs = data["input"].to("cpu")
|
| 28 |
+
segment_info = data["segment_label"].to("cpu")
|
| 29 |
+
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
logits = model(inputs, segment_info)
|
| 32 |
+
|
| 33 |
+
print(f"Batch {batch_idx} logits: {logits}")
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
|
| 38 |
+
parser.add_argument("-t", "--test_dataset", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_test_dataset.csv", help="test set for evaluating fine-tuned model")
|
| 39 |
+
parser.add_argument("-tlabel", "--test_label", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/test_infos_only.csv", help="label set for evaluating fine-tuned model")
|
| 40 |
+
parser.add_argument("-c", "--finetuned_bert_checkpoint", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification/fine_tuned_model_2.pth", help="checkpoint of the saved fine-tuned BERT model")
|
| 41 |
+
parser.add_argument("-v", "--vocab_path", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt", help="built vocab model path")
|
| 42 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
| 43 |
+
|
| 44 |
+
opt = parser.parse_args()
|
| 45 |
+
test_model(opt)
|
test_saved_model.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import torch.nn as nn
|
| 2 |
+
# import torch
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.optim import Adam, SGD
|
| 9 |
+
import torch
|
| 10 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 11 |
+
|
| 12 |
+
from src.pretrainer import BERTFineTuneTrainer1
|
| 13 |
+
from src.dataset import TokenizerDataset
|
| 14 |
+
from src.vocab import Vocab
|
| 15 |
+
|
| 16 |
+
import tqdm
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
import time
|
| 20 |
+
from src.bert import BERT
|
| 21 |
+
from hint_fine_tuning import CustomBERTModel
|
| 22 |
+
|
| 23 |
+
# from vocab import Vocab
|
| 24 |
+
|
| 25 |
+
# class BERTForSequenceClassification(nn.Module):
|
| 26 |
+
# """
|
| 27 |
+
# Since its classification,
|
| 28 |
+
# n_labels = 2
|
| 29 |
+
# """
|
| 30 |
+
|
| 31 |
+
# def __init__(self, vocab_size, n_labels, layers=None, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
|
| 32 |
+
# super().__init__()
|
| 33 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 34 |
+
# print(device)
|
| 35 |
+
# # model_ep0 = torch.load("output_1/bert_trained.model.ep0", map_location=device)
|
| 36 |
+
# self.bert = torch.load("output_1/bert_trained.model.ep0", map_location=device)
|
| 37 |
+
# self.dropout = nn.Dropout(dropout)
|
| 38 |
+
# # add an output layer
|
| 39 |
+
# self.
|
| 40 |
+
|
| 41 |
+
# def forward(self, x, segment_info):
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class BERTFineTunedTrainer:
|
| 48 |
+
|
| 49 |
+
def __init__(self, bert: CustomBERTModel, vocab_size: int,
|
| 50 |
+
train_dataloader: DataLoader = None, test_dataloader: DataLoader = None,
|
| 51 |
+
lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
|
| 52 |
+
with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, num_labels=2):
|
| 53 |
+
"""
|
| 54 |
+
:param bert: BERT model which you want to train
|
| 55 |
+
:param vocab_size: total word vocab size
|
| 56 |
+
:param train_dataloader: train dataset data loader
|
| 57 |
+
:param test_dataloader: test dataset data loader [can be None]
|
| 58 |
+
:param lr: learning rate of optimizer
|
| 59 |
+
:param betas: Adam optimizer betas
|
| 60 |
+
:param weight_decay: Adam optimizer weight decay param
|
| 61 |
+
:param with_cuda: traning with cuda
|
| 62 |
+
:param log_freq: logging frequency of the batch iteration
|
| 63 |
+
"""
|
| 64 |
+
self.device = "cpu"
|
| 65 |
+
self.model = bert
|
| 66 |
+
self.test_data = test_dataloader
|
| 67 |
+
|
| 68 |
+
self.log_freq = log_freq
|
| 69 |
+
self.workspace_name = workspace_name
|
| 70 |
+
# print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
|
| 71 |
+
|
| 72 |
+
def test(self, epoch):
|
| 73 |
+
self.iteration(epoch, self.test_data, train=False)
|
| 74 |
+
|
| 75 |
+
def iteration(self, epoch, data_loader, train=True):
|
| 76 |
+
"""
|
| 77 |
+
loop over the data_loader for training or testing
|
| 78 |
+
if on train status, backward operation is activated
|
| 79 |
+
and also auto save the model every peoch
|
| 80 |
+
|
| 81 |
+
:param epoch: current epoch index
|
| 82 |
+
:param data_loader: torch.utils.data.DataLoader for iteration
|
| 83 |
+
:param train: boolean value of is train or test
|
| 84 |
+
:return: None
|
| 85 |
+
"""
|
| 86 |
+
str_code = "train" if train else "test"
|
| 87 |
+
|
| 88 |
+
# Setting the tqdm progress bar
|
| 89 |
+
data_iter = tqdm.tqdm(enumerate(data_loader),
|
| 90 |
+
desc="EP_%s:%d" % (str_code, epoch),
|
| 91 |
+
total=len(data_loader),
|
| 92 |
+
bar_format="{l_bar}{r_bar}")
|
| 93 |
+
|
| 94 |
+
avg_loss = 0.0
|
| 95 |
+
total_correct = 0
|
| 96 |
+
total_element = 0
|
| 97 |
+
|
| 98 |
+
plabels = []
|
| 99 |
+
tlabels = []
|
| 100 |
+
logits_list = []
|
| 101 |
+
labels_list = []
|
| 102 |
+
positive_class_probs = []
|
| 103 |
+
self.model.eval()
|
| 104 |
+
|
| 105 |
+
for i, data in data_iter:
|
| 106 |
+
data = {key: value.to(self.device) for key, value in data.items()}
|
| 107 |
+
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
h_rep, logits = self.model.forward(data["input"], data["segment_label"])
|
| 110 |
+
# print(logits, logits.shape)
|
| 111 |
+
logits_list.append(logits.cpu())
|
| 112 |
+
labels_list.append(data["label"].cpu())
|
| 113 |
+
|
| 114 |
+
probs = F.Softmax(dim=-1)(logits)
|
| 115 |
+
predicted_labels = torch.argmax(probs, dim=-1)
|
| 116 |
+
true_labels = torch.argmax(data["label"], dim=-1)
|
| 117 |
+
positive_class_probs.extend(probs[:, 1])
|
| 118 |
+
plabels.extend(predicted_labels.cpu().numpy())
|
| 119 |
+
tlabels.extend(true_labels.cpu().numpy())
|
| 120 |
+
|
| 121 |
+
# print(">>>>>>>>>>>>>>", predicted_labels, true_labels)
|
| 122 |
+
# Compare predicted labels to true labels and calculate accuracy
|
| 123 |
+
correct = (predicted_labels == true_labels).sum().item()
|
| 124 |
+
total_correct += correct
|
| 125 |
+
total_element += data["label"].nelement()
|
| 126 |
+
|
| 127 |
+
precisions = precision_score(tlabels, plabels, average="weighted")
|
| 128 |
+
recalls = recall_score(tlabels, plabels, average="weighted")
|
| 129 |
+
f1_scores = f1_score(tlabels, plabels, average="weighted")
|
| 130 |
+
accuracy = total_correct * 100.0 / total_element
|
| 131 |
+
auc_score = roc_auc_score(tlabels.cpu(), plabels.cpu())
|
| 132 |
+
|
| 133 |
+
final_msg = {
|
| 134 |
+
"epoch": f"EP{epoch}_{str_code}",
|
| 135 |
+
"accuracy": accuracy,
|
| 136 |
+
"avg_loss": avg_loss / len(data_iter),
|
| 137 |
+
"precisions": precisions,
|
| 138 |
+
"recalls": recalls,
|
| 139 |
+
"f1_scores": f1_scores
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
print(final_msg)
|
| 143 |
+
|
| 144 |
+
# print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", total_correct * 100.0 / total_element)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 149 |
+
# print(device)
|
| 150 |
+
# is_model = torch.load("ratio_proportion_change4/output/bert_fine_tuned.IS.model.ep40", map_location=device)
|
| 151 |
+
# learned_parameters = model_ep0.state_dict()
|
| 152 |
+
|
| 153 |
+
# for param_name, param_tensor in learned_parameters.items():
|
| 154 |
+
# print(param_name)
|
| 155 |
+
# print(param_tensor)
|
| 156 |
+
# # print(model_ep0.state_dict())
|
| 157 |
+
# # model_ep0.add_module("out", nn.Linear(10,2))
|
| 158 |
+
# # print(model_ep0)
|
| 159 |
+
# seq_vocab = Vocab("pretraining/vocab_file.txt")
|
| 160 |
+
# seq_vocab.load_vocab()
|
| 161 |
+
# classifier = BERTForSequenceClassification(len(seq_vocab.vocab), 2)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
parser = argparse.ArgumentParser()
|
| 165 |
+
|
| 166 |
+
parser.add_argument('-workspace_name', type=str, default="ratio_proportion_change3_1920")
|
| 167 |
+
# parser.add_argument("-t", "--test_dataset", type=str, default="finetuning/before_June/train_in.txt", help="test set for evaluate fine tune train set")
|
| 168 |
+
# parser.add_argument("-tlabel", "--test_label", type=str, default="finetuning/before_June/train_in_label.txt", help="test set for evaluate fine tune train set")
|
| 169 |
+
# ##### change Checkpoint
|
| 170 |
+
# parser.add_argument("-c", "--finetuned_bert_checkpoint", type=str, default="ratio_proportion_change3/output/before_June/bert_fine_tuned.FS.model.ep30", help="checkpoint of saved pretrained bert model")
|
| 171 |
+
# parser.add_argument("-v", "--vocab_path", type=str, default="pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
| 172 |
+
parser.add_argument("-t", "--test_dataset", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/er_test_dataset.csv", help="test set for evaluate fine tune train set")
|
| 173 |
+
parser.add_argument("-tlabel", "--test_label", type=str, default="/home/jupyter/bert/dataset/hint_based/ratio_proportion_change_3/er/test_infos_only.csv", help="test set for evaluate fine tune train set")
|
| 174 |
+
##### change Checkpoint
|
| 175 |
+
parser.add_argument("-c", "--finetuned_bert_checkpoint", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/output/hint_classification/fine_tuned_model_2.pth", help="checkpoint of saved pretrained bert model")
|
| 176 |
+
parser.add_argument("-v", "--vocab_path", type=str, default="/home/jupyter/bert/ratio_proportion_change3_1920/_Aug23/pretraining/vocab.txt", help="built vocab model path with bert-vocab")
|
| 177 |
+
parser.add_argument("-num_labels", type=int, default=2, help="Number of labels")
|
| 178 |
+
|
| 179 |
+
parser.add_argument("-hs", "--hidden", type=int, default=64, help="hidden size of transformer model")
|
| 180 |
+
parser.add_argument("-l", "--layers", type=int, default=4, help="number of layers")
|
| 181 |
+
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
|
| 182 |
+
parser.add_argument("-s", "--seq_len", type=int, default=100, help="maximum sequence length")
|
| 183 |
+
|
| 184 |
+
parser.add_argument("-b", "--batch_size", type=int, default=32, help="number of batch_size")
|
| 185 |
+
parser.add_argument("-e", "--epochs", type=int, default=1, help="number of epochs")
|
| 186 |
+
# Use 50 for pretrain, and 10 for fine tune
|
| 187 |
+
parser.add_argument("-w", "--num_workers", type=int, default=4, help="dataloader worker size")
|
| 188 |
+
|
| 189 |
+
# Later run with cuda
|
| 190 |
+
parser.add_argument("--with_cuda", type=bool, default=False, help="training with CUDA: true, or false")
|
| 191 |
+
parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n")
|
| 192 |
+
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus")
|
| 193 |
+
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids")
|
| 194 |
+
parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false")
|
| 195 |
+
|
| 196 |
+
parser.add_argument("--dropout", type=float, default=0.1, help="dropout of network")
|
| 197 |
+
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam")
|
| 198 |
+
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam")
|
| 199 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
|
| 200 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam first beta value")
|
| 201 |
+
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
for k,v in vars(args).items():
|
| 204 |
+
if ('dataset' in k) or ('path' in k) or ('label' in k):
|
| 205 |
+
if v:
|
| 206 |
+
# setattr(args, f"{k}", args.workspace_name+"/"+v)
|
| 207 |
+
print(f"args.{k} : {getattr(args, f'{k}')}")
|
| 208 |
+
|
| 209 |
+
print("Loading Vocab", args.vocab_path)
|
| 210 |
+
vocab_obj = Vocab(args.vocab_path)
|
| 211 |
+
vocab_obj.load_vocab()
|
| 212 |
+
print("Vocab Size: ", len(vocab_obj.vocab))
|
| 213 |
+
print("Loading Test Dataset", args.test_dataset)
|
| 214 |
+
test_dataset = TokenizerDataset(args.test_dataset, args.test_label, vocab_obj, seq_len=args.seq_len)
|
| 215 |
+
print("Creating Dataloader")
|
| 216 |
+
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 217 |
+
bert = torch.load(args.finetuned_bert_checkpoint, map_location="cpu")
|
| 218 |
+
num_labels = 2
|
| 219 |
+
print(f"Number of Labels : {num_labels}")
|
| 220 |
+
print("Creating BERT Fine Tune Trainer")
|
| 221 |
+
trainer = BERTFineTuneTrainer1(bert, len(vocab_obj.vocab), train_dataloader=None, test_dataloader=test_data_loader,
|
| 222 |
+
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, workspace_name = args.workspace_name, num_labels=args.num_labels)
|
| 223 |
+
|
| 224 |
+
print("Testing Start....")
|
| 225 |
+
start_time = time.time()
|
| 226 |
+
for epoch in range(args.epochs):
|
| 227 |
+
trainer.test(epoch)
|
| 228 |
+
|
| 229 |
+
end_time = time.time()
|
| 230 |
+
|
| 231 |
+
print("Time Taken to fine tune dataset = ", end_time - start_time)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# bert/ratio_proportion_change3_2223/sch_largest_100-coded/output/Opts/bert_fine_tuned.model.ep22
|
visualization.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
#import matplotlib as mpl
|
| 3 |
+
#mpl.use('Agg')
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
import metrics
|
| 7 |
+
|
| 8 |
+
class ConfidenceHistogram(metrics.MaxProbCELoss):
|
| 9 |
+
|
| 10 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
| 11 |
+
super().loss(output, labels, n_bins, logits)
|
| 12 |
+
#scale each datapoint
|
| 13 |
+
n = len(labels)
|
| 14 |
+
w = np.ones(n)/n
|
| 15 |
+
|
| 16 |
+
plt.rcParams["font.family"] = "serif"
|
| 17 |
+
#size and axis limits
|
| 18 |
+
plt.figure(figsize=(4,3))
|
| 19 |
+
plt.xlim(0,1)
|
| 20 |
+
plt.ylim(0,1)
|
| 21 |
+
plt.xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
| 22 |
+
plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], ['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'])
|
| 23 |
+
#plot grid
|
| 24 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
| 25 |
+
#plot histogram
|
| 26 |
+
plt.hist(self.confidences,n_bins,weights = w,color='b',range=(0.0,1.0),edgecolor = 'k')
|
| 27 |
+
|
| 28 |
+
#plot vertical dashed lines
|
| 29 |
+
acc = np.mean(self.accuracies)
|
| 30 |
+
conf = np.mean(self.confidences)
|
| 31 |
+
plt.axvline(x=acc, color='tab:grey', linestyle='--', linewidth = 3)
|
| 32 |
+
plt.axvline(x=conf, color='tab:grey', linestyle='--', linewidth = 3)
|
| 33 |
+
if acc > conf:
|
| 34 |
+
plt.text(acc+0.03,0.4,'Accuracy',rotation=90,fontsize=11)
|
| 35 |
+
plt.text(conf-0.07,0.4,'Avg. Confidence',rotation=90, fontsize=11)
|
| 36 |
+
else:
|
| 37 |
+
plt.text(acc-0.07,0.4,'Accuracy',rotation=90,fontsize=11)
|
| 38 |
+
plt.text(conf+0.03,0.4,'Avg. Confidence',rotation=90, fontsize=11)
|
| 39 |
+
|
| 40 |
+
plt.ylabel('% of Samples',fontsize=13)
|
| 41 |
+
plt.xlabel('Confidence',fontsize=13)
|
| 42 |
+
plt.tight_layout()
|
| 43 |
+
if title is not None:
|
| 44 |
+
plt.title(title,fontsize=16)
|
| 45 |
+
return plt
|
| 46 |
+
|
| 47 |
+
class ReliabilityDiagram(metrics.MaxProbCELoss):
|
| 48 |
+
|
| 49 |
+
def plot(self, output, labels, n_bins = 15, logits = True, title = None):
|
| 50 |
+
super().loss(output, labels, n_bins, logits)
|
| 51 |
+
|
| 52 |
+
#computations
|
| 53 |
+
delta = 1.0/n_bins
|
| 54 |
+
x = np.arange(0,1,delta)
|
| 55 |
+
mid = np.linspace(delta/2,1-delta/2,n_bins)
|
| 56 |
+
error = np.concatenate((np.zeros(shape=7), np.abs(np.subtract(mid[7:],self.bin_acc[7:]))))
|
| 57 |
+
|
| 58 |
+
plt.rcParams["font.family"] = "serif"
|
| 59 |
+
#size and axis limits
|
| 60 |
+
plt.figure(figsize=(4,4))
|
| 61 |
+
plt.xlim(0,1)
|
| 62 |
+
plt.ylim(0,1)
|
| 63 |
+
#plot grid
|
| 64 |
+
plt.grid(color='tab:grey', linestyle=(0, (1, 5)), linewidth=1,zorder=0)
|
| 65 |
+
#plot bars and identity line
|
| 66 |
+
plt.bar(x, self.bin_acc, color = 'b', width=delta,align='edge',edgecolor = 'k',label='Outputs',zorder=5)
|
| 67 |
+
plt.bar(x, error, bottom=np.minimum(self.bin_acc,mid), color = 'mistyrose', alpha=0.5, width=delta,align='edge',edgecolor = 'r',hatch='/',label='Gap',zorder=10)
|
| 68 |
+
ident = [0.0, 1.0]
|
| 69 |
+
plt.plot(ident,ident,linestyle='--',color='tab:grey',zorder=15)
|
| 70 |
+
#labels and legend
|
| 71 |
+
plt.ylabel('Accuracy',fontsize=13)
|
| 72 |
+
plt.xlabel('Confidence',fontsize=13)
|
| 73 |
+
plt.legend(loc='upper left',framealpha=1.0,fontsize='medium')
|
| 74 |
+
if title is not None:
|
| 75 |
+
plt.title(title,fontsize=16)
|
| 76 |
+
plt.tight_layout()
|
| 77 |
+
|
| 78 |
+
return plt
|