import os import torch import shutil from transformers import Trainer from transformers.modeling_utils import unwrap_model from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES import torch.distributed as dist from typing import Optional from torch import nn from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from transformers.utils import is_sagemaker_mp_enabled, is_apex_available, is_torch_tpu_available,is_accelerate_available if is_apex_available(): from apex import amp if is_sagemaker_mp_enabled(): from transformers.trainer_pt_utils import smp_forward_backward import contextlib import copy import functools import glob import importlib.metadata import inspect import math import os import random import re import shutil import sys import tempfile import time import warnings from collections.abc import Mapping from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from packaging import version from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from transformers.modelcard import TrainingSummary from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from transformers.trainer_callback import ( CallbackHandler, DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerCallback, TrainerControl, TrainerState, ) from transformers.utils import ( ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, PushInProgress, can_return_loss, find_labels, is_accelerate_available, is_apex_available, is_bitsandbytes_available, is_datasets_available, is_in_notebook, is_ipex_available, is_peft_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_compile_available, is_torch_neuroncore_available, is_torch_npu_available, is_torch_tpu_available, logging, strtobool, ) DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback if is_in_notebook(): from transformers.utils.notebook import NotebookProgressCallback DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback if is_apex_available(): from apex import amp if is_datasets_available(): import datasets if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat else: IS_SAGEMAKER_MP_POST_1_10 = False if is_safetensors_available(): import safetensors.torch if is_peft_available(): from peft import PeftModel if is_accelerate_available(): from accelerate import Accelerator, skip_first_batches from accelerate import __version__ as accelerate_version from accelerate.utils import ( DistributedDataParallelKwargs, GradientAccumulationPlugin, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer, ) DATA_SAMPLERS = [RandomSampler] if version.parse(accelerate_version) > version.parse("0.23.0"): from accelerate.data_loader import SeedableRandomSampler DATA_SAMPLERS += [SeedableRandomSampler] if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper if TYPE_CHECKING: import optuna logger = logging.get_logger(__name__) # Name of the files used for checkpointing TRAINING_ARGS_NAME = "training_args.bin" TRAINER_STATE_NAME = "trainer_state.json" OPTIMIZER_NAME = "optimizer.pt" OPTIMIZER_NAME_BIN = "optimizer.bin" SCHEDULER_NAME = "scheduler.pt" SCALER_NAME = "scaler.pt" FSDP_MODEL_NAME = "pytorch_model_fsdp" def maybe_zero_3(param, ignore_status=False, name=None): from deepspeed import zero from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus if hasattr(param, "ds_id"): if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: if not ignore_status: print(name, 'no ignore status') with zero.GatheredParameters([param]): param = param.data.detach().cpu().clone() else: param = param.detach().cpu().clone() return param def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} return to_return class LLaVATrainerSSL(Trainer): def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) # Only save Adapter keys_to_match = ['mm_projector'] if getattr(self.args, "use_im_start_end", False): keys_to_match.extend(['embed_tokens', 'embed_in']) weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) else: super(LLaVATrainerSSL, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): pass else: super(LLaVATrainerSSL, self)._save(output_dir, state_dict) def update_history_loss_dict(self,outputs): if not hasattr(self,'history_loss_dict'): self.history_loss_dict = {} for name, value in outputs.items(): if 'loss' in name and name != 'loss': if name not in self.history_loss_dict: self.history_loss_dict[name] = value.item() else: if value != 0: self.history_loss_dict[name] = value.item() def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: labels = None outputs = model(**inputs) if self.args.past_index >= 0: self._past = outputs[self.args.past_index] if labels is not None: if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) else: if isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." ) loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] if isinstance(outputs, dict) and 'loss_dice' in outputs: loss_dict = {} for name,value in outputs.items(): if 'loss' in name and name != 'loss': loss_value = value.item() if loss_value == 0 and hasattr(self,'history_loss_dict'): loss_value = self.history_loss_dict[name] loss_dict[name] = loss_value self.update_history_loss_dict(outputs) self.log(loss_dict) return (loss, outputs) if return_outputs else loss class LLaVATrainer(Trainer): def _save_checkpoint(self, model, trial, metrics=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) # Only save Adapter keys_to_match = ['mm_projector'] if getattr(self.args, "use_im_start_end", False): keys_to_match.extend(['embed_tokens', 'embed_in']) weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) else: super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) def _save(self, output_dir: Optional[str] = None, state_dict=None): if getattr(self.args, 'tune_mm_mlp_adapter', False): pass else: super(LLaVATrainer, self)._save(output_dir, state_dict) def update_history_loss_dict(self,outputs): if not hasattr(self,'history_loss_dict'): self.history_loss_dict = {} for name, value in outputs.items(): if 'loss' in name and name != 'loss': if name not in self.history_loss_dict: self.history_loss_dict[name] = value.item() else: if value != 0: self.history_loss_dict[name] = value.item() def compute_loss(self, model, inputs, return_outputs=False): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ if self.label_smoother is not None and "labels" in inputs: labels = inputs.pop("labels") else: labels = None outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] if labels is not None: if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) else: if isinstance(outputs, dict) and "loss" not in outputs: raise ValueError( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." ) loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] if isinstance(outputs, dict) and 'loss_dice' in outputs: loss_dict = {} for name,value in outputs.items(): if 'loss' in name and name != 'loss': loss_value = value.item() if loss_value == 0 and hasattr(self,'history_loss_dict'): loss_value = self.history_loss_dict[name] loss_dict[name] = loss_value self.update_history_loss_dict(outputs) self.log(loss_dict) return (loss, outputs) if return_outputs else loss