-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
I’m training a perturbation‑prediction model using datasets managed via GEARS PertData, and I need to run multi‑GPU training with PyTorch Distributed Data Parallel (DDP). What’s the recommended way to connect a PertData dataset to a DataLoader that uses torch.utils.data.distributed.DistributedSampler so that each rank gets a disjoint shard, preserves split integrity (train/val/test), and supports epoch‑level shuffling?
I’m specifically looking for guidance or example code for:
- Converting or wrapping PertData into a PyTorch‑style dataset that DistributedSampler understands
- Handling split selection (train, val, test) so that all ranks see consistent subsets
- Ensuring deterministic shuffling across epochs via sampler.set_epoch(epoch)
- Best practices around rank/world size, seed control, and worker initialization
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels