Spaces:
Sleeping
Sleeping
| # train_finetune.py | |
| # Script to fine-tune a SentenceTransformer model on pre-generated ESG triplets. | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer, InputExample, losses | |
| from torch.utils.data import DataLoader | |
| from create_triplets import create_triplets # Import the function | |
| # --- File Paths --- | |
| TRIPLETS_PATH = "data/esg_triplets.csv" | |
| OUTPUT_MODEL_PATH = "./fine_tuned_esg_model" | |
| # --- Training Parameters --- | |
| BASE_MODEL = "all-MiniLM-L6-v2" | |
| TRAIN_BATCH_SIZE = 16 | |
| NUM_EPOCHS = 4 | |
| LEARNING_RATE = 2e-5 | |
| def fine_tune_model(): | |
| """ | |
| Main function to fine-tune the model using pre-generated triplets. | |
| """ | |
| print("--- Starting ESG Fine-Tuning ---") | |
| # Load a pre-trained SentenceTransformer model | |
| model = SentenceTransformer(BASE_MODEL) | |
| try: | |
| triplets_df = pd.read_csv(TRIPLETS_PATH) | |
| train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()] | |
| except FileNotFoundError: | |
| print(f"Info: Triplets file not found at {TRIPLETS_PATH}. Generating it now...") | |
| create_triplets() # Generate the triplets | |
| try: | |
| triplets_df = pd.read_csv(TRIPLETS_PATH) | |
| train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()] | |
| except FileNotFoundError: | |
| print(f"Error: Failed to generate or find triplets file at {TRIPLETS_PATH}. Aborting.") | |
| return | |
| if not train_examples: | |
| print("No training examples found. Aborting fine-tuning.") | |
| return | |
| # Create a DataLoader | |
| train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=TRAIN_BATCH_SIZE) | |
| # Define the loss function. MultipleNegativesRankingLoss is great for triplets. | |
| train_loss = losses.MultipleNegativesRankingLoss(model) | |
| # Fine-tune the model | |
| print(f"Starting training for {NUM_EPOCHS} epochs...") | |
| model.fit( | |
| train_objectives=[(train_dataloader, train_loss)], | |
| epochs=NUM_EPOCHS, | |
| warmup_steps=100, | |
| optimizer_params={'lr': LEARNING_RATE}, | |
| output_path=OUTPUT_MODEL_PATH, | |
| show_progress_bar=True | |
| ) | |
| print(f"--- Fine-tuning complete. Model saved to {OUTPUT_MODEL_PATH} ---") | |
| print(f"To use this model, update FINE_TUNED_MODEL_PATH in app.py to '{OUTPUT_MODEL_PATH}'.") | |
| if __name__ == "__main__": | |
| fine_tune_model() |