# /// script # requires-python = "==3.10" # dependencies = ["torch==2.7.0", "triton", "numpy", "kernels"] # [tool.uv.sources] # kernels = { git = "https://github.com/huggingface/kernels.git" } # /// import time import torch from kernels import get_local_kernel from kernels import get_kernel from pathlib import Path from torch.nn import functional as F # Set seeds and deterministic flags for reproducibility torch.manual_seed(42) torch.cuda.manual_seed(42) torch.cuda.manual_seed_all(42) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False yamoe = get_kernel("drbh/yamoe", revision="v0.1.0") # Configuration batch_size, seq_len, hidden_dim = 4, 1024, 2880 # batch_size, seq_len, hidden_dim = 4, 32, 1024 num_experts, top_k = 8, 2 # Create routing weights logits = torch.randn(batch_size, seq_len, num_experts) probs = F.softmax(logits, dim=-1) weights, indices = torch.topk(probs, top_k, dim=-1) batch_seq = batch_size * seq_len routing_weights = torch.zeros(batch_seq, num_experts, dtype=weights.dtype) flat_indices, flat_weights = indices.reshape(-1, top_k), weights.reshape(-1, top_k) batch_indices = torch.arange(batch_seq).unsqueeze(1).expand(-1, top_k) routing_weights[batch_indices, flat_indices] = flat_weights # Create model tensors hidden_states = torch.randn(batch_size, seq_len, hidden_dim).cuda() # gate_up_proj = torch.randn(num_experts, hidden_dim, 2 * hidden_dim).cuda() gate_up_proj_bias = torch.zeros(num_experts, 2 * hidden_dim).cuda() # down_proj = torch.randn(num_experts, hidden_dim, hidden_dim).cuda() down_proj_bias = torch.zeros(num_experts, hidden_dim).cuda() # routing_weights = routing_weights.cuda() router_indices = flat_indices.cuda() gate_up_proj = torch.empty(num_experts, hidden_dim, 2 * hidden_dim, device="cuda") down_proj = torch.empty(num_experts, hidden_dim, hidden_dim, device="cuda") torch.nn.init.trunc_normal_(gate_up_proj, std=0.02) torch.nn.init.trunc_normal_(down_proj, std=0.02) routing_weights = routing_weights.to(dtype=torch.float32, device="cuda") # Warmup for _ in range(5): _ = yamoe.experts( hidden_states.view(-1, hidden_dim), router_indices, routing_weights.view(-1, num_experts), gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias, seq_len, num_experts, top_k, ) # Benchmark torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() start = time.perf_counter() with torch.no_grad(): output = yamoe.experts( hidden_states.view(-1, hidden_dim), router_indices, routing_weights.view(-1, num_experts), gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias, seq_len, num_experts, top_k, ) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - start) * 1e3 peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) # Store kernel results kernel_output = output.clone() kernel_time = elapsed_ms kernel_memory = peak_mem_mb ## OPTIONAL # Compare to reference implementation config = type("Config", (), {})() config.hidden_size = hidden_dim config.intermediate_size = 4 * hidden_dim config.num_local_experts = num_experts model = yamoe.reference.GptOssExperts(config) # set the weights and biases from above to the reference model model.gate_up_proj.data = gate_up_proj model.gate_up_proj_bias.data = gate_up_proj_bias model.down_proj.data = down_proj model.down_proj_bias.data = down_proj_bias model = model.cuda() model.eval() torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() start = time.perf_counter() with torch.no_grad(): ref_output = model(hidden_states, router_indices, routing_weights) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - start) * 1e3 peak_mem_mb = torch.cuda.max_memory_allocated() / (1024 * 1024) # Store reference results ref_time = elapsed_ms ref_memory = peak_mem_mb # Reshape reference output to match kernel output ref_output_reshaped = ref_output.view(kernel_output.shape) # Calculate similarity metrics mse = torch.nn.functional.mse_loss(kernel_output, ref_output_reshaped).item() mae = torch.nn.functional.l1_loss(kernel_output, ref_output_reshaped).item() # Cosine similarity kernel_flat = kernel_output.view(-1) ref_flat = ref_output_reshaped.view(-1) cosine_sim = torch.nn.functional.cosine_similarity( kernel_flat.unsqueeze(0), ref_flat.unsqueeze(0) ).item() # Relative error (L2 norm of difference / L2 norm of reference) diff_norm = torch.norm(kernel_output - ref_output_reshaped).item() ref_norm = torch.norm(ref_output_reshaped).item() rel_error = diff_norm / ref_norm if ref_norm > 0 else float("inf") # Max absolute difference max_abs_diff = torch.max(torch.abs(kernel_output - ref_output_reshaped)).item() # Print comparison table print("\n" + "=" * 80) print(f"{'METRIC':<20} {'KERNEL':<15} {'REFERENCE':<15} {'SIMILARITY/SPEEDUP':<15}") print("=" * 80) print( f"{'Sum':<20} {kernel_output.sum().item():<15.4f} {ref_output_reshaped.sum().item():<15.4f} {'N/A':<15}" ) print( f"{'Min':<20} {kernel_output.min().item():<15.4f} {ref_output_reshaped.min().item():<15.4f} {'N/A':<15}" ) print( f"{'Max':<20} {kernel_output.max().item():<15.4f} {ref_output_reshaped.max().item():<15.4f} {'N/A':<15}" ) print( f"{'Norm (L2)':<20} {kernel_output.norm().item():<15.4f} {ref_output_reshaped.norm().item():<15.4f} {'N/A':<15}" ) print( f"{'Std':<20} {kernel_output.std().item():<15.4f} {ref_output_reshaped.std().item():<15.4f} {'N/A':<15}" ) print("-" * 80) print( f"{'Time (ms)':<20} {kernel_time:<15.3f} {ref_time:<15.3f} {ref_time / kernel_time:<15.2f}x" ) print( f"{'Memory (MB)':<20} {kernel_memory:<15.2f} {ref_memory:<15.2f} {ref_memory / kernel_memory:<15.2f}x" ) print("-" * 80) print("SIMILARITY METRICS") print("-" * 80) print(f"{'METRIC':<20} {'VALUE':<15} {'DIFFERENCE':<15}") print("-" * 80) print(f"{'MSE':<20} {mse:<15.6e} {'N/A':<15}") print(f"{'MAE':<20} {mae:<15.6e} {'N/A':<15}") print(f"{'Cosine Similarity':<20} {cosine_sim:<15.6f} {abs(1.0 - cosine_sim):<15.6f}") print(f"{'Relative Error':<20} {rel_error:<15.6e} {'N/A':<15}") print(f"{'Max Abs Diff':<20} {max_abs_diff:<15.6e} {'N/A':<15}") print("-" * 80) print("FIRST 10 ELEMENTS COMPARISON") print("-" * 80) # Get first 10 elements as numpy arrays for nice display kernel_first_10 = kernel_flat[:10].cpu().numpy() ref_first_10 = ref_flat[:10].cpu().numpy() diff_first_10 = kernel_first_10 - ref_first_10 print(f"{'INDEX':<5} {'KERNEL':<12} {'REFERENCE':<12} {'DIFF':<12}") print("-" * 45) for i in range(10): print( f"{i:<5} {kernel_first_10[i]:<12.6f} {ref_first_10[i]:<12.6f} {diff_first_10[i]:<12.6f}" ) print("=" * 80)