#!/usr/bin/env python3 import os import torch import argparse from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel, PeftConfig # Set constants USERNAME = "jc4121" # Your username BASE_PATH = f"/vol/bitbucket/{USERNAME}" OUTPUT_PATH = f"{BASE_PATH}/gptneo-2.7Bloratunning/output" def parse_args(): parser = argparse.ArgumentParser(description="Generate lyrics using fine-tuned GPT-Neo 2.7B") parser.add_argument("--model_path", type=str, default=f"{OUTPUT_PATH}/final_model", help="Path to the fine-tuned model") parser.add_argument("--artist", type=str, default="Taylor Swift", help="Artist name for conditioning") parser.add_argument("--max_length", type=int, default=512, help="Maximum length of generated text") parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter") parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling parameter") parser.add_argument("--num_return_sequences", type=int, default=1, help="Number of sequences to generate") parser.add_argument("--seed", type=int, default=42, help="Random seed") return parser.parse_args() def main(): args = parse_args() torch.manual_seed(args.seed) print(f"Loading model from {args.model_path}") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model_path) # Load model model = AutoModelForCausalLM.from_pretrained( args.model_path, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16, ) # Set model to evaluation mode model.eval() # Prepare prompt prompt = f"Artist: {args.artist}\nLyrics:" print(f"Generating lyrics for artist: {args.artist}") # Tokenize prompt inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate text with torch.no_grad(): outputs = model.generate( **inputs, max_length=args.max_length, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, num_return_sequences=args.num_return_sequences, pad_token_id=tokenizer.eos_token_id, do_sample=True, ) # Decode and print generated text for i, output in enumerate(outputs): generated_text = tokenizer.decode(output, skip_special_tokens=True) print(f"\n--- Generated Lyrics {i+1} ---\n") print(generated_text) print("\n" + "-" * 50) if __name__ == "__main__": main()