| from datasets import DatasetDict, load_dataset | |
| from evaluate import load as load_metric | |
| from transformers import * | |
| def train(batch_size: int, model_name: str="t5-small", max_steps: int=10_000) -> None: | |
| total_batch_size_per_step = 512 | |
| grad_acc_steps = total_batch_size_per_step // batch_size | |
| assert grad_acc_steps * batch_size == total_batch_size_per_step | |
| model_name_for_path = model_name.split("/")[-1] | |
| output_dir = f"wmt19-ende-{model_name_for_path}" | |
| args = Seq2SeqTrainingArguments( | |
| output_dir=output_dir, | |
| learning_rate=1e-4, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size * 2, | |
| gradient_accumulation_steps=grad_acc_steps, | |
| max_steps=max_steps, | |
| weight_decay=1e-2, | |
| optim="adamw_torch_fused", | |
| lr_scheduler_type="constant", | |
| evaluation_strategy="steps", | |
| eval_steps=100, | |
| save_strategy="steps", | |
| save_steps=100, | |
| save_total_limit=1, | |
| save_safetensors=True, | |
| metric_for_best_model="bleu", | |
| push_to_hub=True, | |
| bf16=True, | |
| bf16_full_eval=True, | |
| seed=42, | |
| predict_with_generate=True, | |
| log_level="error", | |
| logging_steps=1, | |
| logging_dir=output_dir, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| bleu = load_metric("bleu") | |
| def compute_metrics(eval_preds: EvalPrediction): | |
| logits, label_ids = eval_preds | |
| label_ids[label_ids == -100] = tokenizer.pad_token_id | |
| references = tokenizer.batch_decode(label_ids, skip_special_tokens=True) | |
| predictions = tokenizer.batch_decode(logits, skip_special_tokens=True) | |
| bleu_outputs = bleu.compute(predictions=predictions, references=references) | |
| return { | |
| "bleu": 100 * bleu_outputs["bleu"], | |
| "brevity_penalty": bleu_outputs["brevity_penalty"], | |
| } | |
| def map_fn(inputs): | |
| map_fn = lambda s: tokenizer([d[s] for d in inputs["translation"]], return_attention_mask=False, max_length=64, truncation=True).input_ids | |
| return { | |
| "input_ids": map_fn("de"), | |
| "labels": map_fn("en"), | |
| } | |
| get_dataset_split = lambda s: load_dataset("wmt19", "de-en", split=s, streaming=True).map(map_fn, batched=True) | |
| apply_length_filter = lambda d: d.filter(lambda e: len(e["input_ids"]) >= 8 and len(e["labels"]) >= 8) | |
| trainer = Seq2SeqTrainer( | |
| model=AutoModelForSeq2SeqLM.from_pretrained(model_name), | |
| args=args, | |
| data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), | |
| train_dataset=apply_length_filter(get_dataset_split("train")), | |
| eval_dataset=get_dataset_split("validation"), | |
| tokenizer=tokenizer, | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.remove_callback(PrinterCallback) | |
| trainer.train() | |
| trainer.push_to_hub() | |