|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import os |
|
|
from datetime import datetime |
|
|
from config import Config |
|
|
from model import SmoothDiffusionUNet |
|
|
from noise_scheduler import FrequencyAwareNoise |
|
|
from dataloader import get_dataloaders |
|
|
from loss import diffusion_loss |
|
|
from sample import sample |
|
|
|
|
|
def train(): |
|
|
config = Config() |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
log_dir = os.path.join(config.log_dir, timestamp) |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
writer = SummaryWriter(log_dir) |
|
|
|
|
|
|
|
|
model = SmoothDiffusionUNet(config).to(device) |
|
|
noise_scheduler = FrequencyAwareNoise(config) |
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr) |
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True) |
|
|
train_loader, val_loader = get_dataloaders(config) |
|
|
|
|
|
|
|
|
for epoch in range(config.epochs): |
|
|
model.train() |
|
|
epoch_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for batch_idx, (x0, _) in enumerate(train_loader): |
|
|
x0 = x0.to(device) |
|
|
|
|
|
|
|
|
t = torch.randint(0, config.T, (x0.size(0),), device=device) |
|
|
|
|
|
|
|
|
loss = diffusion_loss(model, x0, t, noise_scheduler, config) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
epoch_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
if batch_idx % 100 == 0: |
|
|
|
|
|
if torch.isnan(loss): |
|
|
print(f"WARNING: NaN loss detected at Epoch {epoch}, Batch {batch_idx}") |
|
|
|
|
|
|
|
|
total_norm = 0 |
|
|
for p in model.parameters(): |
|
|
if p.grad is not None: |
|
|
param_norm = p.grad.data.norm(2) |
|
|
total_norm += param_norm.item() ** 2 |
|
|
total_norm = total_norm ** (1. / 2) |
|
|
|
|
|
|
|
|
if batch_idx == 0 and epoch % 5 == 0: |
|
|
print(f"Debug for Epoch {epoch}:") |
|
|
noise_scheduler.debug_noise_stats(x0[:1], t[:1]) |
|
|
|
|
|
|
|
|
if batch_idx % 500 == 0: |
|
|
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, Grad Norm: {total_norm:.4f}") |
|
|
writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx) |
|
|
writer.add_scalar('Grad_Norm/train', total_norm, epoch * len(train_loader) + batch_idx) |
|
|
|
|
|
|
|
|
avg_epoch_loss = epoch_loss / num_batches |
|
|
scheduler.step(avg_epoch_loss) |
|
|
|
|
|
|
|
|
current_lr = optimizer.param_groups[0]['lr'] |
|
|
print(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}, LR: {current_lr:.2e}") |
|
|
writer.add_scalar('Loss/epoch', avg_epoch_loss, epoch) |
|
|
writer.add_scalar('Learning_Rate', current_lr, epoch) |
|
|
|
|
|
|
|
|
if epoch % config.sample_every == 0: |
|
|
sample(model, noise_scheduler, device, epoch, writer) |
|
|
|
|
|
|
|
|
if epoch == 30 or (epoch > 30 and epoch % 30 == 0): |
|
|
checkpoint_path = os.path.join(log_dir, f"model_epoch_{epoch}.pth") |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'loss': avg_epoch_loss, |
|
|
'config': config |
|
|
}, checkpoint_path) |
|
|
print(f"Model checkpoint saved at epoch {epoch}: {checkpoint_path}") |
|
|
|
|
|
torch.save(model.state_dict(), os.path.join(log_dir, "model_final.pth")) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |