# train_finetune.py # Script to fine-tune a SentenceTransformer model on pre-generated ESG triplets. import pandas as pd from sentence_transformers import SentenceTransformer, InputExample, losses from torch.utils.data import DataLoader from create_triplets import create_triplets # Import the function # --- File Paths --- TRIPLETS_PATH = "data/esg_triplets.csv" OUTPUT_MODEL_PATH = "./fine_tuned_esg_model" # --- Training Parameters --- BASE_MODEL = "all-MiniLM-L6-v2" TRAIN_BATCH_SIZE = 16 NUM_EPOCHS = 4 LEARNING_RATE = 2e-5 def fine_tune_model(): """ Main function to fine-tune the model using pre-generated triplets. """ print("--- Starting ESG Fine-Tuning ---") # Load a pre-trained SentenceTransformer model model = SentenceTransformer(BASE_MODEL) try: triplets_df = pd.read_csv(TRIPLETS_PATH) train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()] except FileNotFoundError: print(f"Info: Triplets file not found at {TRIPLETS_PATH}. Generating it now...") create_triplets() # Generate the triplets try: triplets_df = pd.read_csv(TRIPLETS_PATH) train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()] except FileNotFoundError: print(f"Error: Failed to generate or find triplets file at {TRIPLETS_PATH}. Aborting.") return if not train_examples: print("No training examples found. Aborting fine-tuning.") return # Create a DataLoader train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=TRAIN_BATCH_SIZE) # Define the loss function. MultipleNegativesRankingLoss is great for triplets. train_loss = losses.MultipleNegativesRankingLoss(model) # Fine-tune the model print(f"Starting training for {NUM_EPOCHS} epochs...") model.fit( train_objectives=[(train_dataloader, train_loss)], epochs=NUM_EPOCHS, warmup_steps=100, optimizer_params={'lr': LEARNING_RATE}, output_path=OUTPUT_MODEL_PATH, show_progress_bar=True ) print(f"--- Fine-tuning complete. Model saved to {OUTPUT_MODEL_PATH} ---") print(f"To use this model, update FINE_TUNED_MODEL_PATH in app.py to '{OUTPUT_MODEL_PATH}'.") if __name__ == "__main__": fine_tune_model()