esgdata / train_finetune.py
darisdzakwanhoesien2
Stage 9 — Domain Adaptation (Fine-Tuning)
ef22374
raw
history blame
3.84 kB
# train_finetune.py
# Template for fine-tuning a SentenceTransformer model on your ESG corpus.
import pandas as pd
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
# 1. --- Data Preparation ---
# You need a corpus of text that is relevant to your ESG domain.
# The goal is to teach the model what "similar" means in your context.
#
# Common strategies include:
# - Triplet Mining: (anchor, positive, negative)
# - anchor: "Our company is committed to reducing GHG emissions."
# - positive: "We have set a target to lower greenhouse gas output by 20%."
# - negative: "The company reported strong financial results this quarter."
#
# - Semantic Pairs: (sentence1, sentence2, score)
# - sentence1: "We are investing in renewable energy sources."
# - sentence2: "The firm is funding solar and wind power projects."
# - score: 0.9 (highly similar)
def prepare_data(corpus_path="path/to/your/esg_corpus.csv"):
"""
Load your ESG corpus and convert it into InputExamples.
This is a placeholder function. You must adapt it to your data format.
"""
# Example: Assuming a CSV with columns ['anchor', 'positive', 'negative']
try:
df = pd.read_csv(corpus_path)
except FileNotFoundError:
print(f"Warning: Corpus file not found at {corpus_path}. Using dummy data.")
# Create dummy data if no file is found
dummy_data = {
'anchor': ["Our policy on carbon emissions is strict.", "We support employee well-being."],
'positive': ["We have a strong stance on GHG reduction.", "Our commitment to staff health is paramount."],
'negative': ["Financial performance was strong.", "The new product launch was successful."]
}
df = pd.DataFrame(dummy_data)
train_examples = []
for i, row in df.iterrows():
train_examples.append(InputExample(texts=[row['anchor'], row['positive'], row['negative']]))
print(f"Loaded {len(train_examples)} training examples.")
return train_examples
# 2. --- Model and Training Configuration ---
# Choose a base model to fine-tune. Using a domain-specific model like
# 'yiyanghkust/finbert-pretrain' (if you convert it to SentenceTransformer format)
# or a general one like 'all-MiniLM-L6-v2' is a good start.
base_model = "all-MiniLM-L6-v2"
# Path to save the fine-tuned model
output_path = "./fine_tuned_esg_model"
# Training parameters
train_batch_size = 16
num_epochs = 4
learning_rate = 2e-5
# 3. --- Fine-Tuning Process ---
def fine_tune_model():
"""
Main function to run the fine-tuning pipeline.
"""
print("--- Starting ESG Domain Adaptation (Fine-Tuning) ---")
# Load a pre-trained SentenceTransformer model
model = SentenceTransformer(base_model)
# Prepare the training data
train_examples = prepare_data()
if not train_examples:
print("No training data. Aborting fine-tuning.")
return
# Create a DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
# Define the loss function. MultipleNegativesRankingLoss is great for triplets.
train_loss = losses.MultipleNegativesRankingLoss(model)
# Fine-tune the model
print(f"Starting training for {num_epochs} epochs...")
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=num_epochs,
warmup_steps=100,
optimizer_params={'lr': learning_rate},
output_path=output_path,
show_progress_bar=True
)
print(f"--- Fine-tuning complete. Model saved to {output_path} ---")
print(f"To use this model, update FINE_TUNED_MODEL_PATH in app.py to '{output_path}'.")
if __name__ == "__main__":
fine_tune_model()