The Cell Observatory Platform is a comprehensive framework for training and evaluating machine learning models on biological image and video datasets. Built with PyTorch, orchestrated and scaled with Ray, model sharding via DeepSpeed or native PyTorch parallelism (TorchTitan), and flexibly configured using Hydra, it provides a modular architecture for easy customization and extension.
- Cell Observatory Platform
- Installation
- Running docker image
- Get started
- Architecture Overview
- Data Pipeline
- Models
- Training
- Logging & Experiment Tracking
- Inference
- Evaluation
- Profiling
- Configuration layout
- License
Docker images
Our prebuilt image with Python, Torch, and all packages installed for you.
docker pull ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_25_08git clone --recurse-submodules https://github.com/cell-observatory/cell_observatory_platform.gitTo later update to the latest, greatest.
git pull --recurse-submodulesNote
If you want to run a local version of the image, see the Dockerfile
You will need to create a Supabase and W&B account to use the platform. Supabase can be found at Cell Observatory Database, and W&B can be found at Cell Observatory Dashboard.
Once you have created your Supabase and W&B accounts, you'll need to add your API keys in the environment variables as described below.
Rename .env.example file to .env which will be automatically loaded into the container and will be gitignored. The Supabase related environment variables enable database functionality. The W&B API key enables logging functionality. The REPO_NAME, DATA_DIR, and STORAGE_SERVER_DIR environment variables are leverged in the configs/paths configuration files to ensure that jobs run and save outputs as expected.
Note
STORAGE_SERVER_DIR is usually set to the root directory of your files. See example below:
STORAGE_SERVER_DIR="/clusterfs/scratch/user/"
REPO_DIR="/clusterfs/scratch/user/cell_observatory_platform"
DATA_DIR="/clusterfs/scratch/user/cell_observatory_data"SUPABASE_USER=REPLACE_ME_WITH_YOUR_SUPABASE_USERNAME
SUPABASE_PASS=REPLACE_ME_WITH_YOUR_SUPABASE_PASSWORD
TRINO_USER=REPLACE_ME_WITH_YOUR_TRINO_USERNAME
TRINO_PASS=REPLACE_ME_WITH_YOUR_TRINO_PASSWORD
SUPABASE_STAGING_ID=REPLACE_ME_WITH_YOUR_SUPABASE_STAGING_ID
SUPABASE_PROD_ID=REPLACE_ME_WITH_YOUR_SUPABASE_PROD_ID
WANDB_API_KEY=REPLACE_ME_WITH_YOUR_WANDB_API_KEY
SUPABASE_STAGING_URI="postgresql://${SUPABASE_USER}.${SUPABASE_STAGING_ID}:${SUPABASE_PASS}@aws-0-us-east-1.pooler.supabase.com:5432/postgres"
SUPABASE_PROD_URI="postgresql://${SUPABASE_USER}.${SUPABASE_PROD_ID}:${SUPABASE_PASS}@aws-0-us-east-1.pooler.supabase.com:5432/postgres"
REPO_NAME=cell_observatory_platform # TODO: replace with your repo name if you renamed it
REPO_DIR=REPLACE_ME_WITH_YOUR_ROOT_REPO_DIR
DATA_DIR=REPLACE_ME_WITH_YOUR_ROOT_DATA_DIR_WHERE_DATA_WILL_BE_SAVED
STORAGE_SERVER_DIR=REPLACE_ME_WITH_YOUR_STORAGE_SERVER_DIR_WHERE_DATA_SERVER_IS_MOUNTED
PYTHONPATH=REPLACE_ME_WITH_YOUR_ROOT_REPO_DIRImportant
Username/password and IDs for supabase will be provided upon request.
To run docker image, cd to repo directory or replace $(pwd) with your local path for the repository.
docker run --network host -u 1000 --privileged -v $(pwd):/workspace/cell_observatory_platform -w /workspace/cell_observatory_platform --env PYTHONUNBUFFERED=1 --pull missing -it --rm --ipc host --gpus all ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_25_08 bashRunning an image on a cluster typically requires an apptainer version of the image, which can be generated by:
apptainer pull --arch amd64 --force develop_torch_25_08.sif docker://ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_25_08apptainer pull --arch arm64 --force develop_torch_25_08.sif docker://ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_25_08First, you need to build an apptainer image for torch from the containers provided by Nvidia (e.g., 25.08-py3 from this catalog):
apptainer pull --arch amd64 --force pytorch_25.08-py3.sif docker://nvcr.io/nvidia/pytorch:25.08-py3Then you can run the following command to build a complete image:
apptainer build --arch amd64 --nv --force develop_torch_25_08.sif apptainerfile.defImportant
Make sure to pass in the right argument for your system (amd64 or arm64)
All jobs are orchestrated on top of a Ray cluster and launched through our manager.py script, which facilitates cluster resource allocation and Ray cluster setup. You may decide whether to run jobs locally or on a cluster by setting the launcher_type variable in configs/clusters/*.yaml. We show how to run jobs locally and on SLURM or LSF clusters below.
Example job configs are located in the configs/experiments folder. For local jobs, you can use our existing configs/paths/local.yaml and configs/clusters/local.yaml configurations.
experiment_name: test_cell_observatory_platform
wandb_project: test_cell_observatory_platformpaths:
outdir: ${paths.data_path}/pretrained_models/${experiment_name}
resume_checkpointdir: null
pretrained_checkpointdir: nullclusters:
batch_size: 2
worker_nodes: 1
gpus_per_worker: 1
cpus_per_gpu: 4
mem_per_cpu: 16000python manager.py --config-name=configs/test_pretrain_4d_mae_local.yamlTo launch multiple training jobs, set run_type to multi_run and define a runs list. For Ray Tune hyperparameter sweeps, set run_type to tune.
Running a job on a cluster is very similar to the local setup. Override the defaults in your config file to match your cluster:
defaults:
- clusters: abc_a100
- paths: abcdefaults:
- clusters: janelia_h100
- paths: janelia┌─────────────────────────────────────────────────────────────────────────────┐
│ Ray Cluster │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ manager.py ││
│ │ (SLURM / LSF / Local) ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │ │
│ ┌────────────────────────┼────────────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ EpochBased │ │ ParallelEpoch │ │ Inferencer │ │
│ │ Trainer │ │ BasedTrainer │ │ │ │
│ │ (DeepSpeed) │ │ (TorchTitan) │ │ (Distributed) │ │
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
│ │ │ │ │
│ └────────────────────────┼────────────────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ Ray Data Pipeline ││
│ │ LoaderActor → SharedMemory → CollatorActor → DeviceBuffer → GPU ││
│ └─────────────────────────────────────────────────────────────────────────┘│
└─────────────────────────────────────────────────────────────────────────────┘
The data pipeline is built on Ray Data with a custom queue-based system for high-throughput data loading:
Database (Supabase/Local CSV) -> LoaderActor -> Host SharedMemory Buffer -> CollatorActor -> Device Buffer -> Preprocessor -> Model
ParentDatabase / SupabaseDatabase: Flexible database classes supporting both remote (Supabase/PostgreSQL) and local (cached CSV) data sources.
- Remote mode: Queries Supabase via ConnectorX (Arrow path) for high-performance data fetching
- Local mode: Loads from cached CSV + JSON config for offline/fast iteration
- Data filtering: ROI/tile/HPF selection, occupancy thresholds, CDF thresholds, annotation filtering
# Example database config
datasets:
databases:
_target_: data.databases.supabase_database.SupabaseDatabase
use_cached_hypercubes_dataframe: true
hypercubes_dataframe_path: ${paths.data_path}/databases/hypercubes.csv
server_folder_path: /groups/betzig/dataThe Ray-based dataloader uses a multi-actor architecture for maximum throughput:
LoaderActor: Reads hypercubes from Zarr via TensorStore into pinned shared memory buffers.
CollatorActor / FinetuneCollatorActor: Transfers batches from host shared memory to GPU device buffers with optional transforms, mask extraction, and target building.
Preprocessors provide a unified interface for task-specific data preparation:
| Preprocessor | Task | Description |
|---|---|---|
RayPreprocessor |
Pretraining | dtype normalization, masking, transforms |
ChannelSplitPreprocessor |
Channel Split | Predicts per-channel from averaged input |
UpsamplePreprocessor |
Super-resolution | NA-mask downsampling for space/time upsampling |
InstanceSegmentationPreprocessor |
Instance Seg | Mask/bbox extraction, target building |
Transforms can be applied in either the Collator (CPU, during data loading) or the Preprocessor (GPU, before model forward). Available transforms in data/transforms/:
Resize: Spatial resizing with mask/bbox scalingCrop: Random and center croppingNormalize: Percentile-based normalizationProbabilisticChoice: Randomly select between transform pipelines
Generates patch-level masks for self-supervised learning with explicit time/space awareness:
BLOCKED/BLOCKED_TIME_ONLY/BLOCKED_SPACE_ONLY: Block-based maskingRANDOM/RANDOM_SPACE_ONLY: MAE-style maskingBLOCKED_PATTERNED: Deterministic time downsampling patterns
| Model | Location | Description |
|---|---|---|
| MAE | models/meta_arch/maskedautoencoder.py |
Masked Autoencoder for 3D/4D volumes |
| JEPA | models/meta_arch/jepa.py |
Joint-Embedding Predictive Architecture |
| Model | Location | Description |
|---|---|---|
| plainDETR | models/meta_arch/plainDETR.py |
3D object detection |
| MaskDINO | models/meta_arch/maskdino.py |
3D instance segmentation |
| Mask2Former | 3D semantic segmentation |
- ViT (
models/backbones/vit.py): Vision Transformer with RoPE/sincos positional encoding - ConvNeXt (
models/backbones/convnext.py): ConvNeXt backbone - MaskedEncoder (
models/backbones/maskedencoder.py): Encoder with masking support
Key layer implementations:
- Attention: Multi-head self-attention with flash attention support and deformable attention
- Transformer: Standard and deformable transformer blocks
- Patch Embeddings: 3D/4D patch embedding with multiple layout support
- Positional Encoding: Sinusoidal, learned, and RoPE encodings
- Matchers: Hungarian matcher for detection/segmentation
Standard training loop with DeepSpeed ZeRO optimization:
Features:
- DeepSpeed ZeRO stages 1-3
- Mixed precision (bf16/fp16)
- Checkpoint saving/resuming
- Hook-based extensibility
Advanced parallelism support via TorchTitan integration (work in progress):
Features:
- TP (Tensor Parallelism): Shards model tensor
- CP (Context Parallelism): Shards sequence dimension
- FSDP (Model and Data Parallelism): FSDP-based sharding
- Torch Compile support
- Activation checkpointing
The training loop is extensible via a priority-based hook system. Hooks can execute at various points: before_train, before_epoch, before_step, after_backward, after_step, after_epoch, after_train, and validation/test equivalents.
| Hook | Purpose |
|---|---|
IterationTimer |
Tracks step/epoch/validation timing, logs ETA |
LRScheduler |
Executes LR scheduler steps, logs learning rate |
WeightDecayScheduleHook |
Updates weight decay on schedule |
PeriodicWriter |
Writes metrics to loggers at epoch end |
PeriodicCheckpointer |
Saves checkpoints at configurable intervals |
BestCheckpointer |
Reports best checkpoint to Ray for model selection |
BestMetricSaver |
Tracks and updates best validation metric |
TorchMemoryStats |
Logs CUDA memory usage (reserved/allocated) |
TorchProfiler |
PyTorch profiler with TensorBoard traces and memory snapshots |
NsysProfilerHook |
NVIDIA Nsight Systems profiling for GPU timeline analysis |
EarlyStopHook |
Stops training if validation metric plateaus |
EMASchedulerHook |
Updates EMA beta for target networks (JEPA) |
FreeDeviceBufferHook |
Releases Ray device buffers to prevent deadlocks |
AnomalyDetector |
Detects NaN/Inf losses with torch.autograd.detect_anomaly |
GarbageCollectionHook |
Periodic synchronized GC to prevent memory fragmentation |
MemoryDebugHook |
Detailed memory dumps (proc, CUDA, Ray, /dev/shm) |
AdjustTimeoutHook |
Adjusts distributed timeout for long-running ops |
The EventRecorder is the central hub for collecting metrics during training. It supports step-scoped and epoch-scoped scalars with configurable reduction methods (mean, median, sum, min, max).
# Record a scalar at step scope
trainer.event_recorder.put_scalar("loss", loss_value, scope="step")
# Record multiple scalars with a prefix
trainer.event_recorder.put_scalars(scope="epoch", prefix="val_", accuracy=0.95, f1=0.92)The WandBEventWriter provides seamless Weights & Biases integration.
Features:
- Automatic login via
WANDB_API_KEYfrom.env - Custom metric namespacing (
step/*,epoch/*) - Tags and notes for experiment organization
The LocalEventWriter saves metrics to CSV files for offline analysis.
For advanced training loops (TorchTitan), the MetricsProcessor computes detailed performance metrics (work in progress):
- Throughput: Tokens per second (TPS)
- MFU: Model FLOPS Utilization percentage
- TFLOPS: Achieved TFLOPS
- Timing: Forward/backward/optimizer step times (ms)
- Memory: Peak active/reserved GPU memory, allocation retries, OOMs
The InferencerWorker provides distributed inference with two modes:
stitch_volume: Reconstructs full volumes from hypercube predictions via all-to-all communication.
save_local: Saves per-sample predictions locally (for detection/segmentation tasks).
Supported tasks:
detection(plainDETR)instance_segmentation(MaskDINO)semantic_segmentation(mask2Former)dense_prediction(upsampling, channel split)pretrain(reconstruction)feature_extractor(feature visualization)
Multiple profiling tools are supported:
| Profiler | Use Case |
|---|---|
| pprof (gperftools) | CPU profiling with @pprof_func / @pprof_class decorators |
| NVIDIA Nsys | GPU timeline profiling via NsysProfilerHook |
| PyTorch Profiler | Operator-level profiling via TorchProfiler hook |
| Memory Profiler | PyTorch CUDA memory snapshots and TorchMemoryStats hook |
| Ray Profiler | Ray actor/task profiling via MemoryDebugHook |
Here's what each configuration subdirectory handles:
configs/models/- Model architectures (MAE, JEPA, plainDETR, MaskDINO, backbones, heads)configs/datasets/- Dataset classes, databases, and preprocessor parametersconfigs/tasks/- Task-specific configs (channel_split, instance_segmentation, upsample_*)configs/optimizers/- Optimizer configurations (AdamW, LAMB, Lion, Muon)configs/schedulers/- Learning rate and weight decay schedulersconfigs/optimizations/- Model optimizations (torch.compile, activation checkpointing)configs/hooks/- Training hooks configurationconfigs/checkpoint/- Checkpointing configurationsconfigs/deepspeed/- DeepSpeed ZeRO configurationsconfigs/parallelism/- TorchTitan parallelism settings (TP, CP, PP, DP)configs/clusters/- Cluster configurations (SLURM, LSF, local)configs/paths/- Path configurations for different environments (ABC, Janelia, CoreWeave)configs/loggers/- Logging configurations (WandB, Local CSV)configs/profiling/- Profiling configurations (pprof, nsys, torch profiler)configs/trainer/- Training loop configurationsconfigs/evaluation/- Evaluation configurationsconfigs/inference/- Inference and prediction configurationsconfigs/tune/- Ray Tune hyperparameter sweep configurationsconfigs/benchmarks/- Benchmarking configurations for throughput testingconfigs/experiments/- Complete experiment configs and examples
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at:
Copyright 2025 Cell Observatory.