|
|
|
|
|
import os |
|
|
import torch |
|
|
import argparse |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel, PeftConfig |
|
|
|
|
|
|
|
|
USERNAME = "jc4121" |
|
|
BASE_PATH = f"/vol/bitbucket/{USERNAME}" |
|
|
OUTPUT_PATH = f"{BASE_PATH}/gptneo-2.7Bloratunning/output" |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Generate lyrics using fine-tuned GPT-Neo 2.7B") |
|
|
parser.add_argument("--model_path", type=str, default=f"{OUTPUT_PATH}/final_model", help="Path to the fine-tuned model") |
|
|
parser.add_argument("--artist", type=str, default="Taylor Swift", help="Artist name for conditioning") |
|
|
parser.add_argument("--max_length", type=int, default=512, help="Maximum length of generated text") |
|
|
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") |
|
|
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling parameter") |
|
|
parser.add_argument("--top_k", type=int, default=50, help="Top-k sampling parameter") |
|
|
parser.add_argument("--num_return_sequences", type=int, default=1, help="Number of sequences to generate") |
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed") |
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
print(f"Loading model from {args.model_path}") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
args.model_path, |
|
|
load_in_8bit=True, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
prompt = f"Artist: {args.artist}\nLyrics:" |
|
|
print(f"Generating lyrics for artist: {args.artist}") |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=args.max_length, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
top_k=args.top_k, |
|
|
num_return_sequences=args.num_return_sequences, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
|
|
|
for i, output in enumerate(outputs): |
|
|
generated_text = tokenizer.decode(output, skip_special_tokens=True) |
|
|
print(f"\n--- Generated Lyrics {i+1} ---\n") |
|
|
print(generated_text) |
|
|
print("\n" + "-" * 50) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |