esgdata / train_finetune.py
darisdzakwanhoesien2
label new
eab69c3
# train_finetune.py
# Script to fine-tune a SentenceTransformer model on pre-generated ESG triplets.
import pandas as pd
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
from create_triplets import create_triplets # Import the function
# --- File Paths ---
TRIPLETS_PATH = "data/esg_triplets.csv"
OUTPUT_MODEL_PATH = "./fine_tuned_esg_model"
# --- Training Parameters ---
BASE_MODEL = "all-MiniLM-L6-v2"
TRAIN_BATCH_SIZE = 16
NUM_EPOCHS = 4
LEARNING_RATE = 2e-5
def fine_tune_model():
"""
Main function to fine-tune the model using pre-generated triplets.
"""
print("--- Starting ESG Fine-Tuning ---")
# Load a pre-trained SentenceTransformer model
model = SentenceTransformer(BASE_MODEL)
try:
triplets_df = pd.read_csv(TRIPLETS_PATH)
train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()]
except FileNotFoundError:
print(f"Info: Triplets file not found at {TRIPLETS_PATH}. Generating it now...")
create_triplets() # Generate the triplets
try:
triplets_df = pd.read_csv(TRIPLETS_PATH)
train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()]
except FileNotFoundError:
print(f"Error: Failed to generate or find triplets file at {TRIPLETS_PATH}. Aborting.")
return
if not train_examples:
print("No training examples found. Aborting fine-tuning.")
return
# Create a DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=TRAIN_BATCH_SIZE)
# Define the loss function. MultipleNegativesRankingLoss is great for triplets.
train_loss = losses.MultipleNegativesRankingLoss(model)
# Fine-tune the model
print(f"Starting training for {NUM_EPOCHS} epochs...")
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=NUM_EPOCHS,
warmup_steps=100,
optimizer_params={'lr': LEARNING_RATE},
output_path=OUTPUT_MODEL_PATH,
show_progress_bar=True
)
print(f"--- Fine-tuning complete. Model saved to {OUTPUT_MODEL_PATH} ---")
print(f"To use this model, update FINE_TUNED_MODEL_PATH in app.py to '{OUTPUT_MODEL_PATH}'.")
if __name__ == "__main__":
fine_tune_model()