Skip to content
9 changes: 6 additions & 3 deletions deprecated/sample_eqm_two.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main(args):
save_steps_list = []
if args.save_steps is not None:
save_steps_list = [int(s.strip()) for s in args.save_steps.split(",")]
img_saver = IntermediateImageSaver(save_steps_list, args.out)
img_saver = IntermediateImageSaver(save_steps_list, output_folder=args.out)
hooks.append(img_saver)
print(f"Created IntermediateImageSaver hook for steps: {save_steps_list}")

Expand Down Expand Up @@ -117,7 +117,7 @@ def main(args):
# Finalize gradient norm statistics if enabled
if grad_tracker is not None:
print("Computing gradient norm statistics...")
grad_tracker.finalize(args, args.out)
grad_tracker.finalize(args.out, args.num_sampling_steps, args.stepsize, args.sampler)

# Create .npz files for FID evaluation
print("Creating .npz file for final samples...")
Expand Down Expand Up @@ -168,7 +168,10 @@ def main(args):
help="Comma-separated list of sampling steps to save intermediate images (e.g., '0,50,100,249')",
)
parser.add_argument(
"--track-grad-norm", action="store_true", help="Enable gradient norm tracking and visualization"
"--track-grad-norm",
action=argparse.BooleanOptionalAction,
default=True,
help="Enable gradient norm tracking and visualization",
)
parser.add_argument("--uncond", type=bool, default=True, help="Disable/enable noise conditioning (default: True)")
parser.add_argument(
Expand Down
Loading