From ff24872e3e1cd0c8638a122a85f493d411505e5a Mon Sep 17 00:00:00 2001 From: Kevin On Date: Thu, 11 Dec 2025 07:29:44 +0000 Subject: [PATCH 1/7] Add CFG effect experiment implementation - Introduced a new script to investigate the impact of classifier-free guidance (CFG) on image collapse during extended EqM sampling. - Implemented baseline sampling with CFG and a switch experiment to analyze the effects of transitioning to null class labels. - Captured intermediate latents and saved generated images at specified steps. - Added argument parsing for various experiment configurations including model selection, image size, and sampling parameters. --- experiments/cfg_effect/main.py | 437 +++++++++++++++++++++++++++++++++ 1 file changed, 437 insertions(+) create mode 100644 experiments/cfg_effect/main.py diff --git a/experiments/cfg_effect/main.py b/experiments/cfg_effect/main.py new file mode 100644 index 0000000..57d8a3b --- /dev/null +++ b/experiments/cfg_effect/main.py @@ -0,0 +1,437 @@ +""" +Experiment: CFG Effect on Image Collapse + +Investigates how classifier-free guidance (CFG) affects image collapse during +extended EqM sampling. The hypothesis is that collapse occurs when combining +model outputs from original class labels and null class labels (CFG > 1.0). + +Experiment Design: +- Baseline: Full sampling with cfg_scale for all steps +- Switch experiments: For each switch_step in switch_steps list, use cfg_scale + for first switch_step steps, then switch to null class label only (cfg=1.0) + for remaining steps. + +Implementation: +- Run sample_eqm() with cfg_scale, capture intermediate latents at switch steps +- For each captured latent, resume sample_eqm() with cfg=1.0 for remaining steps +""" + +import argparse +import json +import math +import os +import sys + +import numpy as np +import torch + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +from diffusers.models import AutoencoderKL +from PIL import Image +from tqdm import tqdm + +# Add parent directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from download import find_model +from models import EqM_models +from utils.sampling_utils import ( + IntermediateImageSaver, + SamplingHookContext, + create_npz_from_sample_folder, + sample_eqm, +) + + +class LatentCaptureHook: + """Hook to capture intermediate latents at specified steps.""" + + def __init__(self, capture_steps): + self.capture_steps = set(capture_steps) + self.captured_latents = {} # step -> latent tensor + + def __call__(self, context: SamplingHookContext): + if context.step_idx not in self.capture_steps: + return + + # Extract conditional part if using CFG + xt_save = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_save = context.xt[:batch_size] + + self.captured_latents[context.step_idx] = xt_save.clone() + + +def run_baseline_and_capture( + model, + vae, + device, + initial_latent, + class_labels, + latent_size, + num_sampling_steps, + stepsize, + cfg_scale, + sampler, + mu, + save_steps_list, + switch_steps, + output_folder, + batch_size, + num_samples, +): + """ + Run baseline with cfg_scale for all steps and capture intermediate latents. + + Returns: + captured_latents: dict mapping switch_step -> list of latent tensors (per batch) + """ + os.makedirs(output_folder, exist_ok=True) + + final_step_folder = f"{output_folder}/step_{num_sampling_steps:04d}" + os.makedirs(final_step_folder, exist_ok=True) + + total_samples = int(math.ceil(num_samples / batch_size) * batch_size) + iterations = int(total_samples // batch_size) + + # Store captured latents per switch step + all_captured = {step: [] for step in switch_steps} + + total_saved = 0 + for batch_idx in tqdm(range(iterations), desc="Baseline CFG"): + batch_start = batch_idx * batch_size + batch_end = min(batch_start + batch_size, num_samples) + actual_batch_size = batch_end - batch_start + + batch_latent = initial_latent[batch_start:batch_end] + batch_labels = class_labels[batch_start:batch_end] + + # Create hooks + hooks = [] + + # Capture hook for switch steps + capture_hook = LatentCaptureHook(switch_steps) + hooks.append(capture_hook) + + # Image saver hook + if save_steps_list: + img_saver = IntermediateImageSaver(save_steps_list, output_folder) + hooks.append(img_saver) + + samples = sample_eqm( + model=model, + vae=vae, + device=device, + batch_size=actual_batch_size, + latent_size=latent_size, + initial_latent=batch_latent, + class_labels=batch_labels, + num_sampling_steps=num_sampling_steps, + stepsize=stepsize, + cfg_scale=cfg_scale, + sampler=sampler, + mu=mu, + hooks=hooks, + ) + + # Store captured latents + for step in switch_steps: + if step in capture_hook.captured_latents: + all_captured[step].append(capture_hook.captured_latents[step].cpu()) + + # Save final samples + for i_sample, sample in enumerate(samples): + index = total_saved + i_sample + Image.fromarray(sample).save(f"{final_step_folder}/{index:06d}.png") + + total_saved += actual_batch_size + + print(f"Saved {total_saved} samples to {final_step_folder}") + create_npz_from_sample_folder(final_step_folder, num_samples) + + if save_steps_list: + for step in save_steps_list: + step_folder = f"{output_folder}/step_{step:04d}" + if os.path.exists(step_folder): + create_npz_from_sample_folder(step_folder, num_samples) + + return all_captured + + +def run_switch_experiment( + model, + vae, + device, + captured_latents, + latent_size, + num_sampling_steps, + switch_step, + stepsize, + sampler, + mu, + save_steps_list, + output_folder, + batch_size, + num_samples, +): + """ + Run switch experiment: start from captured latent at switch_step, + continue with null class label (cfg=1.0) for remaining steps. + """ + os.makedirs(output_folder, exist_ok=True) + + remaining_steps = num_sampling_steps - switch_step + final_step_folder = f"{output_folder}/step_{num_sampling_steps:04d}" + os.makedirs(final_step_folder, exist_ok=True) + + # Adjust save_steps for remaining sampling + adjusted_save_steps = [] + if save_steps_list: + for s in save_steps_list: + if s > switch_step: + adjusted_save_steps.append(s - switch_step) + + total_saved = 0 + for _, batch_latent in enumerate(tqdm(captured_latents, desc=f"Switch {switch_step}")): + batch_latent = batch_latent.to(device) + actual_batch_size = batch_latent.shape[0] + + # Use null class labels (1000) + null_labels = torch.tensor([1000] * actual_batch_size, device=device) + + # Create hooks + hooks = [] + if adjusted_save_steps: + # Custom saver that maps adjusted steps back to original steps + img_saver = IntermediateImageSaver(adjusted_save_steps, output_folder) + hooks.append(img_saver) + + samples = sample_eqm( + model=model, + vae=vae, + device=device, + batch_size=actual_batch_size, + latent_size=latent_size, + initial_latent=batch_latent, + class_labels=null_labels, + num_sampling_steps=remaining_steps, + stepsize=stepsize, + cfg_scale=1.0, # No guidance for remaining steps + sampler=sampler, + mu=mu, + hooks=hooks, + ) + + # Save final samples + for i_sample, sample in enumerate(samples): + index = total_saved + i_sample + Image.fromarray(sample).save(f"{final_step_folder}/{index:06d}.png") + + total_saved += actual_batch_size + + print(f"Saved {total_saved} samples to {final_step_folder}") + create_npz_from_sample_folder(final_step_folder, num_samples) + + # Rename intermediate step folders to match original step numbering + if adjusted_save_steps: + for adj_step in adjusted_save_steps: + orig_step = adj_step + switch_step + src_folder = f"{output_folder}/step_{adj_step:03d}" + dst_folder = f"{output_folder}/step_{orig_step:04d}" + if os.path.exists(src_folder): + os.rename(src_folder, dst_folder) + create_npz_from_sample_folder(dst_folder, num_samples) + + +def main(args): + """Main function to run CFG effect experiments.""" + assert torch.cuda.is_available(), "Sampling requires at least one GPU." + + if args.ebm != "none": + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_cudnn_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + + device = torch.device("cuda:0") + torch.manual_seed(args.seed) + torch.cuda.set_device(device) + print(f"Using device: {device}, seed: {args.seed}") + + assert args.image_size % 8 == 0, "Image size must be divisible by 8" + latent_size = args.image_size // 8 + ema_model = EqM_models[args.model]( + input_size=latent_size, num_classes=args.num_classes, uncond=args.uncond, ebm=args.ebm + ).to(device) + + if args.ckpt is None: + raise ValueError("Checkpoint is required") + print(f"Loading checkpoint from {args.ckpt}") + state_dict = find_model(args.ckpt) + if "model" in state_dict.keys(): + ema_model.load_state_dict(state_dict["ema"]) + else: + ema_model.load_state_dict(state_dict) + ema_model.eval() + print(f"EqM Parameters: {sum(p.numel() for p in ema_model.parameters()):,}") + + vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) + + os.makedirs(args.out, exist_ok=True) + + switch_steps = [int(s.strip()) for s in args.switch_steps.split(",")] + print(f"Switch steps: {switch_steps}") + + save_steps_list = [] + if args.save_steps is not None: + save_steps_list = [int(s.strip()) for s in args.save_steps.split(",")] + print(f"Save steps: {save_steps_list}") + + # Generate shared initial latent and class labels + print("Generating shared initial latent and class labels...") + total_samples = int(math.ceil(args.num_samples / args.batch_size) * args.batch_size) + initial_latent = torch.randn(total_samples, 4, latent_size, latent_size, device=device) + + if args.class_labels is not None: + class_ids = [int(c.strip()) for c in args.class_labels.split(",")] + class_ids_tensor = torch.tensor(class_ids, device=device, dtype=torch.long) + class_labels = class_ids_tensor[torch.randint(0, class_ids_tensor.numel(), (total_samples,), device=device)] + else: + class_labels = torch.randint(0, 1000, (total_samples,), device=device) + + # Save initial latent + latent_path = f"{args.out}/initial_latent.npz" + np.savez( + latent_path, + latent=initial_latent.cpu().numpy(), + class_labels=class_labels.cpu().numpy(), + ) + print(f"Saved initial latent and class labels to {latent_path}") + + # Save metadata + metadata = { + "seed": args.seed, + "model": args.model, + "image_size": args.image_size, + "num_classes": args.num_classes, + "batch_size": args.batch_size, + "num_samples": args.num_samples, + "num_sampling_steps": args.num_sampling_steps, + "stepsize": args.stepsize, + "cfg_scale": args.cfg_scale, + "switch_steps": switch_steps, + "save_steps": save_steps_list, + "sampler": args.sampler, + "mu": args.mu, + "vae": args.vae, + "ckpt": args.ckpt, + "class_labels_arg": args.class_labels, + "uncond": args.uncond, + "ebm": args.ebm, + } + with open(f"{args.out}/metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + print(f"Saved metadata to {args.out}/metadata.json") + + # Run baseline with full CFG and capture intermediate latents + print(f"\n{'=' * 60}") + print(f"Running baseline with cfg_scale={args.cfg_scale} (capturing latents at switch steps)") + print(f"{'=' * 60}") + captured_latents = run_baseline_and_capture( + model=ema_model, + vae=vae, + device=device, + initial_latent=initial_latent, + class_labels=class_labels, + latent_size=latent_size, + num_sampling_steps=args.num_sampling_steps, + stepsize=args.stepsize, + cfg_scale=args.cfg_scale, + sampler=args.sampler, + mu=args.mu, + save_steps_list=save_steps_list, + switch_steps=switch_steps, + output_folder=f"{args.out}/baseline_cfg{args.cfg_scale}", + batch_size=args.batch_size, + num_samples=args.num_samples, + ) + + # Run switch experiments using captured latents + for switch_step in switch_steps: + print(f"\n{'=' * 60}") + print(f"Running switch experiment: cfg={args.cfg_scale} for steps 1-{switch_step}, then null-only") + print(f"{'=' * 60}") + run_switch_experiment( + model=ema_model, + vae=vae, + device=device, + captured_latents=captured_latents[switch_step], + latent_size=latent_size, + num_sampling_steps=args.num_sampling_steps, + switch_step=switch_step, + stepsize=args.stepsize, + sampler=args.sampler, + mu=args.mu, + save_steps_list=save_steps_list, + output_folder=f"{args.out}/switch_{switch_step:04d}", + batch_size=args.batch_size, + num_samples=args.num_samples, + ) + + print("\nAll experiments completed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="CFG Effect Experiment") + parser.add_argument("--model", type=str, choices=list(EqM_models.keys()), default="EqM-XL/2") + parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) + parser.add_argument("--num-classes", type=int, default=1000) + parser.add_argument("--batch-size", type=int, required=True, help="Batch size for sampling") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") + parser.add_argument("--cfg-scale", type=float, default=4.0, help="CFG scale for initial steps") + parser.add_argument("--ckpt", type=str, required=True, help="Path to EqM checkpoint") + parser.add_argument( + "--class-labels", + type=str, + default=None, + help="Class labels to sample (comma-separated). If not specified, samples random classes.", + ) + parser.add_argument("--stepsize", type=float, default=0.0017, help="Step size eta") + parser.add_argument("--num-sampling-steps", type=int, default=1000, help="Total sampling steps") + parser.add_argument("--out", type=str, required=True, help="Output directory") + parser.add_argument( + "--sampler", + type=str, + default="gd", + choices=["gd", "ngd"], + help="Sampler type: 'gd' or 'ngd'", + ) + parser.add_argument("--mu", type=float, default=0.3, help="NAG-GD momentum hyperparameter") + parser.add_argument("--num-samples", type=int, required=True, help="Total samples to generate") + parser.add_argument( + "--save-steps", + type=str, + default=None, + help="Comma-separated list of steps to save intermediate images", + ) + parser.add_argument( + "--switch-steps", + type=str, + required=True, + help="Comma-separated list of steps at which to switch to null-only", + ) + parser.add_argument("--uncond", type=bool, default=True, help="Enable noise conditioning") + parser.add_argument( + "--ebm", + type=str, + choices=["none", "l2", "dot", "mean"], + default="none", + help="Energy formulation", + ) + + args = parser.parse_args() + main(args) From 78e0cb2b144c1f694b1d8156506716e9c0b96232 Mon Sep 17 00:00:00 2001 From: Kevin On Date: Thu, 11 Dec 2025 08:25:11 +0000 Subject: [PATCH 2/7] Add gradient norm tracking to CFG effect experiments --- experiments/cfg_effect/main.py | 80 ++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 8 deletions(-) diff --git a/experiments/cfg_effect/main.py b/experiments/cfg_effect/main.py index 57d8a3b..562b046 100644 --- a/experiments/cfg_effect/main.py +++ b/experiments/cfg_effect/main.py @@ -38,6 +38,7 @@ from download import find_model from models import EqM_models from utils.sampling_utils import ( + GradientNormTracker, IntermediateImageSaver, SamplingHookContext, create_npz_from_sample_folder, @@ -82,6 +83,7 @@ def run_baseline_and_capture( output_folder, batch_size, num_samples, + track_grad_norm=False, ): """ Run baseline with cfg_scale for all steps and capture intermediate latents. @@ -100,6 +102,12 @@ def run_baseline_and_capture( # Store captured latents per switch step all_captured = {step: [] for step in switch_steps} + # Create gradient tracker (persistent across batches) + grad_tracker = None + if track_grad_norm: + grad_tracker = GradientNormTracker(num_sampling_steps) + print("Created GradientNormTracker hook for baseline") + total_saved = 0 for batch_idx in tqdm(range(iterations), desc="Baseline CFG"): batch_start = batch_idx * batch_size @@ -121,6 +129,10 @@ def run_baseline_and_capture( img_saver = IntermediateImageSaver(save_steps_list, output_folder) hooks.append(img_saver) + # Gradient norm tracker hook + if grad_tracker is not None: + hooks.append(grad_tracker) + samples = sample_eqm( model=model, vae=vae, @@ -158,7 +170,19 @@ def run_baseline_and_capture( if os.path.exists(step_folder): create_npz_from_sample_folder(step_folder, num_samples) - return all_captured + # Finalize gradient norm tracking + if grad_tracker is not None: + # Create args-like object for finalize + class GradArgs: + pass + + grad_args = GradArgs() + grad_args.num_sampling_steps = num_sampling_steps + grad_args.stepsize = stepsize + grad_args.sampler = sampler + grad_tracker.finalize(grad_args, output_folder) + + return all_captured, grad_tracker def run_switch_experiment( @@ -176,6 +200,8 @@ def run_switch_experiment( output_folder, batch_size, num_samples, + track_grad_norm=False, + baseline_grad_tracker=None, ): """ Run switch experiment: start from captured latent at switch_step, @@ -194,6 +220,17 @@ def run_switch_experiment( if s > switch_step: adjusted_save_steps.append(s - switch_step) + # Create gradient tracker (persistent across batches) + grad_tracker = None + if track_grad_norm: + grad_tracker = GradientNormTracker(remaining_steps) + print(f"Created GradientNormTracker hook for switch_{switch_step}") + + # Create image saver (persistent across batches) + img_saver = None + if adjusted_save_steps: + img_saver = IntermediateImageSaver(adjusted_save_steps, output_folder) + total_saved = 0 for _, batch_latent in enumerate(tqdm(captured_latents, desc=f"Switch {switch_step}")): batch_latent = batch_latent.to(device) @@ -204,11 +241,13 @@ def run_switch_experiment( # Create hooks hooks = [] - if adjusted_save_steps: - # Custom saver that maps adjusted steps back to original steps - img_saver = IntermediateImageSaver(adjusted_save_steps, output_folder) + if img_saver is not None: hooks.append(img_saver) + # Gradient norm tracker hook + if grad_tracker is not None: + hooks.append(grad_tracker) + samples = sample_eqm( model=model, vae=vae, @@ -233,17 +272,35 @@ def run_switch_experiment( total_saved += actual_batch_size print(f"Saved {total_saved} samples to {final_step_folder}") - create_npz_from_sample_folder(final_step_folder, num_samples) + create_npz_from_sample_folder(final_step_folder, total_saved) # Rename intermediate step folders to match original step numbering - if adjusted_save_steps: + if img_saver is not None: for adj_step in adjusted_save_steps: orig_step = adj_step + switch_step src_folder = f"{output_folder}/step_{adj_step:03d}" dst_folder = f"{output_folder}/step_{orig_step:04d}" if os.path.exists(src_folder): os.rename(src_folder, dst_folder) - create_npz_from_sample_folder(dst_folder, num_samples) + step_count = img_saver.step_counters[adj_step] + create_npz_from_sample_folder(dst_folder, step_count) + + # Finalize gradient norm tracking + if grad_tracker is not None: + # Prepend baseline gradient norms to get full trajectory + if baseline_grad_tracker is not None: + baseline_norms = baseline_grad_tracker.gradient_norms[:switch_step] + grad_tracker.gradient_norms = baseline_norms + grad_tracker.gradient_norms + + # Create args-like object for finalize + class GradArgs: + pass + + grad_args = GradArgs() + grad_args.num_sampling_steps = num_sampling_steps # Full trajectory length + grad_args.stepsize = stepsize + grad_args.sampler = sampler + grad_tracker.finalize(grad_args, output_folder) def main(args): @@ -331,6 +388,7 @@ def main(args): "class_labels_arg": args.class_labels, "uncond": args.uncond, "ebm": args.ebm, + "track_grad_norm": args.track_grad_norm, } with open(f"{args.out}/metadata.json", "w") as f: json.dump(metadata, f, indent=2) @@ -340,7 +398,7 @@ def main(args): print(f"\n{'=' * 60}") print(f"Running baseline with cfg_scale={args.cfg_scale} (capturing latents at switch steps)") print(f"{'=' * 60}") - captured_latents = run_baseline_and_capture( + captured_latents, baseline_grad_tracker = run_baseline_and_capture( model=ema_model, vae=vae, device=device, @@ -357,6 +415,7 @@ def main(args): output_folder=f"{args.out}/baseline_cfg{args.cfg_scale}", batch_size=args.batch_size, num_samples=args.num_samples, + track_grad_norm=args.track_grad_norm, ) # Run switch experiments using captured latents @@ -379,6 +438,8 @@ def main(args): output_folder=f"{args.out}/switch_{switch_step:04d}", batch_size=args.batch_size, num_samples=args.num_samples, + track_grad_norm=args.track_grad_norm, + baseline_grad_tracker=baseline_grad_tracker, ) print("\nAll experiments completed!") @@ -424,6 +485,9 @@ def main(args): required=True, help="Comma-separated list of steps at which to switch to null-only", ) + parser.add_argument( + "--track-grad-norm", action="store_true", help="Enable gradient norm tracking and visualization" + ) parser.add_argument("--uncond", type=bool, default=True, help="Enable noise conditioning") parser.add_argument( "--ebm", From 4cf5c321977345d6720ffd4f09228c9fdf53e0a7 Mon Sep 17 00:00:00 2001 From: Kevin On Date: Fri, 12 Dec 2025 06:02:01 +0000 Subject: [PATCH 3/7] Refactor IntermediateImageSaver and GradientNormTracker to improve output folder handling - Updated IntermediateImageSaver to accept a folder pattern for dynamic folder naming. - Modified GradientNormTracker's finalize method to directly accept parameters for sampling statistics. - Adjusted calls to these classes across multiple scripts to ensure consistent usage of the new parameters. --- deprecated/sample_eqm_two.py | 2 +- experiments/cfg_effect/main.py | 201 +++++++++++---------------------- flow_eqm_hybrid.py | 2 +- sample_eqm.py | 4 +- sample_from_clean.py | 2 +- utils/sampling_utils.py | 45 +++++--- 6 files changed, 102 insertions(+), 154 deletions(-) diff --git a/deprecated/sample_eqm_two.py b/deprecated/sample_eqm_two.py index 71fa4c5..05c8b62 100644 --- a/deprecated/sample_eqm_two.py +++ b/deprecated/sample_eqm_two.py @@ -117,7 +117,7 @@ def main(args): # Finalize gradient norm statistics if enabled if grad_tracker is not None: print("Computing gradient norm statistics...") - grad_tracker.finalize(args, args.out) + grad_tracker.finalize(args.out, args.num_sampling_steps, args.stepsize, args.sampler) # Create .npz files for FID evaluation print("Creating .npz file for final samples...") diff --git a/experiments/cfg_effect/main.py b/experiments/cfg_effect/main.py index 562b046..046a07b 100644 --- a/experiments/cfg_effect/main.py +++ b/experiments/cfg_effect/main.py @@ -20,7 +20,6 @@ import json import math import os -import sys import numpy as np import torch @@ -29,12 +28,8 @@ torch.backends.cudnn.allow_tf32 = True from diffusers.models import AutoencoderKL -from PIL import Image from tqdm import tqdm -# Add parent directory to path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) - from download import find_model from models import EqM_models from utils.sampling_utils import ( @@ -67,23 +62,23 @@ def __call__(self, context: SamplingHookContext): def run_baseline_and_capture( - model, - vae, - device, - initial_latent, - class_labels, - latent_size, - num_sampling_steps, - stepsize, - cfg_scale, - sampler, - mu, - save_steps_list, - switch_steps, - output_folder, - batch_size, - num_samples, - track_grad_norm=False, + model: torch.nn.Module, + vae: AutoencoderKL, + device: torch.device, + initial_latent: torch.Tensor, + class_labels: torch.Tensor, + latent_size: int, + num_sampling_steps: int, + stepsize: float, + cfg_scale: float, + sampler: str, + mu: float, + save_steps_list: list[int], + switch_steps: list[int], + output_folder: str, + batch_size: int, + num_samples: int, + track_grad_norm: bool = False, ): """ Run baseline with cfg_scale for all steps and capture intermediate latents. @@ -93,21 +88,29 @@ def run_baseline_and_capture( """ os.makedirs(output_folder, exist_ok=True) - final_step_folder = f"{output_folder}/step_{num_sampling_steps:04d}" - os.makedirs(final_step_folder, exist_ok=True) - total_samples = int(math.ceil(num_samples / batch_size) * batch_size) iterations = int(total_samples // batch_size) # Store captured latents per switch step all_captured = {step: [] for step in switch_steps} - # Create gradient tracker (persistent across batches) + # Create hooks + hooks = [] + + capture_hook = LatentCaptureHook(switch_steps) + hooks.append(capture_hook) + + if num_sampling_steps not in save_steps_list: + save_steps_list.append(num_sampling_steps) + img_saver = IntermediateImageSaver(save_steps_list, output_folder=output_folder) + hooks.append(img_saver) + grad_tracker = None if track_grad_norm: grad_tracker = GradientNormTracker(num_sampling_steps) - print("Created GradientNormTracker hook for baseline") + hooks.append(grad_tracker) + # Sampling loop total_saved = 0 for batch_idx in tqdm(range(iterations), desc="Baseline CFG"): batch_start = batch_idx * batch_size @@ -117,23 +120,7 @@ def run_baseline_and_capture( batch_latent = initial_latent[batch_start:batch_end] batch_labels = class_labels[batch_start:batch_end] - # Create hooks - hooks = [] - - # Capture hook for switch steps - capture_hook = LatentCaptureHook(switch_steps) - hooks.append(capture_hook) - - # Image saver hook - if save_steps_list: - img_saver = IntermediateImageSaver(save_steps_list, output_folder) - hooks.append(img_saver) - - # Gradient norm tracker hook - if grad_tracker is not None: - hooks.append(grad_tracker) - - samples = sample_eqm( + sample_eqm( model=model, vae=vae, device=device, @@ -154,83 +141,63 @@ def run_baseline_and_capture( if step in capture_hook.captured_latents: all_captured[step].append(capture_hook.captured_latents[step].cpu()) - # Save final samples - for i_sample, sample in enumerate(samples): - index = total_saved + i_sample - Image.fromarray(sample).save(f"{final_step_folder}/{index:06d}.png") - total_saved += actual_batch_size - print(f"Saved {total_saved} samples to {final_step_folder}") - create_npz_from_sample_folder(final_step_folder, num_samples) - if save_steps_list: for step in save_steps_list: - step_folder = f"{output_folder}/step_{step:04d}" + step_folder = f"{output_folder}/step_{step:03d}" if os.path.exists(step_folder): create_npz_from_sample_folder(step_folder, num_samples) # Finalize gradient norm tracking if grad_tracker is not None: # Create args-like object for finalize - class GradArgs: - pass - - grad_args = GradArgs() - grad_args.num_sampling_steps = num_sampling_steps - grad_args.stepsize = stepsize - grad_args.sampler = sampler - grad_tracker.finalize(grad_args, output_folder) + grad_tracker.finalize(output_folder, num_sampling_steps, stepsize, sampler) return all_captured, grad_tracker def run_switch_experiment( - model, - vae, - device, - captured_latents, - latent_size, - num_sampling_steps, - switch_step, - stepsize, - sampler, - mu, - save_steps_list, - output_folder, - batch_size, - num_samples, - track_grad_norm=False, - baseline_grad_tracker=None, + model: torch.nn.Module, + vae: AutoencoderKL, + device: torch.device, + captured_latents: list[torch.Tensor], + latent_size: int, + num_sampling_steps: int, + switch_step: int, + stepsize: float, + sampler: str, + mu: float, + save_steps_list: list[int], + output_folder: str, + track_grad_norm: bool = False, + baseline_grad_tracker: GradientNormTracker | None = None, ): """ Run switch experiment: start from captured latent at switch_step, continue with null class label (cfg=1.0) for remaining steps. """ os.makedirs(output_folder, exist_ok=True) - remaining_steps = num_sampling_steps - switch_step - final_step_folder = f"{output_folder}/step_{num_sampling_steps:04d}" - os.makedirs(final_step_folder, exist_ok=True) - # Adjust save_steps for remaining sampling - adjusted_save_steps = [] - if save_steps_list: - for s in save_steps_list: - if s > switch_step: - adjusted_save_steps.append(s - switch_step) + # Create hooks + hooks = [] + + adjusted_save_steps = [s - switch_step for s in (save_steps_list or []) if s > switch_step] + if remaining_steps not in adjusted_save_steps: # Always save the final step + adjusted_save_steps.append(remaining_steps) + img_saver = IntermediateImageSaver( + adjusted_save_steps, + folder_pattern=lambda ctx: f"{output_folder}/step_{ctx.step_idx + switch_step:03d}", + ) + hooks.append(img_saver) - # Create gradient tracker (persistent across batches) grad_tracker = None if track_grad_norm: grad_tracker = GradientNormTracker(remaining_steps) - print(f"Created GradientNormTracker hook for switch_{switch_step}") - - # Create image saver (persistent across batches) - img_saver = None - if adjusted_save_steps: - img_saver = IntermediateImageSaver(adjusted_save_steps, output_folder) + hooks.append(grad_tracker) + # Sampling loop total_saved = 0 for _, batch_latent in enumerate(tqdm(captured_latents, desc=f"Switch {switch_step}")): batch_latent = batch_latent.to(device) @@ -239,16 +206,7 @@ def run_switch_experiment( # Use null class labels (1000) null_labels = torch.tensor([1000] * actual_batch_size, device=device) - # Create hooks - hooks = [] - if img_saver is not None: - hooks.append(img_saver) - - # Gradient norm tracker hook - if grad_tracker is not None: - hooks.append(grad_tracker) - - samples = sample_eqm( + sample_eqm( model=model, vae=vae, device=device, @@ -264,26 +222,13 @@ def run_switch_experiment( hooks=hooks, ) - # Save final samples - for i_sample, sample in enumerate(samples): - index = total_saved + i_sample - Image.fromarray(sample).save(f"{final_step_folder}/{index:06d}.png") - total_saved += actual_batch_size - print(f"Saved {total_saved} samples to {final_step_folder}") - create_npz_from_sample_folder(final_step_folder, total_saved) - - # Rename intermediate step folders to match original step numbering - if img_saver is not None: - for adj_step in adjusted_save_steps: - orig_step = adj_step + switch_step - src_folder = f"{output_folder}/step_{adj_step:03d}" - dst_folder = f"{output_folder}/step_{orig_step:04d}" - if os.path.exists(src_folder): - os.rename(src_folder, dst_folder) - step_count = img_saver.step_counters[adj_step] - create_npz_from_sample_folder(dst_folder, step_count) + # Create .npz files for intermediate steps + for step in adjusted_save_steps: + step_folder = f"{output_folder}/step_{step + switch_step:03d}" + if os.path.exists(step_folder): + create_npz_from_sample_folder(step_folder, total_saved) # Finalize gradient norm tracking if grad_tracker is not None: @@ -292,15 +237,7 @@ def run_switch_experiment( baseline_norms = baseline_grad_tracker.gradient_norms[:switch_step] grad_tracker.gradient_norms = baseline_norms + grad_tracker.gradient_norms - # Create args-like object for finalize - class GradArgs: - pass - - grad_args = GradArgs() - grad_args.num_sampling_steps = num_sampling_steps # Full trajectory length - grad_args.stepsize = stepsize - grad_args.sampler = sampler - grad_tracker.finalize(grad_args, output_folder) + grad_tracker.finalize(output_folder, num_sampling_steps, stepsize, sampler) def main(args): @@ -435,9 +372,7 @@ def main(args): sampler=args.sampler, mu=args.mu, save_steps_list=save_steps_list, - output_folder=f"{args.out}/switch_{switch_step:04d}", - batch_size=args.batch_size, - num_samples=args.num_samples, + output_folder=f"{args.out}/switch_{switch_step:03d}", track_grad_norm=args.track_grad_norm, baseline_grad_tracker=baseline_grad_tracker, ) diff --git a/flow_eqm_hybrid.py b/flow_eqm_hybrid.py index 53290f7..9c810a5 100644 --- a/flow_eqm_hybrid.py +++ b/flow_eqm_hybrid.py @@ -129,7 +129,7 @@ def main(mode, args): eqm_hooks_by_fm_step = {} for flow_step_idx in fm_save_steps_list: step_folder = f"{output_dir}/fm_step_{flow_step_idx:03d}" - img_saver = IntermediateImageSaver(eqm_save_steps_list, step_folder) + img_saver = IntermediateImageSaver(eqm_save_steps_list, output_folder=step_folder) eqm_hooks_by_fm_step[flow_step_idx] = [img_saver] print(f"Starting batch processing: {num_batches} batches of size {batch_size}...") diff --git a/sample_eqm.py b/sample_eqm.py index 0a88467..972b52d 100644 --- a/sample_eqm.py +++ b/sample_eqm.py @@ -80,7 +80,7 @@ def main(args): if args.save_steps is not None: save_steps_list = [int(s.strip()) for s in args.save_steps.split(",")] img_saver = IntermediateImageSaver( - save_steps_list, f"{args.out}/{args.sampler}-{args.stepsize}-cfg{args.cfg_scale}" + save_steps_list, output_folder=f"{args.out}/{args.sampler}-{args.stepsize}-cfg{args.cfg_scale}" ) hooks.append(img_saver) print(f"Created IntermediateImageSaver hook for steps: {save_steps_list}") @@ -130,7 +130,7 @@ def main(args): # Finalize gradient norm statistics if enabled if grad_tracker is not None: print("Computing gradient norm statistics...") - grad_tracker.finalize(args, args.out) + grad_tracker.finalize(args.out, args.num_sampling_steps, args.stepsize, args.sampler) # Create .npz files for FID evaluation print("Creating .npz file for final samples...") diff --git a/sample_from_clean.py b/sample_from_clean.py index c076af4..2421f97 100644 --- a/sample_from_clean.py +++ b/sample_from_clean.py @@ -248,7 +248,7 @@ def main(args): # Finalize gradient norm statistics if enabled if grad_tracker is not None: print("Computing gradient norm statistics...") - grad_tracker.finalize(args, args.out) + grad_tracker.finalize(args.out, args.num_sampling_steps, args.stepsize, args.sampler) # Finalize distortion analysis print("Finalizing distortion analysis...") diff --git a/utils/sampling_utils.py b/utils/sampling_utils.py index 50df38c..cfdeade 100644 --- a/utils/sampling_utils.py +++ b/utils/sampling_utils.py @@ -26,13 +26,13 @@ >>> img_hook = IntermediateImageSaver([0, 50, 100], "outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) - >>> grad_hook.finalize(args, "outputs") + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") """ import json import os from dataclasses import dataclass -from typing import Any +from typing import Any, Callable import matplotlib.pyplot as plt import numpy as np @@ -92,7 +92,7 @@ def sample_eqm( >>> img_hook = IntermediateImageSaver([0, 50, 100, 249], "outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) - >>> grad_hook.finalize(args, "outputs") + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") """ if hooks is None: hooks = [] @@ -227,7 +227,7 @@ def sample_eqm_two( >>> img_hook = IntermediateImageSaver([0, 50, 100, 249], "outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) - >>> grad_hook.finalize(args, "outputs") + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") """ if hooks is None: hooks = [] @@ -354,22 +354,33 @@ class IntermediateImageSaver: Args: save_steps: List of step indices at which to save images (e.g., [0, 50, 100, 250]) - output_folder: Base folder for saving images + output_folder: Base folder for saving images. Required if folder_pattern is not provided. + folder_pattern: Callable (context) -> str that returns the folder path. + If not provided, defaults to "{output_folder}/step_{step_idx:03d}" """ - def __init__(self, save_steps, output_folder): + def __init__( + self, + save_steps: list[int], + output_folder: str | None = None, + folder_pattern: Callable[[SamplingHookContext], str] | None = None, + ): + if output_folder is None and folder_pattern is None: + raise ValueError("Either output_folder or folder_pattern must be provided") self.save_steps = set(save_steps) # Use set for O(1) lookup - self.output_folder = output_folder # Track global sample counter for each step to avoid overwriting self.step_counters = dict.fromkeys(save_steps, 0) - # self.return_images = return_images + if folder_pattern is None: + self.folder_pattern = lambda ctx: f"{output_folder}/step_{ctx.step_idx:03d}" + else: + self.folder_pattern = folder_pattern def __call__(self, context: SamplingHookContext): """Save images if current step is in save_steps list.""" if context.step_idx not in self.save_steps: return - step_folder = f"{self.output_folder}/step_{context.step_idx:03d}" + step_folder = self.folder_pattern(context) os.makedirs(step_folder, exist_ok=True) # Extract conditional part if using CFG @@ -629,14 +640,16 @@ def __call__(self, context: SamplingHookContext): norms = torch.linalg.norm(out_for_norm.reshape(out_for_norm.shape[0], -1), dim=1) # shape: (batch_size,) self.gradient_norms[context.step_idx - 1].extend(norms.cpu().tolist()) - def finalize(self, args, folder): + def finalize(self, folder: str, num_sampling_steps: int, stepsize: float, sampler: str): """ Compute statistics and create visualization for gradient norms. Call this after all sampling is complete. Args: - args: Arguments containing sampling parameters (num_sampling_steps, stepsize, sampler) folder: Output directory for saving JSON and plot + num_sampling_steps: Number of sampling steps + stepsize: Step size used during sampling + sampler: Sampler name (e.g., 'euler', 'heun') """ print("Computing gradient norm statistics...") gradient_means = [] @@ -652,12 +665,12 @@ def finalize(self, args, folder): # Save statistics to JSON stats = { - "num_sampling_steps": args.num_sampling_steps, + "num_sampling_steps": num_sampling_steps, "total_samples": len(self.gradient_norms[0]) if len(self.gradient_norms[0]) > 0 else 0, "mean": gradient_means, "std": gradient_stds, - "stepsize": args.stepsize, - "sampler": args.sampler, + "stepsize": stepsize, + "sampler": sampler, "note": "Statistics computed from individual gradient L2 norms across all samples (batch-size independent)", } json_path = f"{folder}/gradient_norms.json" @@ -667,7 +680,7 @@ def finalize(self, args, folder): # Create plot print("Creating gradient norm plot...") - steps = np.arange(0, args.num_sampling_steps) + steps = np.arange(0, num_sampling_steps) gradient_means = np.array(gradient_means) gradient_stds = np.array(gradient_stds) @@ -683,7 +696,7 @@ def finalize(self, args, folder): plt.xlabel("Sampling Step", fontsize=12) plt.ylabel("Gradient L2 Norm", fontsize=12) plt.title( - f"Gradient L2 Norm during Sampling ({args.sampler.upper()}, stepsize={args.stepsize})", + f"Gradient L2 Norm during Sampling ({sampler.upper()}, stepsize={stepsize})", fontsize=14, ) plt.legend(fontsize=10) From 722695ae5e85b8a61e87d793bcffaa5448d33135 Mon Sep 17 00:00:00 2001 From: Kevin On Date: Fri, 12 Dec 2025 06:15:33 +0000 Subject: [PATCH 4/7] Update IntermediateImageSaver instantiation to use 'output_folder' parameter for consistency across scripts --- deprecated/sample_eqm_two.py | 2 +- sample_from_clean.py | 2 +- utils/sampling_utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deprecated/sample_eqm_two.py b/deprecated/sample_eqm_two.py index 05c8b62..ad24c47 100644 --- a/deprecated/sample_eqm_two.py +++ b/deprecated/sample_eqm_two.py @@ -77,7 +77,7 @@ def main(args): save_steps_list = [] if args.save_steps is not None: save_steps_list = [int(s.strip()) for s in args.save_steps.split(",")] - img_saver = IntermediateImageSaver(save_steps_list, args.out) + img_saver = IntermediateImageSaver(save_steps_list, output_folder=args.out) hooks.append(img_saver) print(f"Created IntermediateImageSaver hook for steps: {save_steps_list}") diff --git a/sample_from_clean.py b/sample_from_clean.py index 2421f97..d478ff6 100644 --- a/sample_from_clean.py +++ b/sample_from_clean.py @@ -169,7 +169,7 @@ def main(args): hooks = [] if len(save_steps_list) > 0: - img_saver = IntermediateImageSaver(save_steps_list, args.out) + img_saver = IntermediateImageSaver(save_steps_list, output_folder=args.out) hooks.append(img_saver) print(f"Created IntermediateImageSaver hook for steps: {save_steps_list}") diff --git a/utils/sampling_utils.py b/utils/sampling_utils.py index c3128e6..e25700c 100644 --- a/utils/sampling_utils.py +++ b/utils/sampling_utils.py @@ -89,7 +89,7 @@ def sample_eqm( >>> >>> # With hooks for monitoring >>> from sampling_utils import IntermediateImageSaver, GradientNormTracker - >>> img_hook = IntermediateImageSaver([0, 50, 100, 249], "outputs") + >>> img_hook = IntermediateImageSaver([0, 50, 100, 249], output_folder="outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") From acfd9fd4e642530eed6e38573723b5990f1a2109 Mon Sep 17 00:00:00 2001 From: Kevin On Date: Fri, 12 Dec 2025 06:55:18 +0000 Subject: [PATCH 5/7] Refactor '--track-grad-norm' argument to use BooleanOptionalAction across multiple scripts for improved clarity and default behavior --- deprecated/sample_eqm_two.py | 5 ++++- experiments/cfg_effect/main.py | 5 ++++- sample_eqm.py | 5 ++++- sample_from_clean.py | 3 ++- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/deprecated/sample_eqm_two.py b/deprecated/sample_eqm_two.py index ad24c47..fba9882 100644 --- a/deprecated/sample_eqm_two.py +++ b/deprecated/sample_eqm_two.py @@ -168,7 +168,10 @@ def main(args): help="Comma-separated list of sampling steps to save intermediate images (e.g., '0,50,100,249')", ) parser.add_argument( - "--track-grad-norm", action="store_true", help="Enable gradient norm tracking and visualization" + "--track-grad-norm", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable gradient norm tracking and visualization", ) parser.add_argument("--uncond", type=bool, default=True, help="Disable/enable noise conditioning (default: True)") parser.add_argument( diff --git a/experiments/cfg_effect/main.py b/experiments/cfg_effect/main.py index 046a07b..a42141d 100644 --- a/experiments/cfg_effect/main.py +++ b/experiments/cfg_effect/main.py @@ -421,7 +421,10 @@ def main(args): help="Comma-separated list of steps at which to switch to null-only", ) parser.add_argument( - "--track-grad-norm", action="store_true", help="Enable gradient norm tracking and visualization" + "--track-grad-norm", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable gradient norm tracking and visualization", ) parser.add_argument("--uncond", type=bool, default=True, help="Enable noise conditioning") parser.add_argument( diff --git a/sample_eqm.py b/sample_eqm.py index 972b52d..82fce7e 100644 --- a/sample_eqm.py +++ b/sample_eqm.py @@ -187,7 +187,10 @@ def main(args): help="Comma-separated list of sampling steps to save intermediate images (e.g., '0,50,100,249')", ) parser.add_argument( - "--track-grad-norm", action="store_true", help="Enable gradient norm tracking and visualization" + "--track-grad-norm", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable gradient norm tracking and visualization", ) parser.add_argument("--uncond", type=bool, default=True, help="Disable/enable noise conditioning (default: True)") parser.add_argument( diff --git a/sample_from_clean.py b/sample_from_clean.py index d478ff6..82ba141 100644 --- a/sample_from_clean.py +++ b/sample_from_clean.py @@ -398,7 +398,8 @@ def main(args): ) parser.add_argument( "--track-grad-norm", - action="store_true", + action=argparse.BooleanOptionalAction, + default=True, help="Enable gradient norm tracking and visualization", ) From c84d146f1e6a7c9b699242462ba1b12c66c8210c Mon Sep 17 00:00:00 2001 From: Kevin On Date: Fri, 12 Dec 2025 07:03:39 +0000 Subject: [PATCH 6/7] Refactor sampling utilities to consolidate hooks into a new module - Moved IntermediateImageSaver, GradientNormTracker, and other sampling hooks from utils.sampling_utils to utils.sampling_hooks for better organization. - Updated import statements across multiple scripts to reflect the new module structure. - Improved documentation to clarify the usage of the new hook system. --- experiments/cfg_effect/main.py | 5 +- flow_eqm_hybrid.py | 3 +- sample_eqm.py | 3 +- train.py | 7 +- utils/sampling_hooks.py | 466 +++++++++++++++++++++++++++++++++ utils/sampling_utils.py | 444 +------------------------------ 6 files changed, 479 insertions(+), 449 deletions(-) create mode 100644 utils/sampling_hooks.py diff --git a/experiments/cfg_effect/main.py b/experiments/cfg_effect/main.py index a42141d..8e08d56 100644 --- a/experiments/cfg_effect/main.py +++ b/experiments/cfg_effect/main.py @@ -32,13 +32,12 @@ from download import find_model from models import EqM_models -from utils.sampling_utils import ( +from utils.sampling_hooks import ( GradientNormTracker, IntermediateImageSaver, SamplingHookContext, - create_npz_from_sample_folder, - sample_eqm, ) +from utils.sampling_utils import create_npz_from_sample_folder, sample_eqm class LatentCaptureHook: diff --git a/flow_eqm_hybrid.py b/flow_eqm_hybrid.py index 9c810a5..a187c4c 100644 --- a/flow_eqm_hybrid.py +++ b/flow_eqm_hybrid.py @@ -22,7 +22,8 @@ from models import EqM_models from transport import Sampler, create_transport from utils.arg_utils import parse_ode_args, parse_sde_args, parse_transport_args -from utils.sampling_utils import IntermediateImageSaver, decode_latents, sample_eqm +from utils.sampling_hooks import IntermediateImageSaver +from utils.sampling_utils import decode_latents, sample_eqm def main(mode, args): diff --git a/sample_eqm.py b/sample_eqm.py index 82fce7e..85bdbb7 100644 --- a/sample_eqm.py +++ b/sample_eqm.py @@ -18,7 +18,8 @@ from download import find_model from models import EqM_models -from utils.sampling_utils import GradientNormTracker, IntermediateImageSaver, create_npz_from_sample_folder, sample_eqm +from utils.sampling_hooks import GradientNormTracker, IntermediateImageSaver +from utils.sampling_utils import create_npz_from_sample_folder, sample_eqm def main(args): diff --git a/train.py b/train.py index e11696d..d79e542 100644 --- a/train.py +++ b/train.py @@ -36,11 +36,8 @@ parse_sample_args, parse_transport_args, ) -from utils.sampling_utils import ( - GradientNormTracker, - WandBImageLogger, - sample_eqm, -) +from utils.sampling_hooks import GradientNormTracker, WandBImageLogger +from utils.sampling_utils import sample_eqm from utils.utils import imagenet_label_from_idx try: diff --git a/utils/sampling_hooks.py b/utils/sampling_hooks.py new file mode 100644 index 0000000..452e17f --- /dev/null +++ b/utils/sampling_hooks.py @@ -0,0 +1,466 @@ +""" +Sampling hooks for EqM models. + +This module provides hook classes for monitoring, logging, and analysis during sampling. + +Hook System: + Hooks are callables that receive a SamplingHookContext at each sampling step. + They enable monitoring, logging, and analysis without cluttering the core sampling loop. + +Main Components: + - SamplingHookContext: Context object passed to hooks during sampling + - IntermediateImageSaver: Hook for saving intermediate images at specified steps + - WandBImageLogger: Hook for logging intermediate images to WandB + - DistortionTracker: Hook for analyzing image distortion and high-frequency correlation + - GradientNormTracker: Hook for tracking and analyzing gradient norms + +Example Usage: + >>> from utils.sampling_hooks import IntermediateImageSaver, GradientNormTracker + >>> img_hook = IntermediateImageSaver([0, 50, 100], "outputs") + >>> grad_hook = GradientNormTracker(num_sampling_steps=250) + >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) + >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") +""" + +import json +import os +from dataclasses import dataclass +from typing import Any, Callable + +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image + + +@dataclass +class SamplingHookContext: + """Context object passed to sampling hooks containing all relevant state.""" + + xt: torch.Tensor # Current latent state + t: torch.Tensor # Current timestep + y: torch.Tensor # Class labels + out: torch.Tensor # Model output/gradient + step_idx: int # Current step index (1-indexed) + use_cfg: bool # Whether CFG is enabled + vae: Any # VAE decoder for image conversion + device: torch.device # Device + total_steps: int # Total number of sampling steps + + +class IntermediateImageSaver: + """ + Hook for saving intermediate images during sampling. + + Args: + save_steps: List of step indices at which to save images (e.g., [0, 50, 100, 250]) + output_folder: Base folder for saving images. Required if folder_pattern is not provided. + folder_pattern: Callable (context) -> str that returns the folder path. + If not provided, defaults to "{output_folder}/step_{step_idx:03d}" + """ + + def __init__( + self, + save_steps: list[int], + output_folder: str | None = None, + folder_pattern: Callable[[SamplingHookContext], str] | None = None, + ): + if output_folder is None and folder_pattern is None: + raise ValueError("Either output_folder or folder_pattern must be provided") + self.save_steps = set(save_steps) # Use set for O(1) lookup + # Track global sample counter for each step to avoid overwriting + self.step_counters = dict.fromkeys(save_steps, 0) + if folder_pattern is None: + self.folder_pattern = lambda ctx: f"{output_folder}/step_{ctx.step_idx:03d}" + else: + self.folder_pattern = folder_pattern + + def __call__(self, context: SamplingHookContext): + """Save images if current step is in save_steps list.""" + from utils.sampling_utils import decode_latents + + if context.step_idx not in self.save_steps: + return + + step_folder = self.folder_pattern(context) + os.makedirs(step_folder, exist_ok=True) + + # Extract conditional part if using CFG + xt_save = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_save = context.xt[:batch_size] + + samples = decode_latents(context.vae, xt_save) + + # Save images with global sequential indexing across batches + start_idx = self.step_counters[context.step_idx] + for i_sample, sample in enumerate(samples): + global_idx = start_idx + i_sample + Image.fromarray(sample).save(f"{step_folder}/{global_idx:06d}.png") + + # Update counter for this step + self.step_counters[context.step_idx] += len(samples) + + +class WandBImageLogger: + """ + Hook for logging intermediate images during sampling directly to WandB. + + Args: + save_steps: List of step indices at which to log images (e.g., [5, 10, 250]) + train_step: Current training step (for WandB logging) + output_folder: Folder to save logged images + wandb_module: wandb module (pass wandb if imported, or None to skip logging) + """ + + def __init__(self, save_steps, train_step, output_folder, wandb_module=None): + self.save_steps = set(save_steps) + self.train_step = train_step + self.output_folder = output_folder + self.wandb = wandb_module + self.logged_images = {step: [] for step in save_steps} + self.step_counters = dict.fromkeys(save_steps, 0) + + def __call__(self, context: SamplingHookContext): + """Log images to WandB if current step is in save_steps list.""" + from utils.sampling_utils import decode_latents + + if context.step_idx not in self.save_steps or self.wandb is None: + return + + folder = f"{self.output_folder}/train_{self.train_step:04d}/sample_{context.step_idx:03d}" + os.makedirs(folder, exist_ok=True) + + # Extract conditional part if using CFG + xt_save = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_save = context.xt[:batch_size] + + # Decode latents to images + samples = decode_latents(context.vae, xt_save) + + # Convert to wandb.Image objects + start_idx = self.step_counters[context.step_idx] + for i_sample, sample in enumerate(samples): + global_idx = start_idx + i_sample + img = Image.fromarray(sample) + img.save(f"{folder}/{global_idx:03d}.png") + self.logged_images[context.step_idx].append(self.wandb.Image(img, caption=f"Sample {global_idx:03d}")) + + def finalize(self): + """Log all collected images to WandB. Call this after sampling is complete.""" + if self.wandb is None: + return + + for step_idx in sorted(self.save_steps): + if len(self.logged_images[step_idx]) > 0: + self.wandb.log({f"samples/step_{step_idx:03d}": self.logged_images[step_idx]}, step=self.train_step) + + +class DistortionTracker: + """ + Hook for tracking distortion of clean images during sampling. + + Computes L2 distance between original and current latents at specified steps, + correlates with high-frequency content, and saves top distorted/undistorted images. + + Args: + original_latents: Original clean latents, shape (batch_size, 4, H, W) + high_freq_metrics: High-frequency content metrics per image, shape (batch_size,) + save_steps: List of step indices at which to track distortion + output_folder: Base folder for saving results + top_n: Number of top distorted/undistorted images to save per step + """ + + def __init__(self, original_latents, high_freq_metrics, save_steps, output_folder, top_n=10): + self.original_latents = original_latents.clone() + self.high_freq_metrics = high_freq_metrics + self.save_steps = set(save_steps) + self.output_folder = output_folder + self.top_n = top_n + + # Storage for distortion metrics at each step + # Key: step_idx, Value: list of L2 distances + self.distortions = {step: [] for step in save_steps} + + # Storage for batch information + # Key: step_idx, Value: list of (batch_start_idx, batch_latents) + self.latent_batches = {step: [] for step in save_steps} + + self.batch_counter = 0 + + def __call__(self, context: SamplingHookContext): + """Track distortion if current step is in save_steps list.""" + if context.step_idx not in self.save_steps: + return + + # Extract conditional part if using CFG + xt_current = context.xt + if context.use_cfg: + batch_size = context.xt.shape[0] // 2 + xt_current = context.xt[:batch_size] + + # Get corresponding original latents for this batch + batch_size = xt_current.shape[0] + batch_start = self.batch_counter + batch_end = batch_start + batch_size + original_batch = self.original_latents[batch_start:batch_end] + + # Compute L2 distance in latent space for each sample + diff = xt_current - original_batch + l2_distances = torch.linalg.norm(diff.reshape(diff.shape[0], -1), dim=1) # shape: (batch_size,) + + # Store distortion metrics + self.distortions[context.step_idx].extend(l2_distances.cpu().tolist()) + + # Store latents for later saving + self.latent_batches[context.step_idx].append((batch_start, xt_current.cpu().clone())) + + def on_batch_complete(self, batch_size): + """Call this after each batch to update the batch counter.""" + self.batch_counter += batch_size + + def finalize(self, vae, output_folder): + """ + Compute statistics, save top images, and create visualizations. + Call this after all sampling is complete. + + Args: + vae: VAE decoder for converting latents to images + output_folder: Output directory for saving results + """ + from utils.sampling_utils import decode_latents + + print("Analyzing distortion metrics and creating visualizations...") + + results = {"steps": {}, "high_freq_metrics": self.high_freq_metrics.tolist()} + + for step_idx in sorted(self.save_steps): + print(f" Processing step {step_idx}...") + + # Get all distortions for this step + distortions = np.array(self.distortions[step_idx]) + + if len(distortions) == 0: + print(f" Warning: No distortions recorded for step {step_idx}") + continue + + # Reconstruct full latent tensor from batches + latent_list = [] + for _batch_start, batch_latents in sorted(self.latent_batches[step_idx]): + latent_list.append(batch_latents) + all_latents = torch.cat(latent_list, dim=0) + + # Get indices of top-N most and least distorted + top_distorted_indices = np.argsort(distortions)[-self.top_n :][::-1] + top_undistorted_indices = np.argsort(distortions)[: self.top_n] + + # Save top distorted images + distorted_folder = f"{output_folder}/step_{step_idx:03d}/top_distorted" + os.makedirs(distorted_folder, exist_ok=True) + for rank, idx in enumerate(top_distorted_indices): + latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device + image = decode_latents(vae, latent)[0] + Image.fromarray(image).save( + f"{distorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" + ) + + # Save top undistorted images + undistorted_folder = f"{output_folder}/step_{step_idx:03d}/top_undistorted" + os.makedirs(undistorted_folder, exist_ok=True) + for rank, idx in enumerate(top_undistorted_indices): + latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device + image = decode_latents(vae, latent)[0] + Image.fromarray(image).save( + f"{undistorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" + ) + + # Compute correlation with high-frequency content + # Use only valid indices (in case of batch size mismatch) + valid_indices = min(len(distortions), len(self.high_freq_metrics)) + distortions_valid = distortions[:valid_indices] + high_freq_valid = self.high_freq_metrics[:valid_indices] + + correlation = np.corrcoef(high_freq_valid, distortions_valid)[0, 1] + + # Create scatter plot + plt.figure(figsize=(10, 8)) + plt.scatter(high_freq_valid, distortions_valid, alpha=0.5, s=20) + plt.xlabel("High-Frequency Content (Ratio)", fontsize=12) + plt.ylabel("L2 Distance from Original (Latent Space)", fontsize=12) + plt.title( + f"Distortion vs High-Frequency Content at Step {step_idx}\nPearson Correlation: {correlation:.4f}", + fontsize=14, + ) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + plot_path = f"{output_folder}/correlation_step_{step_idx:03d}.png" + plt.savefig(plot_path, dpi=150) + plt.close() + print(f" Saved correlation plot to {plot_path}") + + # Store statistics + results["steps"][str(step_idx)] = { + "mean_distortion": float(np.mean(distortions)), + "std_distortion": float(np.std(distortions)), + "min_distortion": float(np.min(distortions)), + "max_distortion": float(np.max(distortions)), + "correlation_with_high_freq": float(correlation), + "num_samples": len(distortions), + "top_distorted_indices": top_distorted_indices.tolist(), + "top_undistorted_indices": top_undistorted_indices.tolist(), + "top_distorted_values": distortions[top_distorted_indices].tolist(), + "top_undistorted_values": distortions[top_undistorted_indices].tolist(), + } + + # Save results to JSON + json_path = f"{output_folder}/distortion_analysis.json" + with open(json_path, "w") as f: + json.dump(results, f, indent=2) + print(f"Saved distortion analysis to {json_path}") + + +class GradientNormTracker: + """ + Hook for tracking gradient L2 norms during sampling. + + Args: + num_steps: Number of sampling steps (for pre-allocating storage) + """ + + def __init__(self, num_steps): + self.gradient_norms = [[] for _ in range(num_steps)] + + def __call__(self, context: SamplingHookContext): + """Accumulate gradient L2 norms for the current step.""" + # Extract conditional part if using CFG + out_for_norm = context.out + if context.use_cfg: + batch_size = context.out.shape[0] // 2 + out_for_norm = context.out[:batch_size] + + # Compute L2 norm for each sample in the batch + norms = torch.linalg.norm(out_for_norm.reshape(out_for_norm.shape[0], -1), dim=1) # shape: (batch_size,) + self.gradient_norms[context.step_idx - 1].extend(norms.cpu().tolist()) + + def finalize(self, folder: str, num_sampling_steps: int, stepsize: float, sampler: str): + """ + Compute statistics and create visualization for gradient norms. + Call this after all sampling is complete. + + Args: + folder: Output directory for saving JSON and plot + num_sampling_steps: Number of sampling steps + stepsize: Step size used during sampling + sampler: Sampler name (e.g., 'euler', 'heun') + """ + print("Computing gradient norm statistics...") + gradient_means = [] + gradient_stds = [] + + for step_norms in self.gradient_norms: + if len(step_norms) > 0: + gradient_means.append(np.mean(step_norms)) + gradient_stds.append(np.std(step_norms)) + else: + gradient_means.append(0.0) + gradient_stds.append(0.0) + + # Save statistics to JSON + stats = { + "num_sampling_steps": num_sampling_steps, + "total_samples": len(self.gradient_norms[0]) if len(self.gradient_norms[0]) > 0 else 0, + "mean": gradient_means, + "std": gradient_stds, + "stepsize": stepsize, + "sampler": sampler, + "note": "Statistics computed from individual gradient L2 norms across all samples (batch-size independent)", + } + json_path = f"{folder}/gradient_norms.json" + with open(json_path, "w") as f: + json.dump(stats, f, indent=2) + print(f"Saved gradient norm statistics to {json_path}") + + # Create plot + print("Creating gradient norm plot...") + steps = np.arange(0, num_sampling_steps) + gradient_means = np.array(gradient_means) + gradient_stds = np.array(gradient_stds) + + plt.figure(figsize=(10, 6)) + plt.plot(steps, gradient_means, linewidth=2, label="Mean L2 Norm") + plt.fill_between( + steps, + gradient_means - gradient_stds, + gradient_means + gradient_stds, + alpha=0.3, + label="Mean ± Std", + ) + plt.xlabel("Sampling Step", fontsize=12) + plt.ylabel("Gradient L2 Norm", fontsize=12) + plt.title( + f"Gradient L2 Norm during Sampling ({sampler.upper()}, stepsize={stepsize})", + fontsize=14, + ) + plt.legend(fontsize=10) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + plot_path = f"{folder}/gradient_norms.png" + plt.savefig(plot_path, dpi=150) + print(f"Saved gradient norm plot to {plot_path}") + plt.close() + + def finalize_for_wandb(self, wandb_module, train_step): + """ + Log gradient norm statistics to WandB. + Call this after all sampling is complete. + + Args: + wandb_module: wandb module for logging + train_step: Current training step for WandB logging + """ + if wandb_module is None: + return + + # Compute statistics + gradient_means = [] + gradient_stds = [] + + for step_norms in self.gradient_norms: + if len(step_norms) > 0: + gradient_means.append(np.mean(step_norms)) + gradient_stds.append(np.std(step_norms)) + else: + gradient_means.append(0.0) + gradient_stds.append(0.0) + + # Create a table for gradient norms vs sampling steps + # This allows WandB to plot with sampling_step on x-axis + data = [] + for sampling_step, (mean_norm, std_norm) in enumerate(zip(gradient_means, gradient_stds)): + data.append( + [ + sampling_step, + mean_norm, + std_norm, + mean_norm + std_norm, # Upper bound + mean_norm - std_norm, # Lower bound + ] + ) + + table = wandb_module.Table( + columns=[ + "sampling_step", + "gradient_norm_mean", + "gradient_norm_std", + "gradient_norm_upper", + "gradient_norm_lower", + ], + data=data, + ) + # Log the table with a consistent key (train_step is already in the table data) + wandb_module.log({"gradient_norms": table}, step=train_step) diff --git a/utils/sampling_utils.py b/utils/sampling_utils.py index e25700c..cfcbce3 100644 --- a/utils/sampling_utils.py +++ b/utils/sampling_utils.py @@ -1,45 +1,38 @@ """ Sampling utility functions for EqM models. -This module provides reusable sampling functions and hooks for Equilibrium Matching models. +This module provides reusable sampling functions for Equilibrium Matching models. Main Components: - sample_eqm(): Core sampling function with GD/NAG-GD support - - SamplingHookContext: Context object passed to hooks during sampling - - IntermediateImageSaver: Hook for saving intermediate images at specified steps - - GradientNormTracker: Hook for tracking and analyzing gradient norms - - DistortionTracker: Hook for analyzing image distortion and high-frequency correlation - create_npz_from_sample_folder(): Create .npz file for FID evaluation - decode_latents(): Decode VAE latents to images - encode_images_to_latent(): Encode images to VAE latents - compute_high_frequency_content(): Compute high-frequency content metric using FFT -Hook System: - Hooks are callables that receive a SamplingHookContext at each sampling step. - They enable monitoring, logging, and analysis without cluttering the core sampling loop. +For sampling hooks, see utils.sampling_hooks module. Example Usage: >>> # Basic sampling >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32) >>> >>> # Sampling with hooks + >>> from utils.sampling_hooks import IntermediateImageSaver, GradientNormTracker >>> img_hook = IntermediateImageSaver([0, 50, 100], "outputs") >>> grad_hook = GradientNormTracker(num_sampling_steps=250) >>> samples = sample_eqm(model, vae, device, batch_size=16, latent_size=32, hooks=[img_hook, grad_hook]) >>> grad_hook.finalize("outputs", num_sampling_steps=250, stepsize=1.0, sampler="euler") """ -import json import os -from dataclasses import dataclass -from typing import Any, Callable -import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image from tqdm import tqdm +from utils.sampling_hooks import SamplingHookContext + @torch.no_grad() def sample_eqm( @@ -333,433 +326,6 @@ def sample_eqm_two( return samples -@dataclass -class SamplingHookContext: - """Context object passed to sampling hooks containing all relevant state.""" - - xt: torch.Tensor # Current latent state - t: torch.Tensor # Current timestep - y: torch.Tensor # Class labels - out: torch.Tensor # Model output/gradient - step_idx: int # Current step index (1-indexed) - use_cfg: bool # Whether CFG is enabled - vae: Any # VAE decoder for image conversion - device: torch.device # Device - total_steps: int # Total number of sampling steps - - -class IntermediateImageSaver: - """ - Hook for saving intermediate images during sampling. - - Args: - save_steps: List of step indices at which to save images (e.g., [0, 50, 100, 250]) - output_folder: Base folder for saving images. Required if folder_pattern is not provided. - folder_pattern: Callable (context) -> str that returns the folder path. - If not provided, defaults to "{output_folder}/step_{step_idx:03d}" - """ - - def __init__( - self, - save_steps: list[int], - output_folder: str | None = None, - folder_pattern: Callable[[SamplingHookContext], str] | None = None, - ): - if output_folder is None and folder_pattern is None: - raise ValueError("Either output_folder or folder_pattern must be provided") - self.save_steps = set(save_steps) # Use set for O(1) lookup - # Track global sample counter for each step to avoid overwriting - self.step_counters = dict.fromkeys(save_steps, 0) - if folder_pattern is None: - self.folder_pattern = lambda ctx: f"{output_folder}/step_{ctx.step_idx:03d}" - else: - self.folder_pattern = folder_pattern - - def __call__(self, context: SamplingHookContext): - """Save images if current step is in save_steps list.""" - if context.step_idx not in self.save_steps: - return - - step_folder = self.folder_pattern(context) - os.makedirs(step_folder, exist_ok=True) - - # Extract conditional part if using CFG - xt_save = context.xt - if context.use_cfg: - batch_size = context.xt.shape[0] // 2 - xt_save = context.xt[:batch_size] - - samples = decode_latents(context.vae, xt_save) - - # Save images with global sequential indexing across batches - start_idx = self.step_counters[context.step_idx] - for i_sample, sample in enumerate(samples): - global_idx = start_idx + i_sample - Image.fromarray(sample).save(f"{step_folder}/{global_idx:06d}.png") - - # Update counter for this step - self.step_counters[context.step_idx] += len(samples) - - -class WandBImageLogger: - """ - Hook for logging intermediate images during sampling directly to WandB. - - Args: - save_steps: List of step indices at which to log images (e.g., [5, 10, 250]) - train_step: Current training step (for WandB logging) - output_folder: Folder to save logged images - wandb_module: wandb module (pass wandb if imported, or None to skip logging) - """ - - def __init__(self, save_steps, train_step, output_folder, wandb_module=None): - self.save_steps = set(save_steps) - self.train_step = train_step - self.output_folder = output_folder - self.wandb = wandb_module - self.logged_images = {step: [] for step in save_steps} - self.step_counters = dict.fromkeys(save_steps, 0) - - def __call__(self, context: SamplingHookContext): - """Log images to WandB if current step is in save_steps list.""" - if context.step_idx not in self.save_steps or self.wandb is None: - return - - folder = f"{self.output_folder}/train_{self.train_step:04d}/sample_{context.step_idx:03d}" - os.makedirs(folder, exist_ok=True) - - # Extract conditional part if using CFG - xt_save = context.xt - if context.use_cfg: - batch_size = context.xt.shape[0] // 2 - xt_save = context.xt[:batch_size] - - # Decode latents to images - samples = decode_latents(context.vae, xt_save) - - # Convert to wandb.Image objects - start_idx = self.step_counters[context.step_idx] - for i_sample, sample in enumerate(samples): - global_idx = start_idx + i_sample - img = Image.fromarray(sample) - img.save(f"{folder}/{global_idx:03d}.png") - self.logged_images[context.step_idx].append(self.wandb.Image(img, caption=f"Sample {global_idx:03d}")) - - def finalize(self): - """Log all collected images to WandB. Call this after sampling is complete.""" - if self.wandb is None: - return - - for step_idx in sorted(self.save_steps): - if len(self.logged_images[step_idx]) > 0: - self.wandb.log({f"samples/step_{step_idx:03d}": self.logged_images[step_idx]}, step=self.train_step) - - -class DistortionTracker: - """ - Hook for tracking distortion of clean images during sampling. - - Computes L2 distance between original and current latents at specified steps, - correlates with high-frequency content, and saves top distorted/undistorted images. - - Args: - original_latents: Original clean latents, shape (batch_size, 4, H, W) - high_freq_metrics: High-frequency content metrics per image, shape (batch_size,) - save_steps: List of step indices at which to track distortion - output_folder: Base folder for saving results - top_n: Number of top distorted/undistorted images to save per step - """ - - def __init__(self, original_latents, high_freq_metrics, save_steps, output_folder, top_n=10): - self.original_latents = original_latents.clone() - self.high_freq_metrics = high_freq_metrics - self.save_steps = set(save_steps) - self.output_folder = output_folder - self.top_n = top_n - - # Storage for distortion metrics at each step - # Key: step_idx, Value: list of L2 distances - self.distortions = {step: [] for step in save_steps} - - # Storage for batch information - # Key: step_idx, Value: list of (batch_start_idx, batch_latents) - self.latent_batches = {step: [] for step in save_steps} - - self.batch_counter = 0 - - def __call__(self, context: SamplingHookContext): - """Track distortion if current step is in save_steps list.""" - if context.step_idx not in self.save_steps: - return - - # Extract conditional part if using CFG - xt_current = context.xt - if context.use_cfg: - batch_size = context.xt.shape[0] // 2 - xt_current = context.xt[:batch_size] - - # Get corresponding original latents for this batch - batch_size = xt_current.shape[0] - batch_start = self.batch_counter - batch_end = batch_start + batch_size - original_batch = self.original_latents[batch_start:batch_end] - - # Compute L2 distance in latent space for each sample - diff = xt_current - original_batch - l2_distances = torch.linalg.norm(diff.reshape(diff.shape[0], -1), dim=1) # shape: (batch_size,) - - # Store distortion metrics - self.distortions[context.step_idx].extend(l2_distances.cpu().tolist()) - - # Store latents for later saving - self.latent_batches[context.step_idx].append((batch_start, xt_current.cpu().clone())) - - def on_batch_complete(self, batch_size): - """Call this after each batch to update the batch counter.""" - self.batch_counter += batch_size - - def finalize(self, vae, output_folder): - """ - Compute statistics, save top images, and create visualizations. - Call this after all sampling is complete. - - Args: - vae: VAE decoder for converting latents to images - output_folder: Output directory for saving results - """ - print("Analyzing distortion metrics and creating visualizations...") - - results = {"steps": {}, "high_freq_metrics": self.high_freq_metrics.tolist()} - - for step_idx in sorted(self.save_steps): - print(f" Processing step {step_idx}...") - - # Get all distortions for this step - distortions = np.array(self.distortions[step_idx]) - - if len(distortions) == 0: - print(f" Warning: No distortions recorded for step {step_idx}") - continue - - # Reconstruct full latent tensor from batches - latent_list = [] - for _batch_start, batch_latents in sorted(self.latent_batches[step_idx]): - latent_list.append(batch_latents) - all_latents = torch.cat(latent_list, dim=0) - - # Get indices of top-N most and least distorted - top_distorted_indices = np.argsort(distortions)[-self.top_n :][::-1] - top_undistorted_indices = np.argsort(distortions)[: self.top_n] - - # Save top distorted images - distorted_folder = f"{output_folder}/step_{step_idx:03d}/top_distorted" - os.makedirs(distorted_folder, exist_ok=True) - for rank, idx in enumerate(top_distorted_indices): - latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device - image = decode_latents(vae, latent)[0] - Image.fromarray(image).save( - f"{distorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" - ) - - # Save top undistorted images - undistorted_folder = f"{output_folder}/step_{step_idx:03d}/top_undistorted" - os.makedirs(undistorted_folder, exist_ok=True) - for rank, idx in enumerate(top_undistorted_indices): - latent = all_latents[idx : idx + 1].to(vae.device) # Move to VAE device - image = decode_latents(vae, latent)[0] - Image.fromarray(image).save( - f"{undistorted_folder}/rank{rank:02d}_idx{idx:06d}_dist{distortions[idx]:.4f}.png" - ) - - # Compute correlation with high-frequency content - # Use only valid indices (in case of batch size mismatch) - valid_indices = min(len(distortions), len(self.high_freq_metrics)) - distortions_valid = distortions[:valid_indices] - high_freq_valid = self.high_freq_metrics[:valid_indices] - - correlation = np.corrcoef(high_freq_valid, distortions_valid)[0, 1] - - # Create scatter plot - plt.figure(figsize=(10, 8)) - plt.scatter(high_freq_valid, distortions_valid, alpha=0.5, s=20) - plt.xlabel("High-Frequency Content (Ratio)", fontsize=12) - plt.ylabel("L2 Distance from Original (Latent Space)", fontsize=12) - plt.title( - f"Distortion vs High-Frequency Content at Step {step_idx}\nPearson Correlation: {correlation:.4f}", - fontsize=14, - ) - plt.grid(True, alpha=0.3) - plt.tight_layout() - - plot_path = f"{output_folder}/correlation_step_{step_idx:03d}.png" - plt.savefig(plot_path, dpi=150) - plt.close() - print(f" Saved correlation plot to {plot_path}") - - # Store statistics - results["steps"][str(step_idx)] = { - "mean_distortion": float(np.mean(distortions)), - "std_distortion": float(np.std(distortions)), - "min_distortion": float(np.min(distortions)), - "max_distortion": float(np.max(distortions)), - "correlation_with_high_freq": float(correlation), - "num_samples": len(distortions), - "top_distorted_indices": top_distorted_indices.tolist(), - "top_undistorted_indices": top_undistorted_indices.tolist(), - "top_distorted_values": distortions[top_distorted_indices].tolist(), - "top_undistorted_values": distortions[top_undistorted_indices].tolist(), - } - - # Save results to JSON - json_path = f"{output_folder}/distortion_analysis.json" - with open(json_path, "w") as f: - json.dump(results, f, indent=2) - print(f"Saved distortion analysis to {json_path}") - - -class GradientNormTracker: - """ - Hook for tracking gradient L2 norms during sampling. - - Args: - num_steps: Number of sampling steps (for pre-allocating storage) - """ - - def __init__(self, num_steps): - self.gradient_norms = [[] for _ in range(num_steps)] - - def __call__(self, context: SamplingHookContext): - """Accumulate gradient L2 norms for the current step.""" - # Extract conditional part if using CFG - out_for_norm = context.out - if context.use_cfg: - batch_size = context.out.shape[0] // 2 - out_for_norm = context.out[:batch_size] - - # Compute L2 norm for each sample in the batch - norms = torch.linalg.norm(out_for_norm.reshape(out_for_norm.shape[0], -1), dim=1) # shape: (batch_size,) - self.gradient_norms[context.step_idx - 1].extend(norms.cpu().tolist()) - - def finalize(self, folder: str, num_sampling_steps: int, stepsize: float, sampler: str): - """ - Compute statistics and create visualization for gradient norms. - Call this after all sampling is complete. - - Args: - folder: Output directory for saving JSON and plot - num_sampling_steps: Number of sampling steps - stepsize: Step size used during sampling - sampler: Sampler name (e.g., 'euler', 'heun') - """ - print("Computing gradient norm statistics...") - gradient_means = [] - gradient_stds = [] - - for step_norms in self.gradient_norms: - if len(step_norms) > 0: - gradient_means.append(np.mean(step_norms)) - gradient_stds.append(np.std(step_norms)) - else: - gradient_means.append(0.0) - gradient_stds.append(0.0) - - # Save statistics to JSON - stats = { - "num_sampling_steps": num_sampling_steps, - "total_samples": len(self.gradient_norms[0]) if len(self.gradient_norms[0]) > 0 else 0, - "mean": gradient_means, - "std": gradient_stds, - "stepsize": stepsize, - "sampler": sampler, - "note": "Statistics computed from individual gradient L2 norms across all samples (batch-size independent)", - } - json_path = f"{folder}/gradient_norms.json" - with open(json_path, "w") as f: - json.dump(stats, f, indent=2) - print(f"Saved gradient norm statistics to {json_path}") - - # Create plot - print("Creating gradient norm plot...") - steps = np.arange(0, num_sampling_steps) - gradient_means = np.array(gradient_means) - gradient_stds = np.array(gradient_stds) - - plt.figure(figsize=(10, 6)) - plt.plot(steps, gradient_means, linewidth=2, label="Mean L2 Norm") - plt.fill_between( - steps, - gradient_means - gradient_stds, - gradient_means + gradient_stds, - alpha=0.3, - label="Mean ± Std", - ) - plt.xlabel("Sampling Step", fontsize=12) - plt.ylabel("Gradient L2 Norm", fontsize=12) - plt.title( - f"Gradient L2 Norm during Sampling ({sampler.upper()}, stepsize={stepsize})", - fontsize=14, - ) - plt.legend(fontsize=10) - plt.grid(True, alpha=0.3) - plt.tight_layout() - - plot_path = f"{folder}/gradient_norms.png" - plt.savefig(plot_path, dpi=150) - print(f"Saved gradient norm plot to {plot_path}") - plt.close() - - def finalize_for_wandb(self, wandb_module, train_step): - """ - Log gradient norm statistics to WandB. - Call this after all sampling is complete. - - Args: - wandb_module: wandb module for logging - train_step: Current training step for WandB logging - """ - if wandb_module is None: - return - - # Compute statistics - gradient_means = [] - gradient_stds = [] - - for step_norms in self.gradient_norms: - if len(step_norms) > 0: - gradient_means.append(np.mean(step_norms)) - gradient_stds.append(np.std(step_norms)) - else: - gradient_means.append(0.0) - gradient_stds.append(0.0) - - # Create a table for gradient norms vs sampling steps - # This allows WandB to plot with sampling_step on x-axis - data = [] - for sampling_step, (mean_norm, std_norm) in enumerate(zip(gradient_means, gradient_stds)): - data.append( - [ - sampling_step, - mean_norm, - std_norm, - mean_norm + std_norm, # Upper bound - mean_norm - std_norm, # Lower bound - ] - ) - - table = wandb_module.Table( - columns=[ - "sampling_step", - "gradient_norm_mean", - "gradient_norm_std", - "gradient_norm_upper", - "gradient_norm_lower", - ], - data=data, - ) - # Log the table with a consistent key (train_step is already in the table data) - wandb_module.log({"gradient_norms": table}, step=train_step) - - def create_npz_from_sample_folder(sample_dir, num=None): """ Builds a single .npz file from a folder of .png samples. From 9acff90a249a6e6cd7c2f3f1067f00c894b95b44 Mon Sep 17 00:00:00 2001 From: Kevin On Date: Fri, 12 Dec 2025 09:10:31 +0000 Subject: [PATCH 7/7] Enhance EqM forward pass and gradient norm tracking for CFG - Updated the `forward_with_cfg` method in `models.py` to include a new argument `return_components`, allowing the return of individual conditional and unconditional outputs. - Modified `GradientNormTracker` in `sampling_hooks.py` to track and compute statistics for conditional and unconditional gradient norms when CFG is enabled. - Adjusted sampling functions in `sampling_utils.py` to pass conditional and unconditional outputs to hooks, improving monitoring capabilities during sampling. - Enhanced documentation to clarify the new functionality and its implications for gradient norm tracking. --- experiments/cfg_effect/main.py | 6 + models.py | 15 ++- utils/sampling_hooks.py | 216 ++++++++++++++++++++++++--------- utils/sampling_utils.py | 62 ++++++++-- 4 files changed, 228 insertions(+), 71 deletions(-) diff --git a/experiments/cfg_effect/main.py b/experiments/cfg_effect/main.py index 8e08d56..2f1926c 100644 --- a/experiments/cfg_effect/main.py +++ b/experiments/cfg_effect/main.py @@ -236,6 +236,12 @@ def run_switch_experiment( baseline_norms = baseline_grad_tracker.gradient_norms[:switch_step] grad_tracker.gradient_norms = baseline_norms + grad_tracker.gradient_norms + baseline_norms_cond = baseline_grad_tracker.gradient_norms_cond[:switch_step] + grad_tracker.gradient_norms_cond = baseline_norms_cond + grad_tracker.gradient_norms_cond + + baseline_norms_uncond = baseline_grad_tracker.gradient_norms_uncond[:switch_step] + grad_tracker.gradient_norms_uncond = baseline_norms_uncond + grad_tracker.gradient_norms_uncond + grad_tracker.finalize(output_folder, num_sampling_steps, stepsize, sampler) diff --git a/models.py b/models.py index 6473154..061ee4a 100644 --- a/models.py +++ b/models.py @@ -278,9 +278,15 @@ def forward(self, x0, t, y, return_act=False, get_energy=False, train=False): return x, act return x - def forward_with_cfg(self, x, t, y, cfg_scale, return_act=False, get_energy=False, train=False): + def forward_with_cfg( + self, x, t, y, cfg_scale, return_act=False, get_energy=False, train=False, return_components=False + ): """ Forward pass of EqM, but also batches the uncondional forward pass for classifier-free guidance. + + Args: + return_components: If True, returns (combined_out, cond_out, uncond_out) tuple + for tracking individual outputs before CFG combination. """ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb half = x[: len(x) // 2] @@ -305,9 +311,12 @@ def forward_with_cfg(self, x, t, y, cfg_scale, return_act=False, get_energy=Fals cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) + combined_out = torch.cat([eps, rest], dim=1) + if return_components: + return combined_out, cond_eps, uncond_eps if get_energy: - return torch.cat([eps, rest], dim=1), E - return torch.cat([eps, rest], dim=1) + return combined_out, E + return combined_out ################################################################################# diff --git a/utils/sampling_hooks.py b/utils/sampling_hooks.py index 452e17f..db110bf 100644 --- a/utils/sampling_hooks.py +++ b/utils/sampling_hooks.py @@ -40,12 +40,14 @@ class SamplingHookContext: xt: torch.Tensor # Current latent state t: torch.Tensor # Current timestep y: torch.Tensor # Class labels - out: torch.Tensor # Model output/gradient + out: torch.Tensor # Model output/gradient (CFG-combined when use_cfg=True) step_idx: int # Current step index (1-indexed) use_cfg: bool # Whether CFG is enabled vae: Any # VAE decoder for image conversion device: torch.device # Device total_steps: int # Total number of sampling steps + out_cond: torch.Tensor | None = None # Class label output before CFG (only when use_cfg=True) + out_uncond: torch.Tensor | None = None # Null label output before CFG (only when use_cfg=True) class IntermediateImageSaver: @@ -327,25 +329,43 @@ class GradientNormTracker: """ Hook for tracking gradient L2 norms during sampling. + When CFG is used, tracks three norms: + - gradient_norms: CFG-applied output (combined) + - gradient_norms_cond: conditional output (out_cond) + - gradient_norms_uncond: unconditional output (out_uncond) + + When CFG is not used, only gradient_norms is tracked. + Args: num_steps: Number of sampling steps (for pre-allocating storage) """ def __init__(self, num_steps): - self.gradient_norms = [[] for _ in range(num_steps)] + self.gradient_norms = [[] for _ in range(num_steps)] # CFG-applied output norms + self.gradient_norms_cond = [[] for _ in range(num_steps)] # conditional output norms + self.gradient_norms_uncond = [[] for _ in range(num_steps)] # unconditional output norms def __call__(self, context: SamplingHookContext): """Accumulate gradient L2 norms for the current step.""" - # Extract conditional part if using CFG + # Track CFG-applied output norm (first half, since duplicated) out_for_norm = context.out if context.use_cfg: batch_size = context.out.shape[0] // 2 out_for_norm = context.out[:batch_size] - - # Compute L2 norm for each sample in the batch - norms = torch.linalg.norm(out_for_norm.reshape(out_for_norm.shape[0], -1), dim=1) # shape: (batch_size,) + norms = torch.linalg.norm(out_for_norm.reshape(out_for_norm.shape[0], -1), dim=1) self.gradient_norms[context.step_idx - 1].extend(norms.cpu().tolist()) + # Track cond/uncond norms when CFG is used + if context.use_cfg and context.out_cond is not None and context.out_uncond is not None: + out_cond = context.out_cond + out_uncond = context.out_uncond + + norms_cond = torch.linalg.norm(out_cond.reshape(out_cond.shape[0], -1), dim=1) + self.gradient_norms_cond[context.step_idx - 1].extend(norms_cond.cpu().tolist()) + + norms_uncond = torch.linalg.norm(out_uncond.reshape(out_uncond.shape[0], -1), dim=1) + self.gradient_norms_uncond[context.step_idx - 1].extend(norms_uncond.cpu().tolist()) + def finalize(self, folder: str, num_sampling_steps: int, stepsize: float, sampler: str): """ Compute statistics and create visualization for gradient norms. @@ -358,27 +378,41 @@ def finalize(self, folder: str, num_sampling_steps: int, stepsize: float, sample sampler: Sampler name (e.g., 'euler', 'heun') """ print("Computing gradient norm statistics...") - gradient_means = [] - gradient_stds = [] - for step_norms in self.gradient_norms: - if len(step_norms) > 0: - gradient_means.append(np.mean(step_norms)) - gradient_stds.append(np.std(step_norms)) - else: - gradient_means.append(0.0) - gradient_stds.append(0.0) + def compute_stats(norm_list): + means, stds = [], [] + for step_norms in norm_list: + if len(step_norms) > 0: + means.append(np.mean(step_norms)) + stds.append(np.std(step_norms)) + else: + means.append(0.0) + stds.append(0.0) + return means, stds + + # Compute stats for CFG-applied output + cfg_means, cfg_stds = compute_stats(self.gradient_norms) + + # Compute stats for cond/uncond outputs + has_cond = any(len(s) > 0 for s in self.gradient_norms_cond) + has_uncond = any(len(s) > 0 for s in self.gradient_norms_uncond) + cond_means, cond_stds = compute_stats(self.gradient_norms_cond) if has_cond else ([], []) + uncond_means, uncond_stds = compute_stats(self.gradient_norms_uncond) if has_uncond else ([], []) # Save statistics to JSON stats = { "num_sampling_steps": num_sampling_steps, "total_samples": len(self.gradient_norms[0]) if len(self.gradient_norms[0]) > 0 else 0, - "mean": gradient_means, - "std": gradient_stds, + "cfg_output": {"mean": cfg_means, "std": cfg_stds}, "stepsize": stepsize, "sampler": sampler, - "note": "Statistics computed from individual gradient L2 norms across all samples (batch-size independent)", + "note": "Statistics computed from individual gradient L2 norms across all samples", } + if has_cond: + stats["cond_output"] = {"mean": cond_means, "std": cond_stds} + if has_uncond: + stats["uncond_output"] = {"mean": uncond_means, "std": uncond_stds} + json_path = f"{folder}/gradient_norms.json" with open(json_path, "w") as f: json.dump(stats, f, indent=2) @@ -387,18 +421,50 @@ def finalize(self, folder: str, num_sampling_steps: int, stepsize: float, sample # Create plot print("Creating gradient norm plot...") steps = np.arange(0, num_sampling_steps) - gradient_means = np.array(gradient_means) - gradient_stds = np.array(gradient_stds) + cfg_means = np.array(cfg_means) + cfg_stds = np.array(cfg_stds) plt.figure(figsize=(10, 6)) - plt.plot(steps, gradient_means, linewidth=2, label="Mean L2 Norm") + + # Plot CFG-applied output norms + plt.plot(steps, cfg_means, linewidth=2, label="CFG Output Mean", color="green") plt.fill_between( steps, - gradient_means - gradient_stds, - gradient_means + gradient_stds, - alpha=0.3, - label="Mean ± Std", + cfg_means - cfg_stds, + cfg_means + cfg_stds, + alpha=0.2, + color="green", + label="CFG Output ± Std", ) + + # Plot conditional output norms if available + if has_cond: + cond_means = np.array(cond_means) + cond_stds = np.array(cond_stds) + plt.plot(steps, cond_means, linewidth=2, label="Cond Output Mean", color="blue") + plt.fill_between( + steps, + cond_means - cond_stds, + cond_means + cond_stds, + alpha=0.2, + color="blue", + label="Cond Output ± Std", + ) + + # Plot unconditional output norms if available + if has_uncond: + uncond_means = np.array(uncond_means) + uncond_stds = np.array(uncond_stds) + plt.plot(steps, uncond_means, linewidth=2, label="Uncond Output Mean", color="orange") + plt.fill_between( + steps, + uncond_means - uncond_stds, + uncond_means + uncond_stds, + alpha=0.2, + color="orange", + label="Uncond Output ± Std", + ) + plt.xlabel("Sampling Step", fontsize=12) plt.ylabel("Gradient L2 Norm", fontsize=12) plt.title( @@ -426,41 +492,81 @@ def finalize_for_wandb(self, wandb_module, train_step): if wandb_module is None: return - # Compute statistics - gradient_means = [] - gradient_stds = [] - - for step_norms in self.gradient_norms: - if len(step_norms) > 0: - gradient_means.append(np.mean(step_norms)) - gradient_stds.append(np.std(step_norms)) - else: - gradient_means.append(0.0) - gradient_stds.append(0.0) + def compute_stats(norm_list): + means, stds = [], [] + for step_norms in norm_list: + if len(step_norms) > 0: + means.append(np.mean(step_norms)) + stds.append(np.std(step_norms)) + else: + means.append(0.0) + stds.append(0.0) + return means, stds + + # Compute stats for CFG-applied output + cfg_means, cfg_stds = compute_stats(self.gradient_norms) + + # Compute stats for cond/uncond outputs + has_cond = any(len(s) > 0 for s in self.gradient_norms_cond) + has_uncond = any(len(s) > 0 for s in self.gradient_norms_uncond) + cond_means, cond_stds = compute_stats(self.gradient_norms_cond) if has_cond else ([], []) + uncond_means, uncond_stds = compute_stats(self.gradient_norms_uncond) if has_uncond else ([], []) # Create a table for gradient norms vs sampling steps - # This allows WandB to plot with sampling_step on x-axis data = [] - for sampling_step, (mean_norm, std_norm) in enumerate(zip(gradient_means, gradient_stds)): - data.append( + for step in range(len(cfg_means)): + row = [ + step, + cfg_means[step], + cfg_stds[step], + cfg_means[step] + cfg_stds[step], + cfg_means[step] - cfg_stds[step], + ] + if has_cond: + row.extend( + [ + cond_means[step], + cond_stds[step], + cond_means[step] + cond_stds[step], + cond_means[step] - cond_stds[step], + ] + ) + if has_uncond: + row.extend( + [ + uncond_means[step], + uncond_stds[step], + uncond_means[step] + uncond_stds[step], + uncond_means[step] - uncond_stds[step], + ] + ) + data.append(row) + + columns = [ + "sampling_step", + "cfg_output_norm_mean", + "cfg_output_norm_std", + "cfg_output_norm_upper", + "cfg_output_norm_lower", + ] + if has_cond: + columns.extend( [ - sampling_step, - mean_norm, - std_norm, - mean_norm + std_norm, # Upper bound - mean_norm - std_norm, # Lower bound + "cond_output_norm_mean", + "cond_output_norm_std", + "cond_output_norm_upper", + "cond_output_norm_lower", + ] + ) + if has_uncond: + columns.extend( + [ + "uncond_output_norm_mean", + "uncond_output_norm_std", + "uncond_output_norm_upper", + "uncond_output_norm_lower", ] ) - table = wandb_module.Table( - columns=[ - "sampling_step", - "gradient_norm_mean", - "gradient_norm_std", - "gradient_norm_upper", - "gradient_norm_lower", - ], - data=data, - ) - # Log the table with a consistent key (train_step is already in the table data) + table = wandb_module.Table(columns=columns, data=data) wandb_module.log({"gradient_norms": table}, step=train_step) diff --git a/utils/sampling_utils.py b/utils/sampling_utils.py index cfcbce3..c5792bd 100644 --- a/utils/sampling_utils.py +++ b/utils/sampling_utils.py @@ -31,7 +31,7 @@ from PIL import Image from tqdm import tqdm -from utils.sampling_hooks import SamplingHookContext +from utils.sampling_hooks import GradientNormTracker, SamplingHookContext @torch.no_grad() @@ -49,6 +49,7 @@ def sample_eqm( sampler="gd", mu=0.3, hooks=None, + return_cfg_components=False, ): """ Generate samples using EqM model with gradient descent sampling. @@ -72,6 +73,8 @@ def sample_eqm( hooks: List of hook callables. Each hook receives a SamplingHookContext object at each sampling step. Use for monitoring, logging, or saving intermediate results. Example hooks: IntermediateImageSaver, GradientNormTracker. + return_cfg_components: If True and CFG is enabled, passes individual cond/uncond outputs + to hooks via out_cond and out_uncond context fields. Returns: samples: Generated images as numpy array, shape (batch_size, H, W, 3), dtype uint8 @@ -90,6 +93,10 @@ def sample_eqm( if hooks is None: hooks = [] + # Auto-enable return_cfg_components if GradientNormTracker is present + if any(isinstance(h, GradientNormTracker) for h in hooks): + return_cfg_components = True + use_cfg = cfg_scale > 1.0 n = batch_size @@ -125,17 +132,26 @@ def sample_eqm( # Sampling loop with torch.no_grad(): for step_idx in tqdm(range(1, num_sampling_steps + 1)): + out_cond = None + out_uncond = None + if sampler == "gd": # Standard gradient descent - out = model_fn(xt, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(xt, t, y, cfg_scale, return_components=True) + else: + out = model_fn(xt, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] elif sampler == "ngd": # Nesterov accelerated gradient descent x_ = xt + stepsize * m * mu - out = model_fn(x_, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(x_, t, y, cfg_scale, return_components=True) + else: + out = model_fn(x_, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] m = out else: raise ValueError(f"Unknown sampler: {sampler}") @@ -156,6 +172,8 @@ def sample_eqm( vae=vae, device=device, total_steps=num_sampling_steps, + out_cond=out_cond, + out_uncond=out_uncond, ) for hook in hooks: hook(context) @@ -184,6 +202,7 @@ def sample_eqm_two( sampler="gd", mu=0.3, hooks=None, + return_cfg_components=False, ): """ Generate samples using EqM model with gradient descent sampling. @@ -207,6 +226,8 @@ def sample_eqm_two( hooks: List of hook callables. Each hook receives a SamplingHookContext object at each sampling step. Use for monitoring, logging, or saving intermediate results. Example hooks: IntermediateImageSaver, GradientNormTracker. + return_cfg_components: If True and CFG is enabled, passes individual cond/uncond outputs + to hooks via out_cond and out_uncond context fields. Returns: samples: Generated images as numpy array, shape (batch_size, H, W, 3), dtype uint8 @@ -225,6 +246,10 @@ def sample_eqm_two( if hooks is None: hooks = [] + # Auto-enable return_cfg_components if GradientNormTracker is present + if any(isinstance(h, GradientNormTracker) for h in hooks): + return_cfg_components = True + use_cfg = cfg_scale > 1.0 n = batch_size @@ -260,17 +285,26 @@ def sample_eqm_two( # Sampling loop with torch.no_grad(): for step_idx in range(1, num_sampling_steps + 1): + out_cond = None + out_uncond = None + if sampler == "gd": # Standard gradient descent - out = model_fn(xt, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(xt, t, y, cfg_scale, return_components=True) + else: + out = model_fn(xt, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] elif sampler == "ngd": # Nesterov accelerated gradient descent x_ = xt + stepsize * m * mu - out = model_fn(x_, t, y, cfg_scale) - if not torch.is_tensor(out): - out = out[0] + if use_cfg and return_cfg_components: + out, out_cond, out_uncond = model_fn(x_, t, y, cfg_scale, return_components=True) + else: + out = model_fn(x_, t, y, cfg_scale) + if not torch.is_tensor(out): + out = out[0] m = out else: raise ValueError(f"Unknown sampler: {sampler}") @@ -313,6 +347,8 @@ def sample_eqm_two( vae=vae, device=device, total_steps=num_sampling_steps, + out_cond=out_cond, + out_uncond=out_uncond, ) for hook in hooks: hook(context)