darisdzakwanhoesien2 commited on
Commit
eab69c3
·
1 Parent(s): 68bd627
Files changed (1) hide show
  1. train_finetune.py +9 -3
train_finetune.py CHANGED
@@ -4,6 +4,7 @@
4
  import pandas as pd
5
  from sentence_transformers import SentenceTransformer, InputExample, losses
6
  from torch.utils.data import DataLoader
 
7
 
8
  # --- File Paths ---
9
  TRIPLETS_PATH = "data/esg_triplets.csv"
@@ -28,9 +29,14 @@ def fine_tune_model():
28
  triplets_df = pd.read_csv(TRIPLETS_PATH)
29
  train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()]
30
  except FileNotFoundError:
31
- print(f"Error: Triplets file not found at {TRIPLETS_PATH}. Aborting.")
32
- print("Please run 'create_triplets.py' first to generate the training data.")
33
- return
 
 
 
 
 
34
 
35
  if not train_examples:
36
  print("No training examples found. Aborting fine-tuning.")
 
4
  import pandas as pd
5
  from sentence_transformers import SentenceTransformer, InputExample, losses
6
  from torch.utils.data import DataLoader
7
+ from create_triplets import create_triplets # Import the function
8
 
9
  # --- File Paths ---
10
  TRIPLETS_PATH = "data/esg_triplets.csv"
 
29
  triplets_df = pd.read_csv(TRIPLETS_PATH)
30
  train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()]
31
  except FileNotFoundError:
32
+ print(f"Info: Triplets file not found at {TRIPLETS_PATH}. Generating it now...")
33
+ create_triplets() # Generate the triplets
34
+ try:
35
+ triplets_df = pd.read_csv(TRIPLETS_PATH)
36
+ train_examples = [InputExample(texts=[row['anchor'], row['positive'], row['negative']]) for _, row in triplets_df.iterrows()]
37
+ except FileNotFoundError:
38
+ print(f"Error: Failed to generate or find triplets file at {TRIPLETS_PATH}. Aborting.")
39
+ return
40
 
41
  if not train_examples:
42
  print("No training examples found. Aborting fine-tuning.")