import pandas as pd import os import sys from sentence_transformers import SentenceTransformer, InputExample, losses from torch.utils.data import DataLoader def train_model(data_path="data/esg_corpus.csv"): """ Fine-tunes a SentenceTransformer model using Task-Adaptive Pretraining (TAPT). """ if not os.path.exists(data_path): print(f"Error: Data file not found at {data_path}") sys.exit(1) df = pd.read_csv(data_path) if 'text' not in df.columns: print("Error: CSV must have a 'text' column.") sys.exit(1) sentences = df['text'].dropna().tolist() print(f"Loaded {len(sentences)} ESG documents for TAPT.") examples = [InputExample(texts=[s, s]) for s in sentences] model_name = "sentence-transformers/all-MiniLM-L6-v2" model = SentenceTransformer(model_name) print(f"Loaded base model: {model_name}") train_dataloader = DataLoader(examples, shuffle=True, batch_size=16) train_loss = losses.MultipleNegativesRankingLoss(model) output_path = "esg_finetuned_model" os.makedirs(output_path, exist_ok=True) model.fit( train_objectives=[(train_dataloader, train_loss)], epochs=1, # Keep it short for Gradio demo warmup_steps=100, show_progress_bar=True, output_path=output_path ) print(f"Fine-tuned ESG model saved to: {output_path}") return output_path if __name__ == "__main__": # If a file path is provided as a command-line argument, use it if len(sys.argv) > 1: input_file = sys.argv[1] train_model(input_file) else: # Otherwise, use the default path train_model()