Grad-CDM / train.py
nazgut's picture
Upload 24 files
8abfb97 verified
raw
history blame
4.67 kB
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
from datetime import datetime
from config import Config
from model import SmoothDiffusionUNet
from noise_scheduler import FrequencyAwareNoise
from dataloader import get_dataloaders
from loss import diffusion_loss
from sample import sample
def train():
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Setup logging
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = os.path.join(config.log_dir, timestamp)
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir)
# Initialize components
model = SmoothDiffusionUNet(config).to(device)
noise_scheduler = FrequencyAwareNoise(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
train_loader, val_loader = get_dataloaders(config)
# Training loop
for epoch in range(config.epochs):
model.train()
epoch_loss = 0.0
num_batches = 0
for batch_idx, (x0, _) in enumerate(train_loader):
x0 = x0.to(device)
# Sample random timesteps
t = torch.randint(0, config.T, (x0.size(0),), device=device)
# Compute loss
loss = diffusion_loss(model, x0, t, noise_scheduler, config)
# Optimize
optimizer.zero_grad()
loss.backward()
# Add gradient clipping for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) # Increased from 1.0
optimizer.step()
# Track epoch loss for scheduler
epoch_loss += loss.item()
num_batches += 1
# Logging with more details
if batch_idx % 100 == 0:
# Check for NaN values
if torch.isnan(loss):
print(f"WARNING: NaN loss detected at Epoch {epoch}, Batch {batch_idx}")
# Check gradient norms
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
# Debug noise statistics less frequently (every 5 epochs)
if batch_idx == 0 and epoch % 5 == 0:
print(f"Debug for Epoch {epoch}:")
noise_scheduler.debug_noise_stats(x0[:1], t[:1])
# Re-enable batch logging since training is stable
if batch_idx % 500 == 0: # Less frequent logging
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, Grad Norm: {total_norm:.4f}")
writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)
writer.add_scalar('Grad_Norm/train', total_norm, epoch * len(train_loader) + batch_idx)
# Update learning rate based on epoch loss
avg_epoch_loss = epoch_loss / num_batches
scheduler.step(avg_epoch_loss)
# Log epoch statistics
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}, LR: {current_lr:.2e}")
writer.add_scalar('Loss/epoch', avg_epoch_loss, epoch)
writer.add_scalar('Learning_Rate', current_lr, epoch)
# Validation
if epoch % config.sample_every == 0:
sample(model, noise_scheduler, device, epoch, writer)
# Save model checkpoints at epoch 30 and every 30 epochs
if epoch == 30 or (epoch > 30 and epoch % 30 == 0):
checkpoint_path = os.path.join(log_dir, f"model_epoch_{epoch}.pth")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': avg_epoch_loss,
'config': config
}, checkpoint_path)
print(f"Model checkpoint saved at epoch {epoch}: {checkpoint_path}")
torch.save(model.state_dict(), os.path.join(log_dir, "model_final.pth"))
if __name__ == "__main__":
train()