From 83318a8767dc346d5e11a70521860e0da2e1099b Mon Sep 17 00:00:00 2001 From: MichaelRamamonjisoa Date: Thu, 20 Nov 2025 14:22:32 +0000 Subject: [PATCH] [convnext] add convnext distillation --- dinov3/configs/ssl_default_config.yaml | 1 + .../convnext_base_p16.yaml | 43 ++++ .../convnext_large_p16.yaml | 43 ++++ .../convnext_small_p16.yaml | 43 ++++ .../convnext_tiny_p16.yaml | 43 ++++ .../multi_distillation_convnext_test.yaml | 41 ++++ dinov3/fsdp/ac_compile_parallelize.py | 211 +++++++++++------- dinov3/models/__init__.py | 9 + dinov3/models/convnext.py | 22 +- dinov3/train/multidist_meta_arch.py | 10 + dinov3/train/ssl_meta_arch.py | 8 +- dinov3/train/train.py | 2 +- 12 files changed, 390 insertions(+), 86 deletions(-) create mode 100644 dinov3/configs/train/distillation_convnext/convnext_base_p16.yaml create mode 100644 dinov3/configs/train/distillation_convnext/convnext_large_p16.yaml create mode 100644 dinov3/configs/train/distillation_convnext/convnext_small_p16.yaml create mode 100644 dinov3/configs/train/distillation_convnext/convnext_tiny_p16.yaml create mode 100644 dinov3/configs/train/distillation_convnext/multi_distillation_convnext_test.yaml diff --git a/dinov3/configs/ssl_default_config.yaml b/dinov3/configs/ssl_default_config.yaml index 87ae3a0a..8002aacc 100644 --- a/dinov3/configs/ssl_default_config.yaml +++ b/dinov3/configs/ssl_default_config.yaml @@ -169,6 +169,7 @@ crops: - 0.229 - 0.224 - 0.225 + teacher_to_student_resolution_scale: 1.0 evaluation: eval_period_iterations: 12500 low_freq_every: 5 diff --git a/dinov3/configs/train/distillation_convnext/convnext_base_p16.yaml b/dinov3/configs/train/distillation_convnext/convnext_base_p16.yaml new file mode 100644 index 00000000..696c7a7d --- /dev/null +++ b/dinov3/configs/train/distillation_convnext/convnext_base_p16.yaml @@ -0,0 +1,43 @@ +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 +student: + arch: convnext_base + patch_size: 16 + drop_path_rate: 0.0 + block_chunks: 4 +optim: + epochs: 500 + clip_grad: 3.0 + layerwise_decay: 1.0 +schedules: + lr: + start: 1e-6 + peak: 1e-4 + end: 1e-6 + warmup_epochs: 80 + freeze_last_layer_epochs: 1 + weight_decay: + start: 0.02 + end: 0.2 + peak: 0.2 + warmup_epochs: 500 + teacher_temp: + start: 0.04 + peak: 0.07 + end: 0.07 + warmup_epochs: 120 + momentum: + start: 0.994 + peak: 1.0 + end: 1.0 + warmup_epochs: 500 diff --git a/dinov3/configs/train/distillation_convnext/convnext_large_p16.yaml b/dinov3/configs/train/distillation_convnext/convnext_large_p16.yaml new file mode 100644 index 00000000..ecec2ca6 --- /dev/null +++ b/dinov3/configs/train/distillation_convnext/convnext_large_p16.yaml @@ -0,0 +1,43 @@ +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 +student: + arch: convnext_large + patch_size: 16 + drop_path_rate: 0.0 + block_chunks: 4 +optim: + epochs: 500 + clip_grad: 3.0 + layerwise_decay: 1.0 +schedules: + lr: + start: 1e-6 + peak: 1e-4 + end: 1e-6 + warmup_epochs: 80 + freeze_last_layer_epochs: 1 + weight_decay: + start: 0.04 + end: 0.2 + peak: 0.2 + warmup_epochs: 500 + teacher_temp: + start: 0.04 + peak: 0.07 + end: 0.07 + warmup_epochs: 120 + momentum: + start: 0.994 + peak: 1.0 + end: 1.0 + warmup_epochs: 500 diff --git a/dinov3/configs/train/distillation_convnext/convnext_small_p16.yaml b/dinov3/configs/train/distillation_convnext/convnext_small_p16.yaml new file mode 100644 index 00000000..b89d5eb2 --- /dev/null +++ b/dinov3/configs/train/distillation_convnext/convnext_small_p16.yaml @@ -0,0 +1,43 @@ +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 +student: + arch: convnext_small + patch_size: 16 + drop_path_rate: 0.0 + block_chunks: 4 +optim: + epochs: 500 + clip_grad: 3.0 + layerwise_decay: 1.0 +schedules: + lr: + start: 1e-6 + peak: 2e-4 + end: 1e-6 + warmup_epochs: 80 + freeze_last_layer_epochs: 1 + weight_decay: + start: 0.04 + end: 0.2 + peak: 0.2 + warmup_epochs: 500 + teacher_temp: + start: 0.04 + peak: 0.07 + end: 0.07 + warmup_epochs: 120 + momentum: + start: 0.994 + peak: 1.0 + end: 1.0 + warmup_epochs: 500 diff --git a/dinov3/configs/train/distillation_convnext/convnext_tiny_p16.yaml b/dinov3/configs/train/distillation_convnext/convnext_tiny_p16.yaml new file mode 100644 index 00000000..a35030db --- /dev/null +++ b/dinov3/configs/train/distillation_convnext/convnext_tiny_p16.yaml @@ -0,0 +1,43 @@ +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + mask_random_circular_shift: false + force_masking_even_with_zero_weight: false + separate_head: true + head_norm_last_layer: false + head_nlayers: 3 + head_hidden_dim: 2048 +student: + arch: convnext_tiny + patch_size: 16 + drop_path_rate: 0.0 + block_chunks: 4 +optim: + epochs: 500 + clip_grad: 3.0 + layerwise_decay: 1.0 +schedules: + lr: + start: 1e-6 + peak: 2e-4 + end: 1e-6 + warmup_epochs: 80 + freeze_last_layer_epochs: 1 + weight_decay: + start: 0.04 + end: 0.2 + peak: 0.2 + warmup_epochs: 500 + teacher_temp: + start: 0.04 + peak: 0.07 + end: 0.07 + warmup_epochs: 120 + momentum: + start: 0.994 + peak: 1.0 + end: 1.0 + warmup_epochs: 500 diff --git a/dinov3/configs/train/distillation_convnext/multi_distillation_convnext_test.yaml b/dinov3/configs/train/distillation_convnext/multi_distillation_convnext_test.yaml new file mode 100644 index 00000000..178b856d --- /dev/null +++ b/dinov3/configs/train/distillation_convnext/multi_distillation_convnext_test.yaml @@ -0,0 +1,41 @@ +MODEL: + META_ARCHITECTURE: MultiDistillationMetaArch +multidistillation: + enabled: true + global_batch_size: 32 # 4096 for 16 nodes + students: + - name: convnext_tiny + config_path: dinov3/configs/train/distillation_convnext/convnext_tiny_p16.yaml + ranks_range: + - 0 + - 2 + - name: convnext_small + config_path: dinov3/configs/train/distillation_convnext/convnext_small_p16.yaml + ranks_range: + - 2 + - 4 + - name: convnext_base + config_path: dinov3/configs/train/distillation_convnext/convnext_base_p16.yaml + ranks_range: + - 4 + - 6 + - name: convnext_large + config_path: dinov3/configs/train/distillation_convnext/convnext_large_p16.yaml + ranks_range: + - 6 + - 8 +distillation: # teacher + enabled: true + full_cfg_path: dinov3/configs/train/vitl_im1k_lin834.yaml + checkpoint_path: ignore +crops: + global_crops_size: 512 + local_crops_size: 224 + teacher_to_student_resolution_scale: 2.0 +train: + dataset_path: ImageNet:split=TRAIN + cache_dataset: false + centering: "sinkhorn_knopp" + compile: true +ibot: + separate_head: true diff --git a/dinov3/fsdp/ac_compile_parallelize.py b/dinov3/fsdp/ac_compile_parallelize.py index 5706759d..dc7aa3e1 100644 --- a/dinov3/fsdp/ac_compile_parallelize.py +++ b/dinov3/fsdp/ac_compile_parallelize.py @@ -5,7 +5,7 @@ import logging from functools import partial -from typing import Any, List, Optional +from typing import Any, Dict, List import torch import torch.distributed as dist @@ -18,27 +18,117 @@ from dinov3.utils import utils + logger = logging.getLogger("dinov3") -def map_modules_and_blocks(models: list[nn.ModuleDict], callable) -> None: - for m in models: - assert isinstance(m, nn.ModuleDict) - for k in m.keys(): - if k == "backbone": - assert isinstance(m[k].blocks, nn.ModuleList) - for block_id, block in enumerate(m[k].blocks): - m[k].blocks[block_id] = callable(block, is_backbone_block=True) - else: - m[k] = callable(m[k], is_backbone_block=False) +def get_activation_checkpoint_wrapper(cfg): + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper + + if cfg.train.checkpointing_full: + _checkpointing_wrapper = checkpoint_wrapper + logger.info("using selective checkpointing on backbone with full checkpointing policy") + else: + _save_list = [ + # mm + torch.ops.aten.mm.default, + torch.ops.aten._scaled_mm.default, + # attentions + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + _checkpointing_wrapper = partial( + checkpoint_wrapper, + context_fn=partial(create_selective_checkpoint_contexts, _save_list), + preserve_rng_state=True, + ) + logger.info("using selective checkpointing on backbone with selective policy") + return _checkpointing_wrapper + + +def activation_checkpoint_convnext(cfg, model: nn.Module): + _checkpointing_wrapper = get_activation_checkpoint_wrapper(cfg) + for stage_id, stage in enumerate(model.stages): + for block_id, block in enumerate(stage): + model.stages[stage_id][block_id] = _checkpointing_wrapper(block) + for dsl_id, dsl in enumerate(model.downsample_layers): + model.downsample_layers[dsl_id] = _checkpointing_wrapper(dsl) + + +def activation_checkpoint_transformer(cfg, model: nn.Module): + _checkpointing_wrapper = get_activation_checkpoint_wrapper(cfg) + for block_id, b in enumerate(model.blocks): + model.blocks[block_id] = _checkpointing_wrapper(b) + + +def wrap_compile_block(module: nn.Module, use_cuda_graphs: bool, is_backbone_block: bool) -> nn.Module: + if use_cuda_graphs and is_backbone_block: + module.compile(fullgraph=True, dynamic=False, options={"triton.cudagraphs": True}) + else: + module.compile() + return module + + +def compile_convnext(cfg, model: nn.Module): + assert isinstance(model.stages, nn.ModuleList) + # Compile at stage level + for stage_id, stage in enumerate(model.stages): + model.stages[stage_id] = wrap_compile_block(stage, cfg.train.cudagraphs, is_backbone_block=False) + assert isinstance(model.downsample_layers, nn.ModuleList) + for dsl_id, dsl in enumerate(model.downsample_layers): + model.downsample_layers[dsl_id] = wrap_compile_block(dsl, cfg.train.cudagraphs, is_backbone_block=False) + + +def compile_transformer(cfg, model: nn.Module): + assert isinstance(model.blocks, nn.ModuleList) + for block_id, block in enumerate(model.blocks): + model.blocks[block_id] = wrap_compile_block(block, cfg.train.cudagraphs, is_backbone_block=True) + + +def fsdp_convnext(fsdp_config: Dict[str, Any], model: nn.Module): + stages = model.stages + assert isinstance(stages, nn.ModuleList) + # FSDP wrap at stage level + for stage_id, stage in enumerate(stages): + stage_reshard: int | bool = True + stages[stage_id] = fully_shard(stage, **fsdp_config, reshard_after_forward=stage_reshard) + downsample_layers = model.downsample_layers + assert isinstance(downsample_layers, nn.ModuleList) + for dsl_id, dsl in enumerate(downsample_layers): + dsl_reshard: int | bool = True + downsample_layers[dsl_id] = fully_shard(dsl, **fsdp_config, reshard_after_forward=dsl_reshard) + dsl: FSDPState + stage: FSDPState + for dsl, stage in zip(downsample_layers, stages): + dsl.set_modules_to_forward_prefetch([stage]) + stage.set_modules_to_backward_prefetch([dsl]) + fully_shard(model, **fsdp_config, reshard_after_forward=True) + register_fsdp_forward_method(model, "get_intermediate_layers") + + +def fsdp_transformer(fsdp_config: Dict[str, Any], model: nn.Module): + # Backbone - FSDP every block + blocks = model.blocks + assert isinstance(blocks, nn.ModuleList) + for block_id, block in enumerate(blocks): + block_reshard: int | bool = True + blocks[block_id] = fully_shard(block, **fsdp_config, reshard_after_forward=block_reshard) + prev_block: FSDPState + next_block: FSDPState + for prev_block, next_block in zip(blocks, blocks[1:]): + prev_block.set_modules_to_forward_prefetch([next_block]) + next_block.set_modules_to_backward_prefetch([prev_block]) + fully_shard(model, **fsdp_config, reshard_after_forward=True) + register_fsdp_forward_method(model, "get_intermediate_layers") def ac_compile_parallelize( trained_model: nn.ModuleDict, inference_only_models: List[nn.ModuleDict], cfg: Any, - trained_model_process_group: Optional[dist.ProcessGroup] = None, - inference_only_models_process_groups: Optional[List[dist.ProcessGroup]] = None, + trained_model_process_group: dist.ProcessGroup | None = None, + inference_only_models_process_groups: List[dist.ProcessGroup] | None = None, ) -> None: """ Order of the wrappers: @@ -53,33 +143,26 @@ def ac_compile_parallelize( if utils.has_batchnorms(trained_model): raise NotImplementedError - # 1/ AC on blocks - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper + from dinov3.models.convnext import ConvNeXt + from dinov3.models.vision_transformer import DinoVisionTransformer + + # FSDP utils for each architecture type + ARCH_TYPE_MAP = { + ConvNeXt: dict( + compile_fn=compile_convnext, + fsdp_fn=fsdp_convnext, + activation_checkpointing_fn=activation_checkpoint_convnext, + ), + DinoVisionTransformer: dict( + compile_fn=compile_transformer, + fsdp_fn=fsdp_transformer, + activation_checkpointing_fn=activation_checkpoint_transformer, + ), + } - backbone = trained_model.backbone + # 1/ AC on blocks if cfg.train.checkpointing: - if cfg.train.checkpointing_full: - _checkpointing_wrapper = checkpoint_wrapper - logger.info("using selective checkpointing on backbone with full checkpointing policy") - else: - _save_list = [ - # mm - torch.ops.aten.mm.default, - torch.ops.aten._scaled_mm.default, - # attentions - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - ] - _checkpointing_wrapper = partial( - checkpoint_wrapper, - context_fn=partial(create_selective_checkpoint_contexts, _save_list), - preserve_rng_state=True, - ) - logger.info("using selective checkpointing on backbone with selective policy") - for i, b in enumerate(backbone.blocks): - backbone.blocks[i] = _checkpointing_wrapper(b) - + ARCH_TYPE_MAP[type(trained_model.backbone)]["activation_checkpointing_fn"](cfg, trained_model["backbone"]) # 2/ Compile blocks all_models = [trained_model] + inference_only_models if trained_model_process_group is None and inference_only_models_process_groups is None: @@ -90,23 +173,13 @@ def ac_compile_parallelize( all_pgs = [trained_model_process_group] + [None] * len(inference_only_models_process_groups) else: all_pgs = [trained_model_process_group] + inference_only_models_process_groups - - def wrap_compile_block(m: nn.Module, is_backbone_block: bool) -> nn.Module: - if cfg.train.compile: - if is_backbone_block and cfg.train.cudagraphs: - m.compile(fullgraph=True, dynamic=False, options={"triton.cudagraphs": True}) - else: - m.compile() - return m - - map_modules_and_blocks(all_models, wrap_compile_block) - - # 3/ Wrap submodules with FSDP - world_mesh = init_device_mesh( - "cuda", - mesh_shape=(dist.get_world_size(),), - mesh_dim_names=("dp",), - ) + if cfg.train.compile: + for model in all_models: + for k in model.keys(): + if k == "backbone": + ARCH_TYPE_MAP[type(model[k])]["compile_fn"](cfg, model[k]) + else: + model[k] = wrap_compile_block(model[k], use_cuda_graphs=False, is_backbone_block=False) DTYPE_MAP = { "fp16": torch.float16, "bf16": torch.bfloat16, @@ -116,8 +189,7 @@ def wrap_compile_block(m: nn.Module, is_backbone_block: bool) -> nn.Module: param_dtype=DTYPE_MAP[cfg.compute_precision.param_dtype], reduce_dtype=DTYPE_MAP[cfg.compute_precision.reduce_dtype], ) - - for m, pg in zip(all_models, all_pgs): + for model, pg in zip(all_models, all_pgs): if pg is None: world_mesh = init_device_mesh( "cuda", @@ -127,26 +199,11 @@ def wrap_compile_block(m: nn.Module, is_backbone_block: bool) -> nn.Module: else: world_mesh = DeviceMesh.from_group(pg, "cuda") fsdp_config = {"mesh": world_mesh, "mp_policy": mp_policy} - for k in m.keys(): - if k != "backbone": - m[k] = fully_shard(m[k], **fsdp_config, reshard_after_forward=True) - continue - # Backbone - FSDP every block - blocks = m[k].blocks - - assert isinstance(blocks, nn.ModuleList) - for block_id, block in enumerate(blocks): - block_reshard: int | bool = True - # if m is trained_model and dist.get_world_size() % 8 == 0 and dist.get_world_size() > 8: - # block_reshard = 8 - blocks[block_id] = fully_shard(block, **fsdp_config, reshard_after_forward=block_reshard) - prev_block: FSDPState - next_block: FSDPState - for prev_block, next_block in zip(blocks, blocks[1:]): - prev_block.set_modules_to_forward_prefetch([next_block]) - next_block.set_modules_to_backward_prefetch([prev_block]) - fully_shard(m.backbone, **fsdp_config, reshard_after_forward=True) - register_fsdp_forward_method(m.backbone, "get_intermediate_layers") + for k in model.keys(): + if k == "backbone": + ARCH_TYPE_MAP[type(model[k])]["fsdp_fn"](fsdp_config, model[k]) + else: + model[k] = fully_shard(model[k], **fsdp_config, reshard_after_forward=True) # 4/ Move to `cuda` device for model in all_models: diff --git a/dinov3/models/__init__.py b/dinov3/models/__init__.py index b5427e2d..2cb785a0 100644 --- a/dinov3/models/__init__.py +++ b/dinov3/models/__init__.py @@ -14,6 +14,7 @@ from dinov3.layers.fp8_linear import convert_linears_to_fp8 from . import vision_transformer as vits +from . import convnext logger = logging.getLogger("dinov3") @@ -64,6 +65,14 @@ def build_model(args, only_teacher=False, img_size=224, device=None): drop_path_rate=args.drop_path_rate, ) embed_dim = student.embed_dim + elif "convnext" in args.arch: + convnext_cls = convnext.get_convnext_arch(args.arch) + convnext_kwargs = dict(patch_size=args.patch_size) + teacher = convnext_cls(**convnext_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = convnext_cls(**convnext_kwargs) + embed_dim = student.embed_dim else: raise NotImplementedError(f"Unrecognized architecture {args.arch}") student = init_fp8(student, args) diff --git a/dinov3/models/convnext.py b/dinov3/models/convnext.py index 7271ae51..6300de5a 100644 --- a/dinov3/models/convnext.py +++ b/dinov3/models/convnext.py @@ -60,6 +60,7 @@ def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6): self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, dim) + self.layer_scale_init_value = layer_scale_init_value self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 @@ -94,14 +95,18 @@ class LayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.weight = nn.Parameter(torch.empty(normalized_shape)) + self.bias = nn.Parameter(torch.empty(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape,) + def init_weights(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) @@ -193,16 +198,21 @@ def __init__( def init_weights(self): self.apply(self._init_weights) + for stage_id, stage in enumerate(self.stages): + for block_id, block in enumerate(stage): + if block.gamma is not None: + nn.init.constant_(self.stages[stage_id][block_id].gamma, block.layer_scale_init_value) def _init_weights(self, module): if isinstance(module, nn.LayerNorm): module.reset_parameters() if isinstance(module, LayerNorm): - module.weight = nn.Parameter(torch.ones(module.normalized_shape)) - module.bias = nn.Parameter(torch.zeros(module.normalized_shape)) + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) if isinstance(module, (nn.Conv2d, nn.Linear)): - torch.nn.init.trunc_normal_(module.weight, std=0.02) - nn.init.constant_(module.bias, 0) + nn.init.trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) def forward_features(self, x: Tensor | List[Tensor], masks: Optional[Tensor] = None) -> List[Dict[str, Tensor]]: if isinstance(x, torch.Tensor): diff --git a/dinov3/train/multidist_meta_arch.py b/dinov3/train/multidist_meta_arch.py index d2238c4a..1b881935 100644 --- a/dinov3/train/multidist_meta_arch.py +++ b/dinov3/train/multidist_meta_arch.py @@ -46,6 +46,16 @@ def forward_backward( global_batch_size = data["global_batch_size"] # Multidistillation codepath: + + # Downsample teacher crops to match student resolution + downsampling_factor = getattr(self, "crops.teacher_to_student_resolution_scale", 1.0) + if downsampling_factor != 1.0: + global_crops = torch.nn.functional.interpolate( + global_crops, + scale_factor=1.0 / downsampling_factor, + mode="bilinear", + antialias=True, + ) global_crops_subgroup = self.broadcast_to_subgroups( global_crops.view(n_global_crops, -1, *global_crops.shape[1:]), 1, diff --git a/dinov3/train/ssl_meta_arch.py b/dinov3/train/ssl_meta_arch.py index bf591437..f42bed1c 100644 --- a/dinov3/train/ssl_meta_arch.py +++ b/dinov3/train/ssl_meta_arch.py @@ -268,8 +268,12 @@ def _setup_distillation(self): distillation_cfg = OmegaConf.merge(default_cfg, distillation_cfg) assert distillation_cfg.ibot.separate_head is True - assert distillation_cfg.ibot.head_n_prototypes == self.cfg.ibot.head_n_prototypes - assert distillation_cfg.dino.head_n_prototypes == self.cfg.dino.head_n_prototypes + assert distillation_cfg.ibot.head_n_prototypes == self.cfg.ibot.head_n_prototypes, ( + f"{distillation_cfg.ibot.head_n_prototypes} != {self.cfg.ibot.head_n_prototypes}" + ) + assert distillation_cfg.dino.head_n_prototypes == self.cfg.dino.head_n_prototypes, ( + f"{distillation_cfg.dino.head_n_prototypes} != {self.cfg.dino.head_n_prototypes}" + ) assert distillation_cfg.student.patch_size == self.cfg.student.patch_size teacher_model_dict = dict() diff --git a/dinov3/train/train.py b/dinov3/train/train.py index fe788765..95ab2e5a 100644 --- a/dinov3/train/train.py +++ b/dinov3/train/train.py @@ -272,7 +272,7 @@ def build_data_loader_from_cfg( ): # Collate function img_size = cfg.crops.global_crops_size - patch_size = cfg.student.patch_size + patch_size = int(cfg.student.patch_size * cfg.crops.teacher_to_student_resolution_scale) n_tokens = (img_size // patch_size) ** 2 mask_generator = MaskingGenerator( input_size=(img_size // patch_size, img_size // patch_size),