"""Console logger utilities. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging """ import logging import fsspec import lightning import torch from timm.scheduler import CosineLRScheduler import argparse import numpy as np import random import os import time, torch from collections import defaultdict from contextlib import contextmanager class StepTimer: def __init__(self, device=None): self.times = defaultdict(list) self.device = device self._use_cuda_sync = ( isinstance(device, torch.device) and device.type == "cuda" ) or (isinstance(device, str) and "cuda" in device) @contextmanager def section(self, name): if self._use_cuda_sync: torch.cuda.synchronize() t0 = time.perf_counter() try: yield finally: if self._use_cuda_sync: torch.cuda.synchronize() dt = time.perf_counter() - t0 self.times[name].append(dt) def summary(self, top_k=None): # returns (name, count, total, mean, p50, p95) import numpy as np rows = [] for k, v in self.times.items(): a = np.array(v, dtype=float) rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95))) rows.sort(key=lambda r: r[2], reverse=True) # by total time return rows[:top_k] if top_k else rows def sample_categorical_logits(logits, dtype=torch.float64): # do not require logits to be log-softmaxed gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log() return (logits + gumbel_noise).argmax(dim=-1) def fsspec_exists(filename): """Check if a file exists using fsspec.""" fs, _ = fsspec.core.url_to_fs(filename) return fs.exists(filename) def fsspec_listdir(dirname): """Listdir in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) return fs.ls(dirname) def fsspec_mkdirs(dirname, exist_ok=True): """Mkdirs in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) fs.makedirs(dirname, exist_ok=exist_ok) def print_nans(tensor, name): if torch.isnan(tensor).any(): print(name, tensor) class CosineDecayWarmupLRScheduler( CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): """Wrap timm.scheduler.CosineLRScheduler Enables calling scheduler.step() without passing in epoch. Supports resuming as well. Adapted from: https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._last_epoch = -1 self.step(epoch=0) def step(self, epoch=None): if epoch is None: self._last_epoch += 1 else: self._last_epoch = epoch # We call either step or step_update, depending on # whether we're using the scheduler every epoch or every # step. # Otherwise, lightning will always call step (i.e., # meant for each epoch), and if we set scheduler # interval to "step", then the learning rate update will # be wrong. if self.t_in_epochs: super().step(epoch=self._last_epoch) else: super().step_update(num_updates=self._last_epoch) class LoggingContext: """Context manager for selective logging.""" def __init__(self, logger, level=None, handler=None, close=True): self.logger = logger self.level = level self.handler = handler self.close = close def __enter__(self): if self.level is not None: self.old_level = self.logger.level self.logger.setLevel(self.level) if self.handler: self.logger.addHandler(self.handler) def __exit__(self, et, ev, tb): if self.level is not None: self.logger.setLevel(self.old_level) if self.handler: self.logger.removeHandler(self.handler) if self.handler and self.close: self.handler.close() def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" logger = logging.getLogger(name) logger.setLevel(level) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ('debug', 'info', 'warning', 'error', 'exception', 'fatal', 'critical'): setattr(logger, level, lightning.pytorch.utilities.rank_zero_only( getattr(logger, level))) return logger def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def set_seed(seed, use_cuda): os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) # torch.backends.cudnn.deterministic = True if use_cuda: torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) print(f'=> Seed of the run set to {seed}')