|
|
"""Console logger utilities. |
|
|
|
|
|
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py |
|
|
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import fsspec |
|
|
import lightning |
|
|
import torch |
|
|
from timm.scheduler import CosineLRScheduler |
|
|
import argparse |
|
|
import numpy as np |
|
|
import random |
|
|
import os |
|
|
import time, torch |
|
|
from collections import defaultdict |
|
|
from contextlib import contextmanager |
|
|
|
|
|
class StepTimer: |
|
|
def __init__(self, device=None): |
|
|
self.times = defaultdict(list) |
|
|
self.device = device |
|
|
self._use_cuda_sync = ( |
|
|
isinstance(device, torch.device) and device.type == "cuda" |
|
|
) or (isinstance(device, str) and "cuda" in device) |
|
|
|
|
|
@contextmanager |
|
|
def section(self, name): |
|
|
if self._use_cuda_sync: |
|
|
torch.cuda.synchronize() |
|
|
t0 = time.perf_counter() |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
if self._use_cuda_sync: |
|
|
torch.cuda.synchronize() |
|
|
dt = time.perf_counter() - t0 |
|
|
self.times[name].append(dt) |
|
|
|
|
|
def summary(self, top_k=None): |
|
|
|
|
|
import numpy as np |
|
|
rows = [] |
|
|
for k, v in self.times.items(): |
|
|
a = np.array(v, dtype=float) |
|
|
rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95))) |
|
|
rows.sort(key=lambda r: r[2], reverse=True) |
|
|
return rows[:top_k] if top_k else rows |
|
|
|
|
|
|
|
|
def sample_categorical_logits(logits, dtype=torch.float64): |
|
|
|
|
|
gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log() |
|
|
return (logits + gumbel_noise).argmax(dim=-1) |
|
|
|
|
|
def fsspec_exists(filename): |
|
|
"""Check if a file exists using fsspec.""" |
|
|
fs, _ = fsspec.core.url_to_fs(filename) |
|
|
return fs.exists(filename) |
|
|
|
|
|
|
|
|
def fsspec_listdir(dirname): |
|
|
"""Listdir in manner compatible with fsspec.""" |
|
|
fs, _ = fsspec.core.url_to_fs(dirname) |
|
|
return fs.ls(dirname) |
|
|
|
|
|
|
|
|
def fsspec_mkdirs(dirname, exist_ok=True): |
|
|
"""Mkdirs in manner compatible with fsspec.""" |
|
|
fs, _ = fsspec.core.url_to_fs(dirname) |
|
|
fs.makedirs(dirname, exist_ok=exist_ok) |
|
|
|
|
|
|
|
|
def print_nans(tensor, name): |
|
|
if torch.isnan(tensor).any(): |
|
|
print(name, tensor) |
|
|
|
|
|
|
|
|
class CosineDecayWarmupLRScheduler( |
|
|
CosineLRScheduler, |
|
|
torch.optim.lr_scheduler._LRScheduler): |
|
|
"""Wrap timm.scheduler.CosineLRScheduler |
|
|
Enables calling scheduler.step() without passing in epoch. |
|
|
Supports resuming as well. |
|
|
Adapted from: |
|
|
https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self._last_epoch = -1 |
|
|
self.step(epoch=0) |
|
|
|
|
|
def step(self, epoch=None): |
|
|
if epoch is None: |
|
|
self._last_epoch += 1 |
|
|
else: |
|
|
self._last_epoch = epoch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.t_in_epochs: |
|
|
super().step(epoch=self._last_epoch) |
|
|
else: |
|
|
super().step_update(num_updates=self._last_epoch) |
|
|
|
|
|
|
|
|
class LoggingContext: |
|
|
"""Context manager for selective logging.""" |
|
|
def __init__(self, logger, level=None, handler=None, close=True): |
|
|
self.logger = logger |
|
|
self.level = level |
|
|
self.handler = handler |
|
|
self.close = close |
|
|
|
|
|
def __enter__(self): |
|
|
if self.level is not None: |
|
|
self.old_level = self.logger.level |
|
|
self.logger.setLevel(self.level) |
|
|
if self.handler: |
|
|
self.logger.addHandler(self.handler) |
|
|
|
|
|
def __exit__(self, et, ev, tb): |
|
|
if self.level is not None: |
|
|
self.logger.setLevel(self.old_level) |
|
|
if self.handler: |
|
|
self.logger.removeHandler(self.handler) |
|
|
if self.handler and self.close: |
|
|
self.handler.close() |
|
|
|
|
|
|
|
|
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: |
|
|
"""Initializes multi-GPU-friendly python logger.""" |
|
|
|
|
|
logger = logging.getLogger(name) |
|
|
logger.setLevel(level) |
|
|
|
|
|
|
|
|
|
|
|
for level in ('debug', 'info', 'warning', 'error', |
|
|
'exception', 'fatal', 'critical'): |
|
|
setattr(logger, |
|
|
level, |
|
|
lightning.pytorch.utilities.rank_zero_only( |
|
|
getattr(logger, level))) |
|
|
|
|
|
return logger |
|
|
|
|
|
|
|
|
def str2bool(v): |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|
|
return True |
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|
|
return False |
|
|
else: |
|
|
raise argparse.ArgumentTypeError('Boolean value expected.') |
|
|
|
|
|
|
|
|
def set_seed(seed, use_cuda): |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
if use_cuda: |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
print(f'=> Seed of the run set to {seed}') |
|
|
|