diff --git a/deprecated/sample_eqm_two.py b/deprecated/sample_eqm_two.py index 71fa4c5..fba9882 100644 --- a/deprecated/sample_eqm_two.py +++ b/deprecated/sample_eqm_two.py @@ -77,7 +77,7 @@ def main(args): save_steps_list = [] if args.save_steps is not None: save_steps_list = [int(s.strip()) for s in args.save_steps.split(",")] - img_saver = IntermediateImageSaver(save_steps_list, args.out) + img_saver = IntermediateImageSaver(save_steps_list, output_folder=args.out) hooks.append(img_saver) print(f"Created IntermediateImageSaver hook for steps: {save_steps_list}") @@ -117,7 +117,7 @@ def main(args): # Finalize gradient norm statistics if enabled if grad_tracker is not None: print("Computing gradient norm statistics...") - grad_tracker.finalize(args, args.out) + grad_tracker.finalize(args.out, args.num_sampling_steps, args.stepsize, args.sampler) # Create .npz files for FID evaluation print("Creating .npz file for final samples...") @@ -168,7 +168,10 @@ def main(args): help="Comma-separated list of sampling steps to save intermediate images (e.g., '0,50,100,249')", ) parser.add_argument( - "--track-grad-norm", action="store_true", help="Enable gradient norm tracking and visualization" + "--track-grad-norm", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable gradient norm tracking and visualization", ) parser.add_argument("--uncond", type=bool, default=True, help="Disable/enable noise conditioning (default: True)") parser.add_argument( diff --git a/experiments/cfg_effect/main.py b/experiments/cfg_effect/main.py new file mode 100644 index 0000000..2f1926c --- /dev/null +++ b/experiments/cfg_effect/main.py @@ -0,0 +1,444 @@ +""" +Experiment: CFG Effect on Image Collapse + +Investigates how classifier-free guidance (CFG) affects image collapse during +extended EqM sampling. The hypothesis is that collapse occurs when combining +model outputs from original class labels and null class labels (CFG > 1.0). + +Experiment Design: +- Baseline: Full sampling with cfg_scale for all steps +- Switch experiments: For each switch_step in switch_steps list, use cfg_scale + for first switch_step steps, then switch to null class label only (cfg=1.0) + for remaining steps. + +Implementation: +- Run sample_eqm() with cfg_scale, capture intermediate latents at switch steps +- For each captured latent, resume sample_eqm() with cfg=1.0 for remaining steps +""" + +import argparse +import json +import math +import os + +import numpy as np +import torch + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +from diffusers.models import AutoencoderKL +from tqdm import tqdm + +from download import find_model +from models import EqM_models +from utils.sampling_hooks import ( + GradientNormTracker, + IntermediateImageSaver, + SamplingHookContext, +) +from utils.sampling_utils import create_npz_from_sample_folder, sample_eqm + + +class LatentCaptureHook: + """Hook to capture intermediate latents at specified steps.""" + + def __init__(self, capture_steps): + self.capture_steps = set(capture_steps) + self.captured_latents = {} # step -> latent tensor + + def __call__(self, context: SamplingHookContext): + if context.step_idx not in self.capture_steps: + return + + # Extract conditional part if using CFG + xt_save = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_save = context.xt[:batch_size] + + self.captured_latents[context.step_idx] = xt_save.clone() + + +def run_baseline_and_capture( + model: torch.nn.Module, + vae: AutoencoderKL, + device: torch.device, + initial_latent: torch.Tensor, + class_labels: torch.Tensor, + latent_size: int, + num_sampling_steps: int, + stepsize: float, + cfg_scale: float, + sampler: str, + mu: float, + save_steps_list: list[int], + switch_steps: list[int], + output_folder: str, + batch_size: int, + num_samples: int, + track_grad_norm: bool = False, +): + """ + Run baseline with cfg_scale for all steps and capture intermediate latents. + + Returns: + captured_latents: dict mapping switch_step -> list of latent tensors (per batch) + """ + os.makedirs(output_folder, exist_ok=True) + + total_samples = int(math.ceil(num_samples / batch_size) * batch_size) + iterations = int(total_samples // batch_size) + + # Store captured latents per switch step + all_captured = {step: [] for step in switch_steps} + + # Create hooks + hooks = [] + + capture_hook = LatentCaptureHook(switch_steps) + hooks.append(capture_hook) + + if num_sampling_steps not in save_steps_list: + save_steps_list.append(num_sampling_steps) + img_saver = IntermediateImageSaver(save_steps_list, output_folder=output_folder) + hooks.append(img_saver) + + grad_tracker = None + if track_grad_norm: + grad_tracker = GradientNormTracker(num_sampling_steps) + hooks.append(grad_tracker) + + # Sampling loop + total_saved = 0 + for batch_idx in tqdm(range(iterations), desc="Baseline CFG"): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, num_samples) + actual_batch_size = batch_end - batch_start + + batch_latent = initial_latent[batch_start:batch_end] + batch_labels = class_labels[batch_start:batch_end] + + sample_eqm( + model=model, + vae=vae, + device=device, + batch_size=actual_batch_size, + latent_size=latent_size, + initial_latent=batch_latent, + class_labels=batch_labels, + num_sampling_steps=num_sampling_steps, + stepsize=stepsize, + cfg_scale=cfg_scale, + sampler=sampler, + mu=mu, + hooks=hooks, + ) + + # Store captured latents + for step in switch_steps: + if step in capture_hook.captured_latents: + all_captured[step].append(capture_hook.captured_latents[step].cpu()) + + total_saved += actual_batch_size + + if save_steps_list: + for step in save_steps_list: + step_folder = f"{output_folder}/step_{step:03d}" + if os.path.exists(step_folder): + create_npz_from_sample_folder(step_folder, num_samples) + + # Finalize gradient norm tracking + if grad_tracker is not None: + # Create args-like object for finalize + grad_tracker.finalize(output_folder, num_sampling_steps, stepsize, sampler) + + return all_captured, grad_tracker + + +def run_switch_experiment( + model: torch.nn.Module, + vae: AutoencoderKL, + device: torch.device, + captured_latents: list[torch.Tensor], + latent_size: int, + num_sampling_steps: int, + switch_step: int, + stepsize: float, + sampler: str, + mu: float, + save_steps_list: list[int], + output_folder: str, + track_grad_norm: bool = False, + baseline_grad_tracker: GradientNormTracker | None = None, +): + """ + Run switch experiment: start from captured latent at switch_step, + continue with null class label (cfg=1.0) for remaining steps. + """ + os.makedirs(output_folder, exist_ok=True) + remaining_steps = num_sampling_steps - switch_step + + # Create hooks + hooks = [] + + adjusted_save_steps = [s - switch_step for s in (save_steps_list or []) if s > switch_step] + if remaining_steps not in adjusted_save_steps: # Always save the final step + adjusted_save_steps.append(remaining_steps) + img_saver = IntermediateImageSaver( + adjusted_save_steps, + folder_pattern=lambda ctx: f"{output_folder}/step_{ctx.step_idx + switch_step:03d}", + ) + hooks.append(img_saver) + + grad_tracker = None + if track_grad_norm: + grad_tracker = GradientNormTracker(remaining_steps) + hooks.append(grad_tracker) + + # Sampling loop + total_saved = 0 + for _, batch_latent in enumerate(tqdm(captured_latents, desc=f"Switch {switch_step}")): + batch_latent = batch_latent.to(device) + actual_batch_size = batch_latent.shape[0] + + # Use null class labels (1000) + null_labels = torch.tensor([1000] * actual_batch_size, device=device) + + sample_eqm( + model=model, + vae=vae, + device=device, + batch_size=actual_batch_size, + latent_size=latent_size, + initial_latent=batch_latent, + class_labels=null_labels, + num_sampling_steps=remaining_steps, + stepsize=stepsize, + cfg_scale=1.0, # No guidance for remaining steps + sampler=sampler, + mu=mu, + hooks=hooks, + ) + + total_saved += actual_batch_size + + # Create .npz files for intermediate steps + for step in adjusted_save_steps: + step_folder = f"{output_folder}/step_{step + switch_step:03d}" + if os.path.exists(step_folder): + create_npz_from_sample_folder(step_folder, total_saved) + + # Finalize gradient norm tracking + if grad_tracker is not None: + # Prepend baseline gradient norms to get full trajectory + if baseline_grad_tracker is not None: + baseline_norms = baseline_grad_tracker.gradient_norms[:switch_step] + grad_tracker.gradient_norms = baseline_norms + grad_tracker.gradient_norms + + baseline_norms_cond = baseline_grad_tracker.gradient_norms_cond[:switch_step] + grad_tracker.gradient_norms_cond = baseline_norms_cond + grad_tracker.gradient_norms_cond + + baseline_norms_uncond = baseline_grad_tracker.gradient_norms_uncond[:switch_step] + grad_tracker.gradient_norms_uncond = baseline_norms_uncond + grad_tracker.gradient_norms_uncond + + grad_tracker.finalize(output_folder, num_sampling_steps, stepsize, sampler) + + +def main(args): + """Main function to run CFG effect experiments.""" + assert torch.cuda.is_available(), "Sampling requires at least one GPU." + + if args.ebm != "none": + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + + device = torch.device("cuda:0") + torch.manual_seed(args.seed) + torch.cuda.set_device(device) + print(f"Using device: {device}, seed: {args.seed}") + + assert args.image_size % 8 == 0, "Image size must be divisible by 8" + latent_size = args.image_size // 8 + ema_model = EqM_models[args.model]( + input_size=latent_size, num_classes=args.num_classes, uncond=args.uncond, ebm=args.ebm + ).to(device) + + if args.ckpt is None: + raise ValueError("Checkpoint is required") + print(f"Loading checkpoint from {args.ckpt}") + state_dict = find_model(args.ckpt) + if "model" in state_dict.keys(): + ema_model.load_state_dict(state_dict["ema"]) + else: + ema_model.load_state_dict(state_dict) + ema_model.eval() + print(f"EqM Parameters: {sum(p.numel() for p in ema_model.parameters()):,}") + + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + + os.makedirs(args.out, exist_ok=True) + + switch_steps = [int(s.strip()) for s in args.switch_steps.split(",")] + print(f"Switch steps: {switch_steps}") + + save_steps_list = [] + if args.save_steps is not None: + save_steps_list = [int(s.strip()) for s in args.save_steps.split(",")] + print(f"Save steps: {save_steps_list}") + + # Generate shared initial latent and class labels + print("Generating shared initial latent and class labels...") + total_samples = int(math.ceil(args.num_samples / args.batch_size) * args.batch_size) + initial_latent = torch.randn(total_samples, 4, latent_size, latent_size, device=device) + + if args.class_labels is not None: + class_ids = [int(c.strip()) for c in args.class_labels.split(",")] + class_ids_tensor = torch.tensor(class_ids, device=device, dtype=torch.long) + class_labels = class_ids_tensor[torch.randint(0, class_ids_tensor.numel(), (total_samples,), device=device)] + else: + class_labels = torch.randint(0, 1000, (total_samples,), device=device) + + # Save initial latent + latent_path = f"{args.out}/initial_latent.npz" + np.savez( + latent_path, + latent=initial_latent.cpu().numpy(), + class_labels=class_labels.cpu().numpy(), + ) + print(f"Saved initial latent and class labels to {latent_path}") + + # Save metadata + metadata = { + "seed": args.seed, + "model": args.model, + "image_size": args.image_size, + "num_classes": args.num_classes, + "batch_size": args.batch_size, + "num_samples": args.num_samples, + "num_sampling_steps": args.num_sampling_steps, + "stepsize": args.stepsize, + "cfg_scale": args.cfg_scale, + "switch_steps": switch_steps, + "save_steps": save_steps_list, + "sampler": args.sampler, + "mu": args.mu, + "vae": args.vae, + "ckpt": args.ckpt, + "class_labels_arg": args.class_labels, + "uncond": args.uncond, + "ebm": args.ebm, + "track_grad_norm": args.track_grad_norm, + } + with open(f"{args.out}/metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + print(f"Saved metadata to {args.out}/metadata.json") + + # Run baseline with full CFG and capture intermediate latents + print(f"\n{'=' * 60}") + print(f"Running baseline with cfg_scale={args.cfg_scale} (capturing latents at switch steps)") + print(f"{'=' * 60}") + captured_latents, baseline_grad_tracker = run_baseline_and_capture( + model=ema_model, + vae=vae, + device=device, + initial_latent=initial_latent, + class_labels=class_labels, + latent_size=latent_size, + num_sampling_steps=args.num_sampling_steps, + stepsize=args.stepsize, + cfg_scale=args.cfg_scale, + sampler=args.sampler, + mu=args.mu, + save_steps_list=save_steps_list, + switch_steps=switch_steps, + output_folder=f"{args.out}/baseline_cfg{args.cfg_scale}", + batch_size=args.batch_size, + num_samples=args.num_samples, + track_grad_norm=args.track_grad_norm, + ) + + # Run switch experiments using captured latents + for switch_step in switch_steps: + print(f"\n{'=' * 60}") + print(f"Running switch experiment: cfg={args.cfg_scale} for steps 1-{switch_step}, then null-only") + print(f"{'=' * 60}") + run_switch_experiment( + model=ema_model, + vae=vae, + device=device, + captured_latents=captured_latents[switch_step], + latent_size=latent_size, + num_sampling_steps=args.num_sampling_steps, + switch_step=switch_step, + stepsize=args.stepsize, + sampler=args.sampler, + mu=args.mu, + save_steps_list=save_steps_list, + output_folder=f"{args.out}/switch_{switch_step:03d}", + track_grad_norm=args.track_grad_norm, + baseline_grad_tracker=baseline_grad_tracker, + ) + + print("\nAll experiments completed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="CFG Effect Experiment") + parser.add_argument("--model", type=str, choices=list(EqM_models.keys()), default="EqM-XL/2") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--batch-size", type=int, required=True, help="Batch size for sampling") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") + parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG scale for initial steps") + parser.add_argument("--ckpt", type=str, required=True, help="Path to EqM checkpoint") + parser.add_argument( + "--class-labels", + type=str, + default=None, + help="Class labels to sample (comma-separated). If not specified, samples random classes.", + ) + parser.add_argument("--stepsize", type=float, default=0.0017, help="Step size eta") + parser.add_argument("--num-sampling-steps", type=int, default=1000, help="Total sampling steps") + parser.add_argument("--out", type=str, required=True, help="Output directory") + parser.add_argument( + "--sampler", + type=str, + default="gd", + choices=["gd", "ngd"], + help="Sampler type: 'gd' or 'ngd'", + ) + parser.add_argument("--mu", type=float, default=0.3, help="NAG-GD momentum hyperparameter") + parser.add_argument("--num-samples", type=int, required=True, help="Total samples to generate") + parser.add_argument( + "--save-steps", + type=str, + default=None, + help="Comma-separated list of steps to save intermediate images", + ) + parser.add_argument( + "--switch-steps", + type=str, + required=True, + help="Comma-separated list of steps at which to switch to null-only", + ) + parser.add_argument( + "--track-grad-norm", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable gradient norm tracking and visualization", + ) + parser.add_argument("--uncond", type=bool, default=True, help="Enable noise conditioning") + parser.add_argument( + "--ebm", + type=str, + choices=["none", "l2", "dot", "mean"], + default="none", + help="Energy formulation", + ) + + args = parser.parse_args() + main(args) diff --git a/flow_eqm_hybrid.py b/flow_eqm_hybrid.py index 53290f7..a187c4c 100644 --- a/flow_eqm_hybrid.py +++ b/flow_eqm_hybrid.py @@ -22,7 +22,8 @@ from models import EqM_models from transport import Sampler, create_transport from utils.arg_utils import parse_ode_args, parse_sde_args, parse_transport_args -from utils.sampling_utils import IntermediateImageSaver, decode_latents, sample_eqm +from utils.sampling_hooks import IntermediateImageSaver +from utils.sampling_utils import decode_latents, sample_eqm def main(mode, args): @@ -129,7 +130,7 @@ def main(mode, args): eqm_hooks_by_fm_step = {} for flow_step_idx in fm_save_steps_list: step_folder = f"{output_dir}/fm_step_{flow_step_idx:03d}" - img_saver = IntermediateImageSaver(eqm_save_steps_list, step_folder) + img_saver = IntermediateImageSaver(eqm_save_steps_list, output_folder=step_folder) eqm_hooks_by_fm_step[flow_step_idx] = [img_saver] print(f"Starting batch processing: {num_batches} batches of size {batch_size}...") diff --git a/models.py b/models.py index 6473154..061ee4a 100644 --- a/models.py +++ b/models.py @@ -278,9 +278,15 @@ def forward(self, x0, t, y, return_act=False, get_energy=False, train=False): return x, act return x - def forward_with_cfg(self, x, t, y, cfg_scale, return_act=False, get_energy=False, train=False): + def forward_with_cfg( + self, x, t, y, cfg_scale, return_act=False, get_energy=False, train=False, return_components=False + ): """ Forward pass of EqM, but also batches the uncondional forward pass for classifier-free guidance. + + Args: + return_components: If True, returns (combined_out, cond_out, uncond_out) tuple + for tracking individual outputs before CFG combination. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] @@ -305,9 +311,12 @@ def forward_with_cfg(self, x, t, y, cfg_scale, return_act=False, get_energy=Fals cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) + combined_out = torch.cat([eps, rest], dim=1) + if return_components: + return combined_out, cond_eps, uncond_eps if get_energy: - return torch.cat([eps, rest], dim=1), E - return torch.cat([eps, rest], dim=1) + return combined_out, E + return combined_out ################################################################################# diff --git a/sample_eqm.py b/sample_eqm.py index 0a88467..85bdbb7 100644 --- a/sample_eqm.py +++ b/sample_eqm.py @@ -18,7 +18,8 @@ from download import find_model from models import EqM_models -from utils.sampling_utils import GradientNormTracker, IntermediateImageSaver, create_npz_from_sample_folder, sample_eqm +from utils.sampling_hooks import GradientNormTracker, IntermediateImageSaver +from utils.sampling_utils import create_npz_from_sample_folder, sample_eqm def main(args): @@ -80,7 +81,7 @@ def main(args): if args.save_steps is not None: save_steps_list = [int(s.strip()) for s in args.save_steps.split(",")] img_saver = IntermediateImageSaver( - save_steps_list, f"{args.out}/{args.sampler}-{args.stepsize}-cfg{args.cfg_scale}" + save_steps_list, output_folder=f"{args.out}/{args.sampler}-{args.stepsize}-cfg{args.cfg_scale}" ) hooks.append(img_saver) print(f"Created IntermediateImageSaver hook for steps: {save_steps_list}") @@ -130,7 +131,7 @@ def main(args): # Finalize gradient norm statistics if enabled if grad_tracker is not None: print("Computing gradient norm statistics...") - grad_tracker.finalize(args, args.out) + grad_tracker.finalize(args.out, args.num_sampling_steps, args.stepsize, args.sampler) # Create .npz files for FID evaluation print("Creating .npz file for final samples...") @@ -187,7 +188,10 @@ def main(args): help="Comma-separated list of sampling steps to save intermediate images (e.g., '0,50,100,249')", ) parser.add_argument( - "--track-grad-norm", action="store_true", help="Enable gradient norm tracking and visualization" + "--track-grad-norm", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable gradient norm tracking and visualization", ) parser.add_argument("--uncond", type=bool, default=True, help="Disable/enable noise conditioning (default: True)") parser.add_argument( diff --git a/sample_from_clean.py b/sample_from_clean.py index c076af4..82ba141 100644 --- a/sample_from_clean.py +++ b/sample_from_clean.py @@ -169,7 +169,7 @@ def main(args): hooks = [] if len(save_steps_list) > 0: - img_saver = IntermediateImageSaver(save_steps_list, args.out) + img_saver = IntermediateImageSaver(save_steps_list, output_folder=args.out) hooks.append(img_saver) print(f"Created IntermediateImageSaver hook for steps: {save_steps_list}") @@ -248,7 +248,7 @@ def main(args): # Finalize gradient norm statistics if enabled if grad_tracker is not None: print("Computing gradient norm statistics...") - grad_tracker.finalize(args, args.out) + grad_tracker.finalize(args.out, args.num_sampling_steps, args.stepsize, args.sampler) # Finalize distortion analysis print("Finalizing distortion analysis...") @@ -398,7 +398,8 @@ def main(args): ) parser.add_argument( "--track-grad-norm", - action="store_true", + action=argparse.BooleanOptionalAction, + default=True, help="Enable gradient norm tracking and visualization", ) diff --git a/train.py b/train.py index e11696d..d79e542 100644 --- a/train.py +++ b/train.py @@ -36,11 +36,8 @@ parse_sample_args, parse_transport_args, ) -from utils.sampling_utils import ( - GradientNormTracker, - WandBImageLogger, - sample_eqm, -) +from utils.sampling_hooks import GradientNormTracker, WandBImageLogger +from utils.sampling_utils import sample_eqm from utils.utils import imagenet_label_from_idx try: diff --git a/utils/sampling_hooks.py b/utils/sampling_hooks.py new file mode 100644 index 0000000..db110bf --- /dev/null +++ b/utils/sampling_hooks.py @@ -0,0 +1,572 @@ +""" +Sampling hooks for EqM models. + +This module provides hook classes for monitoring, logging, and analysis during sampling. + +Hook System: + Hooks are callables that receive a SamplingHookContext at each sampling step. + They enable monitoring, logging, and analysis without cluttering the core sampling loop. + +Main Components: + - SamplingHookContext: Context object passed to hooks during sampling + - IntermediateImageSaver: Hook for saving intermediate images at specified steps + - WandBImageLogger: Hook for logging intermediate images to WandB + - DistortionTracker: Hook for analyzing image distortion and high-frequency correlation + - GradientNormTracker: Hook for tracking and analyzing gradient norms + +Example Usage: + >>> from utils.sampling_hooks import IntermediateImageSaver, GradientNormTracker + >>> img_hook = IntermediateImageSaver([0, 50, 100], "outputs") + >>> grad_hook = GradientNormTracker(num_sampling_steps=250) + >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") +""" + +import json +import os +from dataclasses import dataclass +from typing import Any, Callable + +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image + + +@dataclass +class SamplingHookContext: + """Context object passed to sampling hooks containing all relevant state.""" + + xt: torch.Tensor # Current latent state + t: torch.Tensor # Current timestep + y: torch.Tensor # Class labels + out: torch.Tensor # Model output/gradient (CFG-combined when use_cfg=True) + step_idx: int # Current step index (1-indexed) + use_cfg: bool # Whether CFG is enabled + vae: Any # VAE decoder for image conversion + device: torch.device # Device + total_steps: int # Total number of sampling steps + out_cond: torch.Tensor | None = None # Class label output before CFG (only when use_cfg=True) + out_uncond: torch.Tensor | None = None # Null label output before CFG (only when use_cfg=True) + + +class IntermediateImageSaver: + """ + Hook for saving intermediate images during sampling. + + Args: + save_steps: List of step indices at which to save images (e.g., [0, 50, 100, 250]) + output_folder: Base folder for saving images. Required if folder_pattern is not provided. + folder_pattern: Callable (context) -> str that returns the folder path. + If not provided, defaults to "{output_folder}/step_{step_idx:03d}" + """ + + def __init__( + self, + save_steps: list[int], + output_folder: str | None = None, + folder_pattern: Callable[[SamplingHookContext], str] | None = None, + ): + if output_folder is None and folder_pattern is None: + raise ValueError("Either output_folder or folder_pattern must be provided") + self.save_steps = set(save_steps) # Use set for O(1) lookup + # Track global sample counter for each step to avoid overwriting + self.step_counters = dict.fromkeys(save_steps, 0) + if folder_pattern is None: + self.folder_pattern = lambda ctx: f"{output_folder}/step_{ctx.step_idx:03d}" + else: + self.folder_pattern = folder_pattern + + def __call__(self, context: SamplingHookContext): + """Save images if current step is in save_steps list.""" + from utils.sampling_utils import decode_latents + + if context.step_idx not in self.save_steps: + return + + step_folder = self.folder_pattern(context) + os.makedirs(step_folder, exist_ok=True) + + # Extract conditional part if using CFG + xt_save = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_save = context.xt[:batch_size] + + samples = decode_latents(context.vae, xt_save) + + # Save images with global sequential indexing across batches + start_idx = self.step_counters[context.step_idx] + for i_sample, sample in enumerate(samples): + global_idx = start_idx + i_sample + Image.fromarray(sample).save(f"{step_folder}/{global_idx:06d}.png") + + # Update counter for this step + self.step_counters[context.step_idx] += len(samples) + + +class WandBImageLogger: + """ + Hook for logging intermediate images during sampling directly to WandB. + + Args: + save_steps: List of step indices at which to log images (e.g., [5, 10, 250]) + train_step: Current training step (for WandB logging) + output_folder: Folder to save logged images + wandb_module: wandb module (pass wandb if imported, or None to skip logging) + """ + + def __init__(self, save_steps, train_step, output_folder, wandb_module=None): + self.save_steps = set(save_steps) + self.train_step = train_step + self.output_folder = output_folder + self.wandb = wandb_module + self.logged_images = {step: [] for step in save_steps} + self.step_counters = dict.fromkeys(save_steps, 0) + + def __call__(self, context: SamplingHookContext): + """Log images to WandB if current step is in save_steps list.""" + from utils.sampling_utils import decode_latents + + if context.step_idx not in self.save_steps or self.wandb is None: + return + + folder = f"{self.output_folder}/train_{self.train_step:04d}/sample_{context.step_idx:03d}" + os.makedirs(folder, exist_ok=True) + + # Extract conditional part if using CFG + xt_save = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_save = context.xt[:batch_size] + + # Decode latents to images + samples = decode_latents(context.vae, xt_save) + + # Convert to wandb.Image objects + start_idx = self.step_counters[context.step_idx] + for i_sample, sample in enumerate(samples): + global_idx = start_idx + i_sample + img = Image.fromarray(sample) + img.save(f"{folder}/{global_idx:03d}.png") + self.logged_images[context.step_idx].append(self.wandb.Image(img, caption=f"Sample {global_idx:03d}")) + + def finalize(self): + """Log all collected images to WandB. Call this after sampling is complete.""" + if self.wandb is None: + return + + for step_idx in sorted(self.save_steps): + if len(self.logged_images[step_idx]) > 0: + self.wandb.log({f"samples/step_{step_idx:03d}": self.logged_images[step_idx]}, step=self.train_step) + + +class DistortionTracker: + """ + Hook for tracking distortion of clean images during sampling. + + Computes L2 distance between original and current latents at specified steps, + correlates with high-frequency content, and saves top distorted/undistorted images. + + Args: + original_latents: Original clean latents, shape (batch_size, 4, H, W) + high_freq_metrics: High-frequency content metrics per image, shape (batch_size,) + save_steps: List of step indices at which to track distortion + output_folder: Base folder for saving results + top_n: Number of top distorted/undistorted images to save per step + """ + + def __init__(self, original_latents, high_freq_metrics, save_steps, output_folder, top_n=10): + self.original_latents = original_latents.clone() + self.high_freq_metrics = high_freq_metrics + self.save_steps = set(save_steps) + self.output_folder = output_folder + self.top_n = top_n + + # Storage for distortion metrics at each step + # Key: step_idx, Value: list of L2 distances + self.distortions = {step: [] for step in save_steps} + + # Storage for batch information + # Key: step_idx, Value: list of (batch_start_idx, batch_latents) + self.latent_batches = {step: [] for step in save_steps} + + self.batch_counter = 0 + + def __call__(self, context: SamplingHookContext): + """Track distortion if current step is in save_steps list.""" + if context.step_idx not in self.save_steps: + return + + # Extract conditional part if using CFG + xt_current = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_current = context.xt[:batch_size] + + # Get corresponding original latents for this batch + batch_size = xt_current.shape[0] + batch_start = self.batch_counter + batch_end = batch_start + batch_size + original_batch = self.original_latents[batch_start:batch_end] + + # Compute L2 distance in latent space for each sample + diff = xt_current - original_batch + l2_distances = torch.linalg.norm(diff.reshape(diff.shape[0], -1), dim=1) # shape: (batch_size,) + + # Store distortion metrics + self.distortions[context.step_idx].extend(l2_distances.cpu().tolist()) + + # Store latents for later saving + self.latent_batches[context.step_idx].append((batch_start, xt_current.cpu().clone())) + + def on_batch_complete(self, batch_size): + """Call this after each batch to update the batch counter.""" + self.batch_counter += batch_size + + def finalize(self, vae, output_folder): + """ + Compute statistics, save top images, and create visualizations. + Call this after all sampling is complete. + + Args: + vae: VAE decoder for converting latents to images + output_folder: Output directory for saving results + """ + from utils.sampling_utils import decode_latents + + print("Analyzing distortion metrics and creating visualizations...") + + results = {"steps": {}, "high_freq_metrics": self.high_freq_metrics.tolist()} + + for step_idx in sorted(self.save_steps): + print(f" Processing step {step_idx}...") + + # Get all distortions for this step + distortions = np.array(self.distortions[step_idx]) + + if len(distortions) == 0: + print(f" Warning: No distortions recorded for step {step_idx}") + continue + + # Reconstruct full latent tensor from batches + latent_list = [] + for _batch_start, batch_latents in sorted(self.latent_batches[step_idx]): + latent_list.append(batch_latents) + all_latents = torch.cat(latent_list, dim=0) + + # Get indices of top-N most and least distorted + top_distorted_indices = np.argsort(distortions)[-self.top_n :][::-1] + top_undistorted_indices = np.argsort(distortions)[: self.top_n] + + # Save top distorted images + distorted_folder = f"{output_folder}/step_{step_idx:03d}/top_distorted" + os.makedirs(distorted_folder, exist_ok=True) + for rank, idx in enumerate(top_distorted_indices): + latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device + image = decode_latents(vae, latent)[0] + Image.fromarray(image).save( + f"{distorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" + ) + + # Save top undistorted images + undistorted_folder = f"{output_folder}/step_{step_idx:03d}/top_undistorted" + os.makedirs(undistorted_folder, exist_ok=True) + for rank, idx in enumerate(top_undistorted_indices): + latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device + image = decode_latents(vae, latent)[0] + Image.fromarray(image).save( + f"{undistorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" + ) + + # Compute correlation with high-frequency content + # Use only valid indices (in case of batch size mismatch) + valid_indices = min(len(distortions), len(self.high_freq_metrics)) + distortions_valid = distortions[:valid_indices] + high_freq_valid = self.high_freq_metrics[:valid_indices] + + correlation = np.corrcoef(high_freq_valid, distortions_valid)[0, 1] + + # Create scatter plot + plt.figure(figsize=(10, 8)) + plt.scatter(high_freq_valid, distortions_valid, alpha=0.5, s=20) + plt.xlabel("High-Frequency Content (Ratio)", fontsize=12) + plt.ylabel("L2 Distance from Original (Latent Space)", fontsize=12) + plt.title( + f"Distortion vs High-Frequency Content at Step {step_idx}\nPearson Correlation: {correlation:.4f}", + fontsize=14, + ) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + plot_path = f"{output_folder}/correlation_step_{step_idx:03d}.png" + plt.savefig(plot_path, dpi=150) + plt.close() + print(f" Saved correlation plot to {plot_path}") + + # Store statistics + results["steps"][str(step_idx)] = { + "mean_distortion": float(np.mean(distortions)), + "std_distortion": float(np.std(distortions)), + "min_distortion": float(np.min(distortions)), + "max_distortion": float(np.max(distortions)), + "correlation_with_high_freq": float(correlation), + "num_samples": len(distortions), + "top_distorted_indices": top_distorted_indices.tolist(), + "top_undistorted_indices": top_undistorted_indices.tolist(), + "top_distorted_values": distortions[top_distorted_indices].tolist(), + "top_undistorted_values": distortions[top_undistorted_indices].tolist(), + } + + # Save results to JSON + json_path = f"{output_folder}/distortion_analysis.json" + with open(json_path, "w") as f: + json.dump(results, f, indent=2) + print(f"Saved distortion analysis to {json_path}") + + +class GradientNormTracker: + """ + Hook for tracking gradient L2 norms during sampling. + + When CFG is used, tracks three norms: + - gradient_norms: CFG-applied output (combined) + - gradient_norms_cond: conditional output (out_cond) + - gradient_norms_uncond: unconditional output (out_uncond) + + When CFG is not used, only gradient_norms is tracked. + + Args: + num_steps: Number of sampling steps (for pre-allocating storage) + """ + + def __init__(self, num_steps): + self.gradient_norms = [[] for _ in range(num_steps)] # CFG-applied output norms + self.gradient_norms_cond = [[] for _ in range(num_steps)] # conditional output norms + self.gradient_norms_uncond = [[] for _ in range(num_steps)] # unconditional output norms + + def __call__(self, context: SamplingHookContext): + """Accumulate gradient L2 norms for the current step.""" + # Track CFG-applied output norm (first half, since duplicated) + out_for_norm = context.out + if context.use_cfg: + batch_size = context.out.shape[0] // 2 + out_for_norm = context.out[:batch_size] + norms = torch.linalg.norm(out_for_norm.reshape(out_for_norm.shape[0], -1), dim=1) + self.gradient_norms[context.step_idx - 1].extend(norms.cpu().tolist()) + + # Track cond/uncond norms when CFG is used + if context.use_cfg and context.out_cond is not None and context.out_uncond is not None: + out_cond = context.out_cond + out_uncond = context.out_uncond + + norms_cond = torch.linalg.norm(out_cond.reshape(out_cond.shape[0], -1), dim=1) + self.gradient_norms_cond[context.step_idx - 1].extend(norms_cond.cpu().tolist()) + + norms_uncond = torch.linalg.norm(out_uncond.reshape(out_uncond.shape[0], -1), dim=1) + self.gradient_norms_uncond[context.step_idx - 1].extend(norms_uncond.cpu().tolist()) + + def finalize(self, folder: str, num_sampling_steps: int, stepsize: float, sampler: str): + """ + Compute statistics and create visualization for gradient norms. + Call this after all sampling is complete. + + Args: + folder: Output directory for saving JSON and plot + num_sampling_steps: Number of sampling steps + stepsize: Step size used during sampling + sampler: Sampler name (e.g., 'euler', 'heun') + """ + print("Computing gradient norm statistics...") + + def compute_stats(norm_list): + means, stds = [], [] + for step_norms in norm_list: + if len(step_norms) > 0: + means.append(np.mean(step_norms)) + stds.append(np.std(step_norms)) + else: + means.append(0.0) + stds.append(0.0) + return means, stds + + # Compute stats for CFG-applied output + cfg_means, cfg_stds = compute_stats(self.gradient_norms) + + # Compute stats for cond/uncond outputs + has_cond = any(len(s) > 0 for s in self.gradient_norms_cond) + has_uncond = any(len(s) > 0 for s in self.gradient_norms_uncond) + cond_means, cond_stds = compute_stats(self.gradient_norms_cond) if has_cond else ([], []) + uncond_means, uncond_stds = compute_stats(self.gradient_norms_uncond) if has_uncond else ([], []) + + # Save statistics to JSON + stats = { + "num_sampling_steps": num_sampling_steps, + "total_samples": len(self.gradient_norms[0]) if len(self.gradient_norms[0]) > 0 else 0, + "cfg_output": {"mean": cfg_means, "std": cfg_stds}, + "stepsize": stepsize, + "sampler": sampler, + "note": "Statistics computed from individual gradient L2 norms across all samples", + } + if has_cond: + stats["cond_output"] = {"mean": cond_means, "std": cond_stds} + if has_uncond: + stats["uncond_output"] = {"mean": uncond_means, "std": uncond_stds} + + json_path = f"{folder}/gradient_norms.json" + with open(json_path, "w") as f: + json.dump(stats, f, indent=2) + print(f"Saved gradient norm statistics to {json_path}") + + # Create plot + print("Creating gradient norm plot...") + steps = np.arange(0, num_sampling_steps) + cfg_means = np.array(cfg_means) + cfg_stds = np.array(cfg_stds) + + plt.figure(figsize=(10, 6)) + + # Plot CFG-applied output norms + plt.plot(steps, cfg_means, linewidth=2, label="CFG Output Mean", color="green") + plt.fill_between( + steps, + cfg_means - cfg_stds, + cfg_means + cfg_stds, + alpha=0.2, + color="green", + label="CFG Output ± Std", + ) + + # Plot conditional output norms if available + if has_cond: + cond_means = np.array(cond_means) + cond_stds = np.array(cond_stds) + plt.plot(steps, cond_means, linewidth=2, label="Cond Output Mean", color="blue") + plt.fill_between( + steps, + cond_means - cond_stds, + cond_means + cond_stds, + alpha=0.2, + color="blue", + label="Cond Output ± Std", + ) + + # Plot unconditional output norms if available + if has_uncond: + uncond_means = np.array(uncond_means) + uncond_stds = np.array(uncond_stds) + plt.plot(steps, uncond_means, linewidth=2, label="Uncond Output Mean", color="orange") + plt.fill_between( + steps, + uncond_means - uncond_stds, + uncond_means + uncond_stds, + alpha=0.2, + color="orange", + label="Uncond Output ± Std", + ) + + plt.xlabel("Sampling Step", fontsize=12) + plt.ylabel("Gradient L2 Norm", fontsize=12) + plt.title( + f"Gradient L2 Norm during Sampling ({sampler.upper()}, stepsize={stepsize})", + fontsize=14, + ) + plt.legend(fontsize=10) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + plot_path = f"{folder}/gradient_norms.png" + plt.savefig(plot_path, dpi=150) + print(f"Saved gradient norm plot to {plot_path}") + plt.close() + + def finalize_for_wandb(self, wandb_module, train_step): + """ + Log gradient norm statistics to WandB. + Call this after all sampling is complete. + + Args: + wandb_module: wandb module for logging + train_step: Current training step for WandB logging + """ + if wandb_module is None: + return + + def compute_stats(norm_list): + means, stds = [], [] + for step_norms in norm_list: + if len(step_norms) > 0: + means.append(np.mean(step_norms)) + stds.append(np.std(step_norms)) + else: + means.append(0.0) + stds.append(0.0) + return means, stds + + # Compute stats for CFG-applied output + cfg_means, cfg_stds = compute_stats(self.gradient_norms) + + # Compute stats for cond/uncond outputs + has_cond = any(len(s) > 0 for s in self.gradient_norms_cond) + has_uncond = any(len(s) > 0 for s in self.gradient_norms_uncond) + cond_means, cond_stds = compute_stats(self.gradient_norms_cond) if has_cond else ([], []) + uncond_means, uncond_stds = compute_stats(self.gradient_norms_uncond) if has_uncond else ([], []) + + # Create a table for gradient norms vs sampling steps + data = [] + for step in range(len(cfg_means)): + row = [ + step, + cfg_means[step], + cfg_stds[step], + cfg_means[step] + cfg_stds[step], + cfg_means[step] - cfg_stds[step], + ] + if has_cond: + row.extend( + [ + cond_means[step], + cond_stds[step], + cond_means[step] + cond_stds[step], + cond_means[step] - cond_stds[step], + ] + ) + if has_uncond: + row.extend( + [ + uncond_means[step], + uncond_stds[step], + uncond_means[step] + uncond_stds[step], + uncond_means[step] - uncond_stds[step], + ] + ) + data.append(row) + + columns = [ + "sampling_step", + "cfg_output_norm_mean", + "cfg_output_norm_std", + "cfg_output_norm_upper", + "cfg_output_norm_lower", + ] + if has_cond: + columns.extend( + [ + "cond_output_norm_mean", + "cond_output_norm_std", + "cond_output_norm_upper", + "cond_output_norm_lower", + ] + ) + if has_uncond: + columns.extend( + [ + "uncond_output_norm_mean", + "uncond_output_norm_std", + "uncond_output_norm_upper", + "uncond_output_norm_lower", + ] + ) + + table = wandb_module.Table(columns=columns, data=data) + wandb_module.log({"gradient_norms": table}, step=train_step) diff --git a/utils/sampling_utils.py b/utils/sampling_utils.py index d4066f5..c5792bd 100644 --- a/utils/sampling_utils.py +++ b/utils/sampling_utils.py @@ -1,45 +1,38 @@ """ Sampling utility functions for EqM models. -This module provides reusable sampling functions and hooks for Equilibrium Matching models. +This module provides reusable sampling functions for Equilibrium Matching models. Main Components: - sample_eqm(): Core sampling function with GD/NAG-GD support - - SamplingHookContext: Context object passed to hooks during sampling - - IntermediateImageSaver: Hook for saving intermediate images at specified steps - - GradientNormTracker: Hook for tracking and analyzing gradient norms - - DistortionTracker: Hook for analyzing image distortion and high-frequency correlation - create_npz_from_sample_folder(): Create .npz file for FID evaluation - decode_latents(): Decode VAE latents to images - encode_images_to_latent(): Encode images to VAE latents - compute_high_frequency_content(): Compute high-frequency content metric using FFT -Hook System: - Hooks are callables that receive a SamplingHookContext at each sampling step. - They enable monitoring, logging, and analysis without cluttering the core sampling loop. +For sampling hooks, see utils.sampling_hooks module. Example Usage: >>> # Basic sampling >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32) >>> >>> # Sampling with hooks + >>> from utils.sampling_hooks import IntermediateImageSaver, GradientNormTracker >>> img_hook = IntermediateImageSaver([0, 50, 100], "outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) - >>> grad_hook.finalize(args, "outputs") + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") """ -import json import os -from dataclasses import dataclass -from typing import Any -import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image from tqdm import tqdm +from utils.sampling_hooks import GradientNormTracker, SamplingHookContext + @torch.no_grad() def sample_eqm( @@ -56,6 +49,7 @@ def sample_eqm( sampler="gd", mu=0.3, hooks=None, + return_cfg_components=False, ): """ Generate samples using EqM model with gradient descent sampling. @@ -79,6 +73,8 @@ def sample_eqm( hooks: List of hook callables. Each hook receives a SamplingHookContext object at each sampling step. Use for monitoring, logging, or saving intermediate results. Example hooks: IntermediateImageSaver, GradientNormTracker. + return_cfg_components: If True and CFG is enabled, passes individual cond/uncond outputs + to hooks via out_cond and out_uncond context fields. Returns: samples: Generated images as numpy array, shape (batch_size, H, W, 3), dtype uint8 @@ -89,14 +85,18 @@ def sample_eqm( >>> >>> # With hooks for monitoring >>> from sampling_utils import IntermediateImageSaver, GradientNormTracker - >>> img_hook = IntermediateImageSaver([0, 50, 100, 249], "outputs") + >>> img_hook = IntermediateImageSaver([0, 50, 100, 249], output_folder="outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) - >>> grad_hook.finalize(args, "outputs") + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") """ if hooks is None: hooks = [] + # Auto-enable return_cfg_components if GradientNormTracker is present + if any(isinstance(h, GradientNormTracker) for h in hooks): + return_cfg_components = True + use_cfg = cfg_scale > 1.0 n = batch_size @@ -132,17 +132,26 @@ def sample_eqm( # Sampling loop with torch.no_grad(): for step_idx in tqdm(range(1, num_sampling_steps + 1)): + out_cond = None + out_uncond = None + if sampler == "gd": # Standard gradient descent - out = model_fn(xt, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(xt, t, y, cfg_scale, return_components=True) + else: + out = model_fn(xt, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] elif sampler == "ngd": # Nesterov accelerated gradient descent x_ = xt + stepsize * m * mu - out = model_fn(x_, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(x_, t, y, cfg_scale, return_components=True) + else: + out = model_fn(x_, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] m = out else: raise ValueError(f"Unknown sampler: {sampler}") @@ -163,6 +172,8 @@ def sample_eqm( vae=vae, device=device, total_steps=num_sampling_steps, + out_cond=out_cond, + out_uncond=out_uncond, ) for hook in hooks: hook(context) @@ -191,6 +202,7 @@ def sample_eqm_two( sampler="gd", mu=0.3, hooks=None, + return_cfg_components=False, ): """ Generate samples using EqM model with gradient descent sampling. @@ -214,6 +226,8 @@ def sample_eqm_two( hooks: List of hook callables. Each hook receives a SamplingHookContext object at each sampling step. Use for monitoring, logging, or saving intermediate results. Example hooks: IntermediateImageSaver, GradientNormTracker. + return_cfg_components: If True and CFG is enabled, passes individual cond/uncond outputs + to hooks via out_cond and out_uncond context fields. Returns: samples: Generated images as numpy array, shape (batch_size, H, W, 3), dtype uint8 @@ -227,11 +241,15 @@ def sample_eqm_two( >>> img_hook = IntermediateImageSaver([0, 50, 100, 249], "outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) - >>> grad_hook.finalize(args, "outputs") + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") """ if hooks is None: hooks = [] + # Auto-enable return_cfg_components if GradientNormTracker is present + if any(isinstance(h, GradientNormTracker) for h in hooks): + return_cfg_components = True + use_cfg = cfg_scale > 1.0 n = batch_size @@ -267,17 +285,26 @@ def sample_eqm_two( # Sampling loop with torch.no_grad(): for step_idx in range(1, num_sampling_steps + 1): + out_cond = None + out_uncond = None + if sampler == "gd": # Standard gradient descent - out = model_fn(xt, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(xt, t, y, cfg_scale, return_components=True) + else: + out = model_fn(xt, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] elif sampler == "ngd": # Nesterov accelerated gradient descent x_ = xt + stepsize * m * mu - out = model_fn(x_, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(x_, t, y, cfg_scale, return_components=True) + else: + out = model_fn(x_, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] m = out else: raise ValueError(f"Unknown sampler: {sampler}") @@ -320,6 +347,8 @@ def sample_eqm_two( vae=vae, device=device, total_steps=num_sampling_steps, + out_cond=out_cond, + out_uncond=out_uncond, ) for hook in hooks: hook(context) @@ -333,420 +362,6 @@ def sample_eqm_two( return samples -@dataclass -class SamplingHookContext: - """Context object passed to sampling hooks containing all relevant state.""" - - xt: torch.Tensor # Current latent state - t: torch.Tensor # Current timestep - y: torch.Tensor # Class labels - out: torch.Tensor # Model output/gradient - step_idx: int # Current step index (1-indexed) - use_cfg: bool # Whether CFG is enabled - vae: Any # VAE decoder for image conversion - device: torch.device # Device - total_steps: int # Total number of sampling steps - - -class IntermediateImageSaver: - """ - Hook for saving intermediate images during sampling. - - Args: - save_steps: List of step indices at which to save images (e.g., [0, 50, 100, 250]) - output_folder: Base folder for saving images - """ - - def __init__(self, save_steps, output_folder): - self.save_steps = set(save_steps) # Use set for O(1) lookup - self.output_folder = output_folder - # Track global sample counter for each step to avoid overwriting - self.step_counters = dict.fromkeys(save_steps, 0) - # self.return_images = return_images - - def __call__(self, context: SamplingHookContext): - """Save images if current step is in save_steps list.""" - if context.step_idx not in self.save_steps: - return - - step_folder = f"{self.output_folder}/step_{context.step_idx:03d}" - os.makedirs(step_folder, exist_ok=True) - - # Extract conditional part if using CFG - xt_save = context.xt - if context.use_cfg: - batch_size = context.xt.shape[0] // 2 - xt_save = context.xt[:batch_size] - - samples = decode_latents(context.vae, xt_save) - - # Save images with global sequential indexing across batches - start_idx = self.step_counters[context.step_idx] - for i_sample, sample in enumerate(samples): - global_idx = start_idx + i_sample - Image.fromarray(sample).save(f"{step_folder}/{global_idx:06d}.png") - - # Update counter for this step - self.step_counters[context.step_idx] += len(samples) - - -class WandBImageLogger: - """ - Hook for logging intermediate images during sampling directly to WandB. - - Args: - save_steps: List of step indices at which to log images (e.g., [5, 10, 250]) - train_step: Current training step (for WandB logging) - output_folder: Folder to save logged images - wandb_module: wandb module (pass wandb if imported, or None to skip logging) - """ - - def __init__(self, save_steps, train_step, output_folder, wandb_module=None): - self.save_steps = set(save_steps) - self.train_step = train_step - self.output_folder = output_folder - self.wandb = wandb_module - self.logged_images = {step: [] for step in save_steps} - self.step_counters = dict.fromkeys(save_steps, 0) - - def __call__(self, context: SamplingHookContext): - """Log images to WandB if current step is in save_steps list.""" - if context.step_idx not in self.save_steps or self.wandb is None: - return - - folder = f"{self.output_folder}/train_{self.train_step:04d}/sample_{context.step_idx:03d}" - os.makedirs(folder, exist_ok=True) - - # Extract conditional part if using CFG - xt_save = context.xt - if context.use_cfg: - batch_size = context.xt.shape[0] // 2 - xt_save = context.xt[:batch_size] - - # Decode latents to images - samples = decode_latents(context.vae, xt_save) - - # Convert to wandb.Image objects - start_idx = self.step_counters[context.step_idx] - for i_sample, sample in enumerate(samples): - global_idx = start_idx + i_sample - img = Image.fromarray(sample) - img.save(f"{folder}/{global_idx:03d}.png") - self.logged_images[context.step_idx].append(self.wandb.Image(img, caption=f"Sample {global_idx:03d}")) - - def finalize(self): - """Log all collected images to WandB. Call this after sampling is complete.""" - if self.wandb is None: - return - - for step_idx in sorted(self.save_steps): - if len(self.logged_images[step_idx]) > 0: - self.wandb.log({f"samples/step_{step_idx:03d}": self.logged_images[step_idx]}, step=self.train_step) - - -class DistortionTracker: - """ - Hook for tracking distortion of clean images during sampling. - - Computes L2 distance between original and current latents at specified steps, - correlates with high-frequency content, and saves top distorted/undistorted images. - - Args: - original_latents: Original clean latents, shape (batch_size, 4, H, W) - high_freq_metrics: High-frequency content metrics per image, shape (batch_size,) - save_steps: List of step indices at which to track distortion - output_folder: Base folder for saving results - top_n: Number of top distorted/undistorted images to save per step - """ - - def __init__(self, original_latents, high_freq_metrics, save_steps, output_folder, top_n=10): - self.original_latents = original_latents.clone() - self.high_freq_metrics = high_freq_metrics - self.save_steps = set(save_steps) - self.output_folder = output_folder - self.top_n = top_n - - # Storage for distortion metrics at each step - # Key: step_idx, Value: list of L2 distances - self.distortions = {step: [] for step in save_steps} - - # Storage for batch information - # Key: step_idx, Value: list of (batch_start_idx, batch_latents) - self.latent_batches = {step: [] for step in save_steps} - - self.batch_counter = 0 - - def __call__(self, context: SamplingHookContext): - """Track distortion if current step is in save_steps list.""" - if context.step_idx not in self.save_steps: - return - - # Extract conditional part if using CFG - xt_current = context.xt - if context.use_cfg: - batch_size = context.xt.shape[0] // 2 - xt_current = context.xt[:batch_size] - - # Get corresponding original latents for this batch - batch_size = xt_current.shape[0] - batch_start = self.batch_counter - batch_end = batch_start + batch_size - original_batch = self.original_latents[batch_start:batch_end] - - # Compute L2 distance in latent space for each sample - diff = xt_current - original_batch - l2_distances = torch.linalg.norm(diff.reshape(diff.shape[0], -1), dim=1) # shape: (batch_size,) - - # Store distortion metrics - self.distortions[context.step_idx].extend(l2_distances.cpu().tolist()) - - # Store latents for later saving - self.latent_batches[context.step_idx].append((batch_start, xt_current.cpu().clone())) - - def on_batch_complete(self, batch_size): - """Call this after each batch to update the batch counter.""" - self.batch_counter += batch_size - - def finalize(self, vae, output_folder): - """ - Compute statistics, save top images, and create visualizations. - Call this after all sampling is complete. - - Args: - vae: VAE decoder for converting latents to images - output_folder: Output directory for saving results - """ - print("Analyzing distortion metrics and creating visualizations...") - - results = {"steps": {}, "high_freq_metrics": self.high_freq_metrics.tolist()} - - for step_idx in sorted(self.save_steps): - print(f" Processing step {step_idx}...") - - # Get all distortions for this step - distortions = np.array(self.distortions[step_idx]) - - if len(distortions) == 0: - print(f" Warning: No distortions recorded for step {step_idx}") - continue - - # Reconstruct full latent tensor from batches - latent_list = [] - for _batch_start, batch_latents in sorted(self.latent_batches[step_idx]): - latent_list.append(batch_latents) - all_latents = torch.cat(latent_list, dim=0) - - # Get indices of top-N most and least distorted - top_distorted_indices = np.argsort(distortions)[-self.top_n :][::-1] - top_undistorted_indices = np.argsort(distortions)[: self.top_n] - - # Save top distorted images - distorted_folder = f"{output_folder}/step_{step_idx:03d}/top_distorted" - os.makedirs(distorted_folder, exist_ok=True) - for rank, idx in enumerate(top_distorted_indices): - latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device - image = decode_latents(vae, latent)[0] - Image.fromarray(image).save( - f"{distorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" - ) - - # Save top undistorted images - undistorted_folder = f"{output_folder}/step_{step_idx:03d}/top_undistorted" - os.makedirs(undistorted_folder, exist_ok=True) - for rank, idx in enumerate(top_undistorted_indices): - latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device - image = decode_latents(vae, latent)[0] - Image.fromarray(image).save( - f"{undistorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" - ) - - # Compute correlation with high-frequency content - # Use only valid indices (in case of batch size mismatch) - valid_indices = min(len(distortions), len(self.high_freq_metrics)) - distortions_valid = distortions[:valid_indices] - high_freq_valid = self.high_freq_metrics[:valid_indices] - - correlation = np.corrcoef(high_freq_valid, distortions_valid)[0, 1] - - # Create scatter plot - plt.figure(figsize=(10, 8)) - plt.scatter(high_freq_valid, distortions_valid, alpha=0.5, s=20) - plt.xlabel("High-Frequency Content (Ratio)", fontsize=12) - plt.ylabel("L2 Distance from Original (Latent Space)", fontsize=12) - plt.title( - f"Distortion vs High-Frequency Content at Step {step_idx}\nPearson Correlation: {correlation:.4f}", - fontsize=14, - ) - plt.grid(True, alpha=0.3) - plt.tight_layout() - - plot_path = f"{output_folder}/correlation_step_{step_idx:03d}.png" - plt.savefig(plot_path, dpi=150) - plt.close() - print(f" Saved correlation plot to {plot_path}") - - # Store statistics - results["steps"][str(step_idx)] = { - "mean_distortion": float(np.mean(distortions)), - "std_distortion": float(np.std(distortions)), - "min_distortion": float(np.min(distortions)), - "max_distortion": float(np.max(distortions)), - "correlation_with_high_freq": float(correlation), - "num_samples": len(distortions), - "top_distorted_indices": top_distorted_indices.tolist(), - "top_undistorted_indices": top_undistorted_indices.tolist(), - "top_distorted_values": distortions[top_distorted_indices].tolist(), - "top_undistorted_values": distortions[top_undistorted_indices].tolist(), - } - - # Save results to JSON - json_path = f"{output_folder}/distortion_analysis.json" - with open(json_path, "w") as f: - json.dump(results, f, indent=2) - print(f"Saved distortion analysis to {json_path}") - - -class GradientNormTracker: - """ - Hook for tracking gradient L2 norms during sampling. - - Args: - num_steps: Number of sampling steps (for pre-allocating storage) - """ - - def __init__(self, num_steps): - self.gradient_norms = [[] for _ in range(num_steps)] - - def __call__(self, context: SamplingHookContext): - """Accumulate gradient L2 norms for the current step.""" - # Extract conditional part if using CFG - out_for_norm = context.out - if context.use_cfg: - batch_size = context.out.shape[0] // 2 - out_for_norm = context.out[:batch_size] - - # Compute L2 norm for each sample in the batch - norms = torch.linalg.norm(out_for_norm.reshape(out_for_norm.shape[0], -1), dim=1) # shape: (batch_size,) - self.gradient_norms[context.step_idx - 1].extend(norms.cpu().tolist()) - - def finalize(self, args, folder): - """ - Compute statistics and create visualization for gradient norms. - Call this after all sampling is complete. - - Args: - args: Arguments containing sampling parameters (num_sampling_steps, stepsize, sampler) - folder: Output directory for saving JSON and plot - """ - print("Computing gradient norm statistics...") - gradient_means = [] - gradient_stds = [] - - for step_norms in self.gradient_norms: - if len(step_norms) > 0: - gradient_means.append(np.mean(step_norms)) - gradient_stds.append(np.std(step_norms)) - else: - gradient_means.append(0.0) - gradient_stds.append(0.0) - - # Save statistics to JSON - stats = { - "num_sampling_steps": args.num_sampling_steps, - "total_samples": len(self.gradient_norms[0]) if len(self.gradient_norms[0]) > 0 else 0, - "mean": gradient_means, - "std": gradient_stds, - "stepsize": args.stepsize, - "sampler": args.sampler, - "note": "Statistics computed from individual gradient L2 norms across all samples (batch-size independent)", - } - json_path = f"{folder}/gradient_norms.json" - with open(json_path, "w") as f: - json.dump(stats, f, indent=2) - print(f"Saved gradient norm statistics to {json_path}") - - # Create plot - print("Creating gradient norm plot...") - steps = np.arange(0, args.num_sampling_steps) - gradient_means = np.array(gradient_means) - gradient_stds = np.array(gradient_stds) - - plt.figure(figsize=(10, 6)) - plt.plot(steps, gradient_means, linewidth=2, label="Mean L2 Norm") - plt.fill_between( - steps, - gradient_means - gradient_stds, - gradient_means + gradient_stds, - alpha=0.3, - label="Mean ± Std", - ) - plt.xlabel("Sampling Step", fontsize=12) - plt.ylabel("Gradient L2 Norm", fontsize=12) - plt.title( - f"Gradient L2 Norm during Sampling ({args.sampler.upper()}, stepsize={args.stepsize})", - fontsize=14, - ) - plt.legend(fontsize=10) - plt.grid(True, alpha=0.3) - plt.tight_layout() - - plot_path = f"{folder}/gradient_norms.png" - plt.savefig(plot_path, dpi=150) - print(f"Saved gradient norm plot to {plot_path}") - plt.close() - - def finalize_for_wandb(self, wandb_module, train_step): - """ - Log gradient norm statistics to WandB. - Call this after all sampling is complete. - - Args: - wandb_module: wandb module for logging - train_step: Current training step for WandB logging - """ - if wandb_module is None: - return - - # Compute statistics - gradient_means = [] - gradient_stds = [] - - for step_norms in self.gradient_norms: - if len(step_norms) > 0: - gradient_means.append(np.mean(step_norms)) - gradient_stds.append(np.std(step_norms)) - else: - gradient_means.append(0.0) - gradient_stds.append(0.0) - - # Create a table for gradient norms vs sampling steps - # This allows WandB to plot with sampling_step on x-axis - data = [] - for sampling_step, (mean_norm, std_norm) in enumerate(zip(gradient_means, gradient_stds)): - data.append( - [ - sampling_step, - mean_norm, - std_norm, - mean_norm + std_norm, # Upper bound - mean_norm - std_norm, # Lower bound - ] - ) - - table = wandb_module.Table( - columns=[ - "sampling_step", - "gradient_norm_mean", - "gradient_norm_std", - "gradient_norm_upper", - "gradient_norm_lower", - ], - data=data, - ) - # Log the table with a consistent key (train_step is already in the table data) - wandb_module.log({"gradient_norms": table}, step=train_step) - - def create_npz_from_sample_folder(sample_dir, num=None): """ Builds a single .npz file from a folder of .png samples.