| import torch | |
| from typing import List | |
| class Singleton(type): | |
| _instances = {} | |
| def __call__(cls, *args, **kwargs): | |
| if cls not in cls._instances: | |
| cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |
| return cls._instances[cls] | |
| class Latents(metaclass=Singleton): | |
| def __init__(self) -> None: | |
| self.history: List[torch.FloatTensor] = [] | |
| def is_empty(self) -> bool: | |
| return self.history is None | |
| def add_latents(self, latents: torch.FloatTensor): | |
| self.history.append(latents) | |
| def clear(self): | |
| self.history = [] | |
| def dump_and_clear(self): | |
| history = self.history | |
| self.clear() | |
| return history | |