gptneo-2.7Bloratunning / generate.py
jacob-c's picture
Upload folder using huggingface_hub
1f99d4e verified
#!/usr/bin/env python3
import os
import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
# Set constants
USERNAME = "jc4121" # Your username
BASE_PATH = f"/vol/bitbucket/{USERNAME}"
OUTPUT_PATH = f"{BASE_PATH}/gptneo-2.7Bloratunning/output"
def parse_args():
parser = argparse.ArgumentParser(description="Generate lyrics using fine-tuned GPT-Neo 2.7B")
parser.add_argument("--model_path", type=str, default=f"{OUTPUT_PATH}/final_model", help="Path to the fine-tuned model")
parser.add_argument("--artist", type=str, default="Taylor Swift", help="Artist name for conditioning")
parser.add_argument("--max_length", type=int, default=512, help="Maximum length of generated text")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter")
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling parameter")
parser.add_argument("--num_return_sequences", type=int, default=1, help="Number of sequences to generate")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
return parser.parse_args()
def main():
args = parse_args()
torch.manual_seed(args.seed)
print(f"Loading model from {args.model_path}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
# Load model
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
load_in_8bit=True,
device_map="auto",
torch_dtype=torch.float16,
)
# Set model to evaluation mode
model.eval()
# Prepare prompt
prompt = f"Artist: {args.artist}\nLyrics:"
print(f"Generating lyrics for artist: {args.artist}")
# Tokenize prompt
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate text
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=args.max_length,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
num_return_sequences=args.num_return_sequences,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
)
# Decode and print generated text
for i, output in enumerate(outputs):
generated_text = tokenizer.decode(output, skip_special_tokens=True)
print(f"\n--- Generated Lyrics {i+1} ---\n")
print(generated_text)
print("\n" + "-" * 50)
if __name__ == "__main__":
main()