Spaces:
Runtime error
Runtime error
| import json | |
| import torch | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import classification_report | |
| from torch.utils.data import Dataset, DataLoader | |
| # Corrected imports for Bert models and AdamW | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from torch.optim import AdamW # AdamW is now imported from torch.optim | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| import os | |
| import nltk # Ensure nltk is imported | |
| # Download NLTK data if not already present | |
| try: | |
| nltk.data.find('corpora/wordnet') | |
| except LookupError: | |
| nltk.download('wordnet', quiet=True) | |
| try: | |
| nltk.data.find('taggers/averaged_perceptron_tagger') | |
| except LookupError: | |
| nltk.download('averaged_perceptron_tagger', quiet=True) | |
| # Dataset class for tokenized queries | |
| class AdversarialQueryDataset(Dataset): | |
| def __init__(self, queries, labels, tokenizer, max_len=128): | |
| self.queries = queries | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.queries) | |
| def __getitem__(self, idx): | |
| query = self.queries[idx] | |
| label = self.labels[idx] | |
| encoding = self.tokenizer.encode_plus( | |
| query, | |
| add_special_tokens=True, | |
| truncation=True, | |
| max_length=self.max_len, | |
| padding='max_length', | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': encoding['input_ids'].squeeze(), | |
| 'attention_mask': encoding['attention_mask'].squeeze(), | |
| 'label': torch.tensor(label, dtype=torch.long) | |
| } | |
| def train_and_save_defense_model(data_path="caleb_adversarial_prompts.json", model_save_path="./defense_model"): | |
| """ | |
| Trains an adversarial query detection model and saves it. | |
| """ | |
| if not os.path.exists(data_path): | |
| print(f"Error: Dataset not found at {data_path}. Please create it first.") | |
| return | |
| with open(data_path, "r") as f: | |
| data = json.load(f) | |
| queries = [item["prompt"] for item in data] | |
| labels = [1 if item["is_adversarial"] else 0 for item in data] # 1 = adversarial, 0 = clean | |
| # Ensure consistency: using 'bert-base-uncased' tokenizer and model | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) | |
| train_qs, test_qs, train_labels, test_labels = train_test_split(queries, labels, test_size=0.2, random_state=42) | |
| train_dataset = AdversarialQueryDataset(train_qs, train_labels, tokenizer) | |
| test_dataset = AdversarialQueryDataset(test_qs, test_labels, tokenizer) | |
| train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) | |
| test_loader = DataLoader(test_dataset, batch_size=16) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| optimizer = AdamW(model.parameters(), lr=2e-5) # AdamW from torch.optim | |
| print("Starting model training...") | |
| model.train() | |
| for epoch in range(5): # Train for 3 epochs | |
| total_loss = 0 | |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"): | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['label'].to(device) | |
| optimizer.zero_grad() | |
| outputs = model(input_ids, attention_mask=attention_mask, labels=labels) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}") | |
| print("\nEvaluating model...") | |
| model.eval() | |
| all_preds, all_labels = [], [] | |
| with torch.no_grad(): | |
| for batch in test_loader: | |
| input_ids = batch['input_ids'].to(device) | |
| attention_mask = batch['attention_mask'].to(device) | |
| labels = batch['label'].to(device) | |
| outputs = model(input_ids, attention_mask=attention_mask) | |
| preds = torch.argmax(F.softmax(outputs.logits, dim=1), dim=1) | |
| all_preds.extend(preds.cpu().numpy()) | |
| all_labels.extend(labels.cpu().numpy()) | |
| report = classification_report(all_labels, all_preds) | |
| print("\nClassification Report:\n", report) | |
| # Save the trained model | |
| os.makedirs(model_save_path, exist_ok=True) | |
| model.save_pretrained(model_save_path) | |
| tokenizer.save_pretrained(model_save_path) | |
| print(f"\nDefense model saved to {model_save_path}") | |
| if __name__ == "__main__": | |
| train_and_save_defense_model() | |