diff --git a/.gitignore b/.gitignore index 27b361f..aac9f87 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,8 @@ env/ *.egg-info dist/ build/ -data/ + +/data/ weights/ output/ *.jpg diff --git a/trolo/configs/yaml/include/dataloader.yml b/trolo/configs/yaml/include/dataloader.yml new file mode 100644 index 0000000..d55a411 --- /dev/null +++ b/trolo/configs/yaml/include/dataloader.yml @@ -0,0 +1,38 @@ + +train_dataloader: + dataset: + transforms: + ops: + - {type: RandomPhotometricDistort, p: 0.5} + - {type: RandomZoomOut, fill: 0} + - {type: RandomIoUCrop, p: 0.8} + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: RandomHorizontalFlip} + - {type: Resize, size: [640, 640], } + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + - {type: ConvertBoxes, fmt: 'cxcywh', normalize: True} + policy: + name: stop_epoch + epoch: 71 # epoch in [71, ~) stop `ops` + ops: ['RandomPhotometricDistort', 'RandomZoomOut', 'RandomIoUCrop'] + + collate_fn: + type: BatchImageCollateFunction + scales: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800] + stop_epoch: 71 # epoch in [71, ~) stop `multiscales` + + shuffle: True + total_batch_size: 16 # total batch size equals to 16 (4 * 4) + num_workers: 4 + + +val_dataloader: + dataset: + transforms: + ops: + - {type: Resize, size: [640, 640]} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + shuffle: False + total_batch_size: 32 + num_workers: 4 \ No newline at end of file diff --git a/trolo/configs/yaml/include/optimizer.yml b/trolo/configs/yaml/include/optimizer.yml new file mode 100644 index 0000000..189a9a1 --- /dev/null +++ b/trolo/configs/yaml/include/optimizer.yml @@ -0,0 +1,37 @@ + +use_amp: True +use_ema: True +ema: + type: ModelEMA + decay: 0.9999 + warmups: 2000 + + +epoches: 72 +clip_max_norm: 0.1 + + +optimizer: + type: AdamW + params: + - + params: '^(?=.*backbone)(?!.*norm).*$' + lr: 0.00001 + - + params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn)).*$' + weight_decay: 0. + + lr: 0.0001 + betas: [0.9, 0.999] + weight_decay: 0.0001 + + +lr_scheduler: + type: MultiStepLR + milestones: [1000] + gamma: 0.1 + + +lr_warmup_scheduler: + type: LinearWarmup + warmup_duration: 2000 \ No newline at end of file diff --git a/trolo/configs/yaml/rtdetrv2/base.yml b/trolo/configs/yaml/rtdetrv2/base.yml new file mode 100644 index 0000000..d7936f6 --- /dev/null +++ b/trolo/configs/yaml/rtdetrv2/base.yml @@ -0,0 +1,82 @@ +task: detection + +model: RTDETR +criterion: RTDETRCriterionv2 +postprocessor: RTDETRPostProcessor + + +use_focal_loss: True +eval_spatial_size: [640, 640] # h w + + +RTDETR: + backbone: PResNet + encoder: HybridEncoder + decoder: RTDETRTransformerv2 + + +PResNet: + depth: 50 + variant: d + freeze_at: 0 + return_idx: [1, 2, 3] + num_stages: 4 + freeze_norm: True + pretrained: True + + +HybridEncoder: + in_channels: [512, 1024, 2048] + feat_strides: [8, 16, 32] + + # intra + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + enc_act: 'gelu' + + # cross + expansion: 1.0 + depth_mult: 1 + act: 'silu' + + +RTDETRTransformerv2: + feat_channels: [256, 256, 256] + feat_strides: [8, 16, 32] + hidden_dim: 256 + num_levels: 3 + + num_layers: 6 + num_queries: 300 + + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 # 1.0 0.4 + + eval_idx: -1 + + # NEW + num_points: [4, 4, 4] # [3,3,3] [2,2,2] + cross_attn_method: default # default, discrete + query_select_method: default # default, agnostic + + +RTDETRPostProcessor: + num_top_queries: 300 + + +RTDETRCriterionv2: + weight_dict: {loss_vfl: 1, loss_bbox: 5, loss_giou: 2,} + losses: ['vfl', 'boxes', ] + alpha: 0.75 + gamma: 2.0 + + matcher: + type: HungarianMatcher + weight_dict: {cost_class: 2, cost_bbox: 5, cost_giou: 2} + alpha: 0.25 + gamma: 2.0 diff --git a/trolo/configs/yaml/rt-detrv2/rtdetrv2_s_coco.yml b/trolo/configs/yaml/rtdetrv2/rtdetrv2_s.yml similarity index 73% rename from trolo/configs/yaml/rt-detrv2/rtdetrv2_s_coco.yml rename to trolo/configs/yaml/rtdetrv2/rtdetrv2_s.yml index ec44690..5b52e6b 100644 --- a/trolo/configs/yaml/rt-detrv2/rtdetrv2_s_coco.yml +++ b/trolo/configs/yaml/rtdetrv2/rtdetrv2_s.yml @@ -1,13 +1,13 @@ __include__: [ - '../dataset/coco_detection.yml', + '../dataset/dummy_coco.yml', '../runtime.yml', - './include/dataloader.yml', - './include/optimizer.yml', - './include/rtdetrv2_r50vd.yml', + '../include/dataloader.yml', + '../include/optimizer.yml', + 'base.yml', ] -output_dir: ./output/rtdetrv2_s_coco +output_dir: ./output/rtdetrv2_r18vd_120e_coco PResNet: diff --git a/trolo/data/dataloader.py b/trolo/data/dataloader.py index fa7a4eb..ab451f5 100644 --- a/trolo/data/dataloader.py +++ b/trolo/data/dataloader.py @@ -89,10 +89,13 @@ def __init__( ema_restart_decay=0.9999, base_size=640, base_size_repeat=None, + scales=None, ) -> None: super().__init__() self.base_size = base_size - self.scales = generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None + self.scales = scales + if scales is None: + self.scales = generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000 self.ema_restart_decay = ema_restart_decay # self.interpolation = interpolation diff --git a/trolo/loaders/maps.py b/trolo/loaders/maps.py index b8a7de7..a702c12 100644 --- a/trolo/loaders/maps.py +++ b/trolo/loaders/maps.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Dict from trolo.models.dfine.maps import MODEL_CONFIG_MAP as DFINE_MODEL_CONFIG_MAP +from trolo.models.rtdetrv2.maps import MODEL_CONFIG_MAP as RTDETRV2_MODEL_CONFIG_MAP # Get package root directory PKG_ROOT = Path(__file__).parent.parent @@ -9,6 +10,7 @@ # Map of model names to their config files MODEL_CONFIG_MAP = { **DFINE_MODEL_CONFIG_MAP, + **RTDETRV2_MODEL_CONFIG_MAP, } diff --git a/trolo/models/__init__.py b/trolo/models/__init__.py index 473164e..874db02 100644 --- a/trolo/models/__init__.py +++ b/trolo/models/__init__.py @@ -1,2 +1,2 @@ from . import dfine -from . import rtdetr +from . import rtdetrv2 diff --git a/trolo/models/dfine/box_ops.py b/trolo/models/dfine/box_ops.py index 78b4761..87572ac 100644 --- a/trolo/models/dfine/box_ops.py +++ b/trolo/models/dfine/box_ops.py @@ -1,88 +1,3 @@ -import torch -from torch import Tensor -from torchvision.ops.boxes import box_area - -def box_cxcywh_to_xyxy(x): - x_c, y_c, w, h = x.unbind(-1) - b = [ - (x_c - 0.5 * w.clamp(min=0.0)), - (y_c - 0.5 * h.clamp(min=0.0)), - (x_c + 0.5 * w.clamp(min=0.0)), - (y_c + 0.5 * h.clamp(min=0.0)), - ] - return torch.stack(b, dim=-1) - - -def box_xyxy_to_cxcywh(x: Tensor) -> Tensor: - x0, y0, x1, y1 = x.unbind(-1) - b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] - return torch.stack(b, dim=-1) - - -# modified from torchvision to also return the union -def box_iou(boxes1: Tensor, boxes2: Tensor): - area1 = box_area(boxes1) - area2 = box_area(boxes2) - - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] - - wh = (rb - lt).clamp(min=0) # [N,M,2] - inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter - - iou = inter / union - return iou, union - - -def generalized_box_iou(boxes1, boxes2): - """ - Generalized IoU from https://giou.stanford.edu/ - - The boxes should be in [x0, y0, x1, y1] format - - Returns a [N, M] pairwise matrix, where N = len(boxes1) - and M = len(boxes2) - """ - # degenerate boxes gives inf / nan results - # so do an early check - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() - iou, union = box_iou(boxes1, boxes2) - - lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - - wh = (rb - lt).clamp(min=0) # [N,M,2] - area = wh[:, :, 0] * wh[:, :, 1] - - return iou - (area - union) / area - - -def masks_to_boxes(masks): - """Compute the bounding boxes around the provided masks - - The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. - - Returns a [N, 4] tensors, with the boxes in xyxy format - """ - if masks.numel() == 0: - return torch.zeros((0, 4), device=masks.device) - - h, w = masks.shape[-2:] - - y = torch.arange(0, h, dtype=torch.float) - x = torch.arange(0, w, dtype=torch.float) - y, x = torch.meshgrid(y, x) - - x_mask = masks * x.unsqueeze(0) - x_max = x_mask.flatten(1).max(-1)[0] - x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] - - y_mask = masks * y.unsqueeze(0) - y_max = y_mask.flatten(1).max(-1)[0] - y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] - - return torch.stack([x_min, y_min, x_max, y_max], 1) +## THIS IS TOTAL TECH DEBT +from trolo.utils.box_ops import * diff --git a/trolo/models/rtdetr/__init__.py b/trolo/models/rtdetr/__init__.py deleted file mode 100644 index 194e148..0000000 --- a/trolo/models/rtdetr/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .decoder import RTDETRTransformerv2 diff --git a/trolo/models/rtdetr/decoder.py b/trolo/models/rtdetr/decoder.py deleted file mode 100644 index 77021bd..0000000 --- a/trolo/models/rtdetr/decoder.py +++ /dev/null @@ -1,487 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.init as init -from collections import OrderedDict -from typing import List -import copy - -from trolo.loaders.registry import register -from ..dfine.denoising import get_contrastive_denoising_training_group, inverse_sigmoid -from ..dfine.dfine_decoder import MLP, bias_init_with_prob, MSDeformableAttention, get_activation - - -class TransformerDecoderLayer(nn.Module): - def __init__( - self, - d_model=256, - n_head=8, - dim_feedforward=1024, - dropout=0.0, - activation="relu", - n_levels=4, - n_points=4, - cross_attn_method="default", - ): - super(TransformerDecoderLayer, self).__init__() - - # self attention - self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True) - self.dropout1 = nn.Dropout(dropout) - self.norm1 = nn.LayerNorm(d_model) - - # cross attention - self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points, method=cross_attn_method) - self.dropout2 = nn.Dropout(dropout) - self.norm2 = nn.LayerNorm(d_model) - - # ffn - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.activation = get_activation(activation) - self.dropout3 = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - self.dropout4 = nn.Dropout(dropout) - self.norm3 = nn.LayerNorm(d_model) - - self._reset_parameters() - - def _reset_parameters(self): - init.xavier_uniform_(self.linear1.weight) - init.xavier_uniform_(self.linear2.weight) - - def with_pos_embed(self, tensor, pos): - return tensor if pos is None else tensor + pos - - def forward_ffn(self, tgt): - return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) - - def forward( - self, - target, - reference_points, - memory, - memory_spatial_shapes, - attn_mask=None, - memory_mask=None, - query_pos_embed=None, - ): - # self attention - q = k = self.with_pos_embed(target, query_pos_embed) - - target2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask) - target = target + self.dropout1(target2) - target = self.norm1(target) - - # cross attention - target2 = self.cross_attn( - self.with_pos_embed(target, query_pos_embed), reference_points, memory, memory_spatial_shapes, memory_mask - ) - target = target + self.dropout2(target2) - target = self.norm2(target) - - # ffn - target2 = self.forward_ffn(target) - target = target + self.dropout4(target2) - target = self.norm3(target) - - return target - - -class TransformerDecoder(nn.Module): - def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): - super(TransformerDecoder, self).__init__() - self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)]) - self.hidden_dim = hidden_dim - self.num_layers = num_layers - self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx - - def forward( - self, - target, - ref_points_unact, - memory, - memory_spatial_shapes, - bbox_head, - score_head, - query_pos_head, - attn_mask=None, - memory_mask=None, - ): - dec_out_bboxes = [] - dec_out_logits = [] - ref_points_detach = F.sigmoid(ref_points_unact) - - output = target - for i, layer in enumerate(self.layers): - ref_points_input = ref_points_detach.unsqueeze(2) - query_pos_embed = query_pos_head(ref_points_detach) - - output = layer( - output, ref_points_input, memory, memory_spatial_shapes, attn_mask, memory_mask, query_pos_embed - ) - - inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) - - if self.training: - dec_out_logits.append(score_head[i](output)) - if i == 0: - dec_out_bboxes.append(inter_ref_bbox) - else: - dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))) - - elif i == self.eval_idx: - dec_out_logits.append(score_head[i](output)) - dec_out_bboxes.append(inter_ref_bbox) - break - - ref_points = inter_ref_bbox - ref_points_detach = inter_ref_bbox.detach() - - return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits) - - -@register() -class RTDETRTransformerv2(nn.Module): - __share__ = ["num_classes", "eval_spatial_size"] - - def __init__( - self, - num_classes=80, - hidden_dim=256, - num_queries=300, - feat_channels=[512, 1024, 2048], - feat_strides=[8, 16, 32], - num_levels=3, - num_points=4, - nhead=8, - num_layers=6, - dim_feedforward=1024, - dropout=0.0, - activation="relu", - num_denoising=100, - label_noise_ratio=0.5, - box_noise_scale=1.0, - learn_query_content=False, - eval_spatial_size=None, - eval_idx=-1, - eps=1e-2, - aux_loss=True, - cross_attn_method="default", - query_select_method="default", - ): - super().__init__() - assert len(feat_channels) <= num_levels - assert len(feat_strides) == len(feat_channels) - - for _ in range(num_levels - len(feat_strides)): - feat_strides.append(feat_strides[-1] * 2) - - self.hidden_dim = hidden_dim - self.nhead = nhead - self.feat_strides = feat_strides - self.num_levels = num_levels - self.num_classes = num_classes - self.num_queries = num_queries - self.eps = eps - self.num_layers = num_layers - self.eval_spatial_size = eval_spatial_size - self.aux_loss = aux_loss - - assert query_select_method in ("default", "one2many", "agnostic"), "" - assert cross_attn_method in ("default", "discrete"), "" - self.cross_attn_method = cross_attn_method - self.query_select_method = query_select_method - - # backbone feature projection - self._build_input_proj_layer(feat_channels) - - # Transformer module - decoder_layer = TransformerDecoderLayer( - hidden_dim, - nhead, - dim_feedforward, - dropout, - activation, - num_levels, - num_points, - cross_attn_method=cross_attn_method, - ) - self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_layers, eval_idx) - - # denoising - self.num_denoising = num_denoising - self.label_noise_ratio = label_noise_ratio - self.box_noise_scale = box_noise_scale - if num_denoising > 0: - self.denoising_class_embed = nn.Embedding(num_classes + 1, hidden_dim, padding_idx=num_classes) - init.normal_(self.denoising_class_embed.weight[:-1]) - - # decoder embedding - self.learn_query_content = learn_query_content - if learn_query_content: - self.tgt_embed = nn.Embedding(num_queries, hidden_dim) - self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2) - - # if num_select_queries != self.num_queries: - # layer = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, activation='gelu') - # self.encoder = TransformerEncoder(layer, 1) - - self.enc_output = nn.Sequential( - OrderedDict( - [ - ("proj", nn.Linear(hidden_dim, hidden_dim)), - ( - "norm", - nn.LayerNorm( - hidden_dim, - ), - ), - ] - ) - ) - - if query_select_method == "agnostic": - self.enc_score_head = nn.Linear(hidden_dim, 1) - else: - self.enc_score_head = nn.Linear(hidden_dim, num_classes) - - self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) - - # decoder head - self.dec_score_head = nn.ModuleList([nn.Linear(hidden_dim, num_classes) for _ in range(num_layers)]) - self.dec_bbox_head = nn.ModuleList([MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers)]) - - # init encoder output anchors and valid_mask - if self.eval_spatial_size: - anchors, valid_mask = self._generate_anchors() - self.register_buffer("anchors", anchors) - self.register_buffer("valid_mask", valid_mask) - - self._reset_parameters() - - def _reset_parameters(self): - bias = bias_init_with_prob(0.01) - init.constant_(self.enc_score_head.bias, bias) - init.constant_(self.enc_bbox_head.layers[-1].weight, 0) - init.constant_(self.enc_bbox_head.layers[-1].bias, 0) - - for _cls, _reg in zip(self.dec_score_head, self.dec_bbox_head): - init.constant_(_cls.bias, bias) - init.constant_(_reg.layers[-1].weight, 0) - init.constant_(_reg.layers[-1].bias, 0) - - init.xavier_uniform_(self.enc_output[0].weight) - if self.learn_query_content: - init.xavier_uniform_(self.tgt_embed.weight) - init.xavier_uniform_(self.query_pos_head.layers[0].weight) - init.xavier_uniform_(self.query_pos_head.layers[1].weight) - for m in self.input_proj: - init.xavier_uniform_(m[0].weight) - - def _build_input_proj_layer(self, feat_channels): - self.input_proj = nn.ModuleList() - for in_channels in feat_channels: - self.input_proj.append( - nn.Sequential( - OrderedDict( - [ - ("conv", nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), - ( - "norm", - nn.BatchNorm2d( - self.hidden_dim, - ), - ), - ] - ) - ) - ) - - in_channels = feat_channels[-1] - - for _ in range(self.num_levels - len(feat_channels)): - self.input_proj.append( - nn.Sequential( - OrderedDict( - [ - ("conv", nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)), - ("norm", nn.BatchNorm2d(self.hidden_dim)), - ] - ) - ) - ) - in_channels = self.hidden_dim - - def _get_encoder_input(self, feats: List[torch.Tensor]): - # get projection features - proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] - if self.num_levels > len(proj_feats): - len_srcs = len(proj_feats) - for i in range(len_srcs, self.num_levels): - if i == len_srcs: - proj_feats.append(self.input_proj[i](feats[-1])) - else: - proj_feats.append(self.input_proj[i](proj_feats[-1])) - - # get encoder inputs - feat_flatten = [] - spatial_shapes = [] - for i, feat in enumerate(proj_feats): - _, _, h, w = feat.shape - # [b, c, h, w] -> [b, h*w, c] - feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) - # [num_levels, 2] - spatial_shapes.append([h, w]) - # [b, l, c] - feat_flatten = torch.concat(feat_flatten, 1) - return feat_flatten, spatial_shapes - - def _generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu"): - if spatial_shapes is None: - spatial_shapes = [] - eval_h, eval_w = self.eval_spatial_size - for s in self.feat_strides: - spatial_shapes.append([int(eval_h / s), int(eval_w / s)]) - - anchors = [] - for lvl, (h, w) in enumerate(spatial_shapes): - grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") - grid_xy = torch.stack([grid_x, grid_y], dim=-1) - grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype) - wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl) - lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4) - anchors.append(lvl_anchors) - - anchors = torch.concat(anchors, dim=1).to(device) - valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) - anchors = torch.log(anchors / (1 - anchors)) - anchors = torch.where(valid_mask, anchors, torch.inf) - - return anchors, valid_mask - - def _get_decoder_input( - self, memory: torch.Tensor, spatial_shapes, denoising_logits=None, denoising_bbox_unact=None - ): - # prepare input for decoder - if self.training or self.eval_spatial_size is None: - anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device) - else: - anchors = self.anchors - valid_mask = self.valid_mask - - # memory = torch.where(valid_mask, memory, 0) - # TODO fix type error for onnx export - memory = valid_mask.to(memory.dtype) * memory - - output_memory: torch.Tensor = self.enc_output(memory) - enc_outputs_logits: torch.Tensor = self.enc_score_head(output_memory) - enc_outputs_coord_unact: torch.Tensor = self.enc_bbox_head(output_memory) + anchors - - enc_topk_bboxes_list, enc_topk_logits_list = [], [] - enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = self._select_topk( - output_memory, enc_outputs_logits, enc_outputs_coord_unact, self.num_queries - ) - - if self.training: - enc_topk_bboxes = F.sigmoid(enc_topk_bbox_unact) - enc_topk_bboxes_list.append(enc_topk_bboxes) - enc_topk_logits_list.append(enc_topk_logits) - - # if self.num_select_queries != self.num_queries: - # raise NotImplementedError('') - - if self.learn_query_content: - content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1]) - else: - content = enc_topk_memory.detach() - - enc_topk_bbox_unact = enc_topk_bbox_unact.detach() - - if denoising_bbox_unact is not None: - enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) - content = torch.concat([denoising_logits, content], dim=1) - - return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list - - def _select_topk( - self, memory: torch.Tensor, outputs_logits: torch.Tensor, outputs_coords_unact: torch.Tensor, topk: int - ): - if self.query_select_method == "default": - _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1) - - elif self.query_select_method == "one2many": - _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1) - topk_ind = topk_ind // self.num_classes - - elif self.query_select_method == "agnostic": - _, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1) - - topk_ind: torch.Tensor - - topk_coords = outputs_coords_unact.gather( - dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1]) - ) - - topk_logits = outputs_logits.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1])) - - topk_memory = memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1])) - - return topk_memory, topk_logits, topk_coords - - def forward(self, feats, targets=None): - # input projection and embedding - memory, spatial_shapes = self._get_encoder_input(feats) - - # prepare denoising training - if self.training and self.num_denoising > 0: - denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = get_contrastive_denoising_training_group( - targets, - self.num_classes, - self.num_queries, - self.denoising_class_embed, - num_denoising=self.num_denoising, - label_noise_ratio=self.label_noise_ratio, - box_noise_scale=self.box_noise_scale, - ) - else: - denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None - - init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = self._get_decoder_input( - memory, spatial_shapes, denoising_logits, denoising_bbox_unact - ) - - # decoder - out_bboxes, out_logits = self.decoder( - init_ref_contents, - init_ref_points_unact, - memory, - spatial_shapes, - self.dec_bbox_head, - self.dec_score_head, - self.query_pos_head, - attn_mask=attn_mask, - ) - - if self.training and dn_meta is not None: - dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta["dn_num_split"], dim=2) - dn_out_logits, out_logits = torch.split(out_logits, dn_meta["dn_num_split"], dim=2) - - out = {"pred_logits": out_logits[-1], "pred_boxes": out_bboxes[-1]} - - if self.training and self.aux_loss: - out["aux_outputs"] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1]) - out["enc_aux_outputs"] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list) - out["enc_meta"] = {"class_agnostic": self.query_select_method == "agnostic"} - - if dn_meta is not None: - out["dn_aux_outputs"] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) - out["dn_meta"] = dn_meta - - return out - - @torch.jit.unused - def _set_aux_loss(self, outputs_class, outputs_coord): - # this is a workaround to make torchscript happy, as torchscript - # doesn't support dictionary with non-homogeneous values, such - # as a dict having both a Tensor and a list. - return [{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] diff --git a/trolo/models/rtdetrv2/__init__.py b/trolo/models/rtdetrv2/__init__.py new file mode 100644 index 0000000..5faa58e --- /dev/null +++ b/trolo/models/rtdetrv2/__init__.py @@ -0,0 +1,4 @@ +from .decoder import RTDETRTransformerv2 +from .rtdetr import RTDETR +from .criterion import RTDETRCriterionv2 +from .preprocessor import RTDETRPostProcessor diff --git a/trolo/models/rtdetrv2/criterion.py b/trolo/models/rtdetrv2/criterion.py new file mode 100644 index 0000000..074f9b0 --- /dev/null +++ b/trolo/models/rtdetrv2/criterion.py @@ -0,0 +1,265 @@ +"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.nn as nn +import torch.distributed +import torch.nn.functional as F +import torchvision + +import copy + +from trolo.utils.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou +from trolo.utils.dist_utils import get_world_size, is_dist_available_and_initialized +from trolo.loaders import register + + +@register() +class RTDETRCriterionv2(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + __share__ = ['num_classes', ] + __inject__ = ['matcher', ] + + def __init__(self, \ + matcher, + weight_dict, + losses, + alpha=0.2, + gamma=2.0, + num_classes=80, + boxes_weight_format=None, + share_matched_indices=False): + """Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + num_classes: number of object categories, omitting the special no-object category + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + boxes_weight_format: format for boxes weight (iou, ) + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.boxes_weight_format = boxes_weight_format + self.share_matched_indices = share_matched_indices + self.alpha = alpha + self.gamma = gamma + + def loss_labels_focal(self, outputs, targets, indices, num_boxes): + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes+1)[..., :-1] + loss = torchvision.ops.sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma, reduction='none') + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + + return {'loss_focal': loss} + + def loss_labels_vfl(self, outputs, targets, indices, num_boxes, values=None): + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + if values is None: + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)) + ious = torch.diag(ious).detach() + else: + ious = values + + src_logits = outputs['pred_logits'] + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_o[idx] = ious.to(target_score_o.dtype) + target_score = target_score_o.unsqueeze(-1) * target + + pred_score = F.sigmoid(src_logits).detach() + weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score + + loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none') + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {'loss_vfl': loss} + + def loss_boxes(self, outputs, targets, indices, num_boxes, boxes_weight=None): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + losses = {} + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(generalized_box_iou(\ + box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))) + loss_giou = loss_giou if boxes_weight is None else loss_giou * boxes_weight + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'boxes': self.loss_boxes, + 'focal': self.loss_labels_focal, + 'vfl': self.loss_labels_vfl, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets, **kwargs): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k} + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_available_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Retrieve the matching between the outputs of the last layer and the targets + matched = self.matcher(outputs_without_aux, targets) + indices = matched['indices'] + + # Compute all the requested losses + losses = {} + for loss in self.losses: + meta = self.get_loss_meta_info(loss, outputs, targets, indices) + l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes, **meta) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + losses.update(l_dict) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + if not self.share_matched_indices: + matched = self.matcher(aux_outputs, targets) + indices = matched['indices'] + for loss in self.losses: + meta = self.get_loss_meta_info(loss, aux_outputs, targets, indices) + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **meta) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of cdn auxiliary losses. For rtdetr + if 'dn_aux_outputs' in outputs: + assert 'dn_meta' in outputs, '' + indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets) + dn_num_boxes = num_boxes * outputs['dn_meta']['dn_num_group'] + for i, aux_outputs in enumerate(outputs['dn_aux_outputs']): + for loss in self.losses: + meta = self.get_loss_meta_info(loss, aux_outputs, targets, indices) + l_dict = self.get_loss(loss, aux_outputs, targets, indices, dn_num_boxes, **meta) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of encoder auxiliary losses. For rtdetr v2 + if 'enc_aux_outputs' in outputs: + assert 'enc_meta' in outputs, '' + class_agnostic = outputs['enc_meta']['class_agnostic'] + if class_agnostic: + orig_num_classes = self.num_classes + self.num_classes = 1 + enc_targets = copy.deepcopy(targets) + for t in enc_targets: + t['labels'] = torch.zeros_like(t["labels"]) + else: + enc_targets = targets + + for i, aux_outputs in enumerate(outputs['enc_aux_outputs']): + matched = self.matcher(aux_outputs, targets) + indices = matched['indices'] + for loss in self.losses: + meta = self.get_loss_meta_info(loss, aux_outputs, enc_targets, indices) + l_dict = self.get_loss(loss, aux_outputs, enc_targets, indices, num_boxes, **meta) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f'_enc_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if class_agnostic: + self.num_classes = orig_num_classes + + return losses + + def get_loss_meta_info(self, loss, outputs, targets, indices): + if self.boxes_weight_format is None: + return {} + + src_boxes = outputs['pred_boxes'][self._get_src_permutation_idx(indices)] + target_boxes = torch.cat([t['boxes'][j] for t, (_, j) in zip(targets, indices)], dim=0) + + if self.boxes_weight_format == 'iou': + iou, _ = box_iou(box_cxcywh_to_xyxy(src_boxes.detach()), box_cxcywh_to_xyxy(target_boxes)) + iou = torch.diag(iou) + elif self.boxes_weight_format == 'giou': + iou = torch.diag(generalized_box_iou(\ + box_cxcywh_to_xyxy(src_boxes.detach()), box_cxcywh_to_xyxy(target_boxes))) + else: + raise AttributeError() + + if loss in ('boxes', ): + meta = {'boxes_weight': iou} + elif loss in ('vfl', ): + meta = {'values': iou} + else: + meta = {} + + return meta + + @staticmethod + def get_cdn_matched_indices(dn_meta, targets): + """get_cdn_matched_indices + """ + dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] + num_gts = [len(t['labels']) for t in targets] + device = targets[0]['labels'].device + + dn_match_indices = [] + for i, num_gt in enumerate(num_gts): + if num_gt > 0: + gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device) + gt_idx = gt_idx.tile(dn_num_group) + assert len(dn_positive_idx[i]) == len(gt_idx) + dn_match_indices.append((dn_positive_idx[i], gt_idx)) + else: + dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \ + torch.zeros(0, dtype=torch.int64, device=device))) + + return dn_match_indices \ No newline at end of file diff --git a/trolo/models/rtdetrv2/decoder.py b/trolo/models/rtdetrv2/decoder.py new file mode 100644 index 0000000..6cdf3c1 --- /dev/null +++ b/trolo/models/rtdetrv2/decoder.py @@ -0,0 +1,710 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from collections import OrderedDict +from typing import List +import copy +import functools +import math + +from trolo.loaders.registry import register +from ..dfine.denoising import get_contrastive_denoising_training_group, inverse_sigmoid +from ..dfine.dfine_decoder import bias_init_with_prob + + +__all__ = ['RTDETRTransformerv2'] + + +def deformable_attention_core_func_v2(\ + value: torch.Tensor, + value_spatial_shapes, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, + num_points_list: List[int], + method='default'): + """ + Args: + value (Tensor): [bs, value_length, n_head, c] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] + sampling_locations (Tensor): [bs, query_length, n_head, n_levels * n_points, 2] + attention_weights (Tensor): [bs, query_length, n_head, n_levels * n_points] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, _, n_head, c = value.shape + _, Len_q, _, _, _ = sampling_locations.shape + + split_shape = [h * w for h, w in value_spatial_shapes] + value_list = value.permute(0, 2, 3, 1).flatten(0, 1).split(split_shape, dim=-1) + + # sampling_offsets [8, 480, 8, 12, 2] + if method == 'default': + sampling_grids = 2 * sampling_locations - 1 + + elif method == 'discrete': + sampling_grids = sampling_locations + + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_locations_list = sampling_grids.split(num_points_list, dim=-2) + + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes): + value_l = value_list[level].reshape(bs * n_head, c, h, w) + sampling_grid_l: torch.Tensor = sampling_locations_list[level] + + if method == 'default': + sampling_value_l = F.grid_sample( + value_l, + sampling_grid_l, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + + elif method == 'discrete': + # n * m, seq, n, 2 + sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value.device) + 0.5).to(torch.int64) + + # FIX ME? for rectangle input + sampling_coord = sampling_coord.clamp(0, h - 1) + sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2) + + s_idx = torch.arange(sampling_coord.shape[0], device=value.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1]) + sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c + + sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level]) + + sampling_value_list.append(sampling_value_l) + + attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list)) + weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights + output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q) + + return output.permute(0, 2, 1) + + +def get_activation(act: str, inpace: bool=True): + """get activation + """ + if act is None: + return nn.Identity() + + elif isinstance(act, nn.Module): + return act + + act = act.lower() + + if act == 'silu' or act == 'swish': + m = nn.SiLU() + + elif act == 'relu': + m = nn.ReLU() + + elif act == 'leaky_relu': + m = nn.LeakyReLU() + + elif act == 'silu': + m = nn.SiLU() + + elif act == 'gelu': + m = nn.GELU() + + elif act == 'hardsigmoid': + m = nn.Hardsigmoid() + + else: + raise RuntimeError('') + + if hasattr(m, 'inplace'): + m.inplace = inpace + + return m + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.act = get_activation(act) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MSDeformableAttention(nn.Module): + def __init__( + self, + embed_dim=256, + num_heads=8, + num_levels=4, + num_points=4, + method='default', + offset_scale=0.5, + ): + """Multi-Scale Deformable Attention + """ + super(MSDeformableAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_levels = num_levels + self.offset_scale = offset_scale + + if isinstance(num_points, list): + assert len(num_points) == num_levels, '' + num_points_list = num_points + else: + num_points_list = [num_points for _ in range(num_levels)] + + self.num_points_list = num_points_list + + num_points_scale = [1/n for n in num_points_list for _ in range(n)] + self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32)) + + self.total_points = num_heads * sum(num_points_list) + self.method = method + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2) + self.attention_weights = nn.Linear(embed_dim, self.total_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + self.ms_deformable_attn_core = functools.partial(deformable_attention_core_func_v2, method=self.method) + + self._reset_parameters() + + if method == 'discrete': + for p in self.sampling_offsets.parameters(): + p.requires_grad = False + + def _reset_parameters(self): + # sampling_offsets + init.constant_(self.sampling_offsets.weight, 0) + thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1) + grid_init *= scaling + self.sampling_offsets.bias.data[...] = grid_init.flatten() + + # attention_weights + init.constant_(self.attention_weights.weight, 0) + init.constant_(self.attention_weights.bias, 0) + + # proj + init.xavier_uniform_(self.value_proj.weight) + init.constant_(self.value_proj.bias, 0) + init.xavier_uniform_(self.output_proj.weight) + init.constant_(self.output_proj.bias, 0) + + + def forward(self, + query: torch.Tensor, + reference_points: torch.Tensor, + value: torch.Tensor, + value_spatial_shapes: List[int], + value_mask: torch.Tensor=None): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + + value = self.value_proj(value) + if value_mask is not None: + value = value * value_mask.to(value.dtype).unsqueeze(-1) + + value = value.reshape(bs, Len_v, self.num_heads, self.head_dim) + + sampling_offsets: torch.Tensor = self.sampling_offsets(query) + sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, sum(self.num_points_list), 2) + + attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list)) + attention_weights = F.softmax(attention_weights, dim=-1).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list)) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.tensor(value_spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2) + sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer + elif reference_points.shape[-1] == 4: + # reference_points [8, 480, None, 1, 4] + # sampling_offsets [8, 480, 8, 12, 2] + num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1) + offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.". + format(reference_points.shape[-1])) + + output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list) + + output = self.output_proj(output) + + return output + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0., + activation='relu', + n_levels=4, + n_points=4, + cross_attn_method='default'): + super(TransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # cross attention + self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points, method=cross_attn_method) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = get_activation(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + self._reset_parameters() + + def _reset_parameters(self): + init.xavier_uniform_(self.linear1.weight) + init.xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + def forward(self, + target, + reference_points, + memory, + memory_spatial_shapes, + attn_mask=None, + memory_mask=None, + query_pos_embed=None): + # self attention + q = k = self.with_pos_embed(target, query_pos_embed) + + target2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask) + target = target + self.dropout1(target2) + target = self.norm1(target) + + # cross attention + target2 = self.cross_attn(\ + self.with_pos_embed(target, query_pos_embed), + reference_points, + memory, + memory_spatial_shapes, + memory_mask) + target = target + self.dropout2(target2) + target = self.norm2(target) + + # ffn + target2 = self.forward_ffn(target) + target = target + self.dropout4(target2) + target = self.norm3(target) + + return target + + +class TransformerDecoder(nn.Module): + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + super(TransformerDecoder, self).__init__() + self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)]) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + def forward(self, + target, + ref_points_unact, + memory, + memory_spatial_shapes, + bbox_head, + score_head, + query_pos_head, + attn_mask=None, + memory_mask=None): + dec_out_bboxes = [] + dec_out_logits = [] + ref_points_detach = F.sigmoid(ref_points_unact) + + output = target + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = query_pos_head(ref_points_detach) + + output = layer(output, ref_points_input, memory, memory_spatial_shapes, attn_mask, memory_mask, query_pos_embed) + + inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) + + if self.training: + dec_out_logits.append(score_head[i](output)) + if i == 0: + dec_out_bboxes.append(inter_ref_bbox) + else: + dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))) + + elif i == self.eval_idx: + dec_out_logits.append(score_head[i](output)) + dec_out_bboxes.append(inter_ref_bbox) + break + + ref_points = inter_ref_bbox + ref_points_detach = inter_ref_bbox.detach() + + return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits) + + +@register() +class RTDETRTransformerv2(nn.Module): + __share__ = ['num_classes', 'eval_spatial_size'] + + def __init__(self, + num_classes=80, + hidden_dim=256, + num_queries=300, + feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_levels=3, + num_points=4, + nhead=8, + num_layers=6, + dim_feedforward=1024, + dropout=0., + activation="relu", + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_query_content=False, + eval_spatial_size=None, + eval_idx=-1, + eps=1e-2, + aux_loss=True, + cross_attn_method='default', + query_select_method='default'): + super().__init__() + assert len(feat_channels) <= num_levels + assert len(feat_strides) == len(feat_channels) + + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_layers = num_layers + self.eval_spatial_size = eval_spatial_size + self.aux_loss = aux_loss + + assert query_select_method in ('default', 'one2many', 'agnostic'), '' + assert cross_attn_method in ('default', 'discrete'), '' + self.cross_attn_method = cross_attn_method + self.query_select_method = query_select_method + + # backbone feature projection + self._build_input_proj_layer(feat_channels) + + # Transformer module + decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, \ + activation, num_levels, num_points, cross_attn_method=cross_attn_method) + self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_layers, eval_idx) + + # denoising + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + if num_denoising > 0: + self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes) + init.normal_(self.denoising_class_embed.weight[:-1]) + + # decoder embedding + self.learn_query_content = learn_query_content + if learn_query_content: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2) + + # if num_select_queries != self.num_queries: + # layer = TransformerEncoderLayer(hidden_dim, nhead, dim_feedforward, activation='gelu') + # self.encoder = TransformerEncoder(layer, 1) + + self.enc_output = nn.Sequential(OrderedDict([ + ('proj', nn.Linear(hidden_dim, hidden_dim)), + ('norm', nn.LayerNorm(hidden_dim,)), + ])) + + if query_select_method == 'agnostic': + self.enc_score_head = nn.Linear(hidden_dim, 1) + else: + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3) + + # decoder head + self.dec_score_head = nn.ModuleList([ + nn.Linear(hidden_dim, num_classes) for _ in range(num_layers) + ]) + self.dec_bbox_head = nn.ModuleList([ + MLP(hidden_dim, hidden_dim, 4, 3) for _ in range(num_layers) + ]) + + # init encoder output anchors and valid_mask + if self.eval_spatial_size: + anchors, valid_mask = self._generate_anchors() + self.register_buffer('anchors', anchors) + self.register_buffer('valid_mask', valid_mask) + + self._reset_parameters() + + def _reset_parameters(self): + bias = bias_init_with_prob(0.01) + init.constant_(self.enc_score_head.bias, bias) + init.constant_(self.enc_bbox_head.layers[-1].weight, 0) + init.constant_(self.enc_bbox_head.layers[-1].bias, 0) + + for _cls, _reg in zip(self.dec_score_head, self.dec_bbox_head): + init.constant_(_cls.bias, bias) + init.constant_(_reg.layers[-1].weight, 0) + init.constant_(_reg.layers[-1].bias, 0) + + init.xavier_uniform_(self.enc_output[0].weight) + if self.learn_query_content: + init.xavier_uniform_(self.tgt_embed.weight) + init.xavier_uniform_(self.query_pos_head.layers[0].weight) + init.xavier_uniform_(self.query_pos_head.layers[1].weight) + for m in self.input_proj: + init.xavier_uniform_(m[0].weight) + + def _build_input_proj_layer(self, feat_channels): + self.input_proj = nn.ModuleList() + for in_channels in feat_channels: + self.input_proj.append( + nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), + ('norm', nn.BatchNorm2d(self.hidden_dim,))]) + ) + ) + + in_channels = feat_channels[-1] + + for _ in range(self.num_levels - len(feat_channels)): + self.input_proj.append( + nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)), + ('norm', nn.BatchNorm2d(self.hidden_dim))]) + ) + ) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats: List[torch.Tensor]): + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + for i, feat in enumerate(proj_feats): + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) + # [num_levels, 2] + spatial_shapes.append([h, w]) + # [b, l, c] + feat_flatten = torch.concat(feat_flatten, 1) + return feat_flatten, spatial_shapes + + def _generate_anchors(self, + spatial_shapes=None, + grid_size=0.05, + dtype=torch.float32, + device='cpu'): + if spatial_shapes is None: + spatial_shapes = [] + eval_h, eval_w = self.eval_spatial_size + for s in self.feat_strides: + spatial_shapes.append([int(eval_h / s), int(eval_w / s)]) + + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') + grid_xy = torch.stack([grid_x, grid_y], dim=-1) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / torch.tensor([w, h], dtype=dtype) + wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl) + lvl_anchors = torch.concat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4) + anchors.append(lvl_anchors) + + anchors = torch.concat(anchors, dim=1).to(device) + valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + anchors = torch.where(valid_mask, anchors, torch.inf) + + return anchors, valid_mask + + + def _get_decoder_input(self, + memory: torch.Tensor, + spatial_shapes, + denoising_logits=None, + denoising_bbox_unact=None): + + # prepare input for decoder + if self.training or self.eval_spatial_size is None: + anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device) + else: + anchors = self.anchors + valid_mask = self.valid_mask + + # memory = torch.where(valid_mask, memory, 0) + # TODO fix type error for onnx export + memory = valid_mask.to(memory.dtype) * memory + + output_memory :torch.Tensor = self.enc_output(memory) + enc_outputs_logits :torch.Tensor = self.enc_score_head(output_memory) + enc_outputs_coord_unact :torch.Tensor = self.enc_bbox_head(output_memory) + anchors + + enc_topk_bboxes_list, enc_topk_logits_list = [], [] + enc_topk_memory, enc_topk_logits, enc_topk_bbox_unact = \ + self._select_topk(output_memory, enc_outputs_logits, enc_outputs_coord_unact, self.num_queries) + + if self.training: + enc_topk_bboxes = F.sigmoid(enc_topk_bbox_unact) + enc_topk_bboxes_list.append(enc_topk_bboxes) + enc_topk_logits_list.append(enc_topk_logits) + + # if self.num_select_queries != self.num_queries: + # raise NotImplementedError('') + + if self.learn_query_content: + content = self.tgt_embed.weight.unsqueeze(0).tile([memory.shape[0], 1, 1]) + else: + content = enc_topk_memory.detach() + + enc_topk_bbox_unact = enc_topk_bbox_unact.detach() + + if denoising_bbox_unact is not None: + enc_topk_bbox_unact = torch.concat([denoising_bbox_unact, enc_topk_bbox_unact], dim=1) + content = torch.concat([denoising_logits, content], dim=1) + + return content, enc_topk_bbox_unact, enc_topk_bboxes_list, enc_topk_logits_list + + def _select_topk(self, memory: torch.Tensor, outputs_logits: torch.Tensor, outputs_coords_unact: torch.Tensor, topk: int): + if self.query_select_method == 'default': + _, topk_ind = torch.topk(outputs_logits.max(-1).values, topk, dim=-1) + + elif self.query_select_method == 'one2many': + _, topk_ind = torch.topk(outputs_logits.flatten(1), topk, dim=-1) + topk_ind = topk_ind // self.num_classes + + elif self.query_select_method == 'agnostic': + _, topk_ind = torch.topk(outputs_logits.squeeze(-1), topk, dim=-1) + + topk_ind: torch.Tensor + + topk_coords = outputs_coords_unact.gather(dim=1, \ + index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_coords_unact.shape[-1])) + + topk_logits = outputs_logits.gather(dim=1, \ + index=topk_ind.unsqueeze(-1).repeat(1, 1, outputs_logits.shape[-1])) + + topk_memory = memory.gather(dim=1, \ + index=topk_ind.unsqueeze(-1).repeat(1, 1, memory.shape[-1])) + + return topk_memory, topk_logits, topk_coords + + + def forward(self, feats, targets=None): + # input projection and embedding + memory, spatial_shapes = self._get_encoder_input(feats) + + # prepare denoising training + if self.training and self.num_denoising > 0: + denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = \ + get_contrastive_denoising_training_group(targets, \ + self.num_classes, + self.num_queries, + self.denoising_class_embed, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, ) + else: + denoising_logits, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None + + init_ref_contents, init_ref_points_unact, enc_topk_bboxes_list, enc_topk_logits_list = \ + self._get_decoder_input(memory, spatial_shapes, denoising_logits, denoising_bbox_unact) + + # decoder + out_bboxes, out_logits = self.decoder( + init_ref_contents, + init_ref_points_unact, + memory, + spatial_shapes, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask) + + if self.training and dn_meta is not None: + dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2) + dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2) + + out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]} + + if self.training and self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1]) + out['enc_aux_outputs'] = self._set_aux_loss(enc_topk_logits_list, enc_topk_bboxes_list) + out['enc_meta'] = {'class_agnostic': self.query_select_method == 'agnostic'} + + if dn_meta is not None: + out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) + out['dn_meta'] = dn_meta + + return out + + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class, outputs_coord)] \ No newline at end of file diff --git a/trolo/models/rtdetrv2/maps.py b/trolo/models/rtdetrv2/maps.py new file mode 100644 index 0000000..5061ced --- /dev/null +++ b/trolo/models/rtdetrv2/maps.py @@ -0,0 +1,13 @@ +from pathlib import Path + +MODEL_CONFIG_MAP = { + + # Automatically map all yml files from dfine config directory + **{ + Path(f).name: str( + Path("rtdetrv2") / Path(f).relative_to(Path(__file__).parent.parent.parent / "configs" / "yaml" / "rtdetrv2") + ) + for f in (Path(__file__).parent.parent.parent / "configs" / "yaml" / "rtdetrv2").rglob("*.yml") + if not Path(f).name.startswith("_") # Skip include files + }, +} diff --git a/trolo/models/rtdetrv2/preprocessor.py b/trolo/models/rtdetrv2/preprocessor.py new file mode 100644 index 0000000..e29eeca --- /dev/null +++ b/trolo/models/rtdetrv2/preprocessor.py @@ -0,0 +1,93 @@ +"""Copyright(c) 2023 lyuwenyu. All Rights Reserved. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchvision + +from trolo.loaders import register + + +__all__ = ['RTDETRPostProcessor'] + + +def mod(a, b): + out = a - a // b * b + return out + + +@register() +class RTDETRPostProcessor(nn.Module): + __share__ = [ + 'num_classes', + 'use_focal_loss', + 'num_top_queries', + 'remap_mscoco_category' + ] + + def __init__( + self, + num_classes=80, + use_focal_loss=True, + num_top_queries=300, + remap_mscoco_category=False + ) -> None: + super().__init__() + self.use_focal_loss = use_focal_loss + self.num_top_queries = num_top_queries + self.num_classes = int(num_classes) + self.remap_mscoco_category = remap_mscoco_category + self.deploy_mode = False + + def extra_repr(self) -> str: + return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' + + # def forward(self, outputs, orig_target_sizes): + def forward(self, outputs, orig_target_sizes: torch.Tensor): + logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] + # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + + bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') + bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) + + if self.use_focal_loss: + scores = F.sigmoid(logits) + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) + # TODO for older tensorrt + # labels = index % self.num_classes + labels = mod(index, self.num_classes) + index = index // self.num_classes + boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) + + else: + scores = F.softmax(logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > self.num_top_queries: + scores, index = torch.topk(scores, self.num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + # TODO for onnx export + if self.deploy_mode: + return labels, boxes, scores + + # TODO + if self.remap_mscoco_category: + from ...data.dataset import mscoco_label2category + labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ + .to(boxes.device).reshape(labels.shape) + + results = [] + for lab, box, sco in zip(labels, boxes, scores): + result = dict(labels=lab, boxes=box, scores=sco) + results.append(result) + + return results + + + def deploy(self, ): + self.eval() + self.deploy_mode = True + return self \ No newline at end of file diff --git a/trolo/models/rtdetrv2/rtdetr.py b/trolo/models/rtdetrv2/rtdetr.py new file mode 100644 index 0000000..1bf856a --- /dev/null +++ b/trolo/models/rtdetrv2/rtdetr.py @@ -0,0 +1,8 @@ +from trolo.loaders import register +from ..dfine.dfine import DFINE + +@register() +class RTDETR(DFINE): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + diff --git a/trolo/utils/box_ops.py b/trolo/utils/box_ops.py index 333b949..e3de0f4 100644 --- a/trolo/utils/box_ops.py +++ b/trolo/utils/box_ops.py @@ -4,6 +4,8 @@ import torchvision from torch import Tensor import supervision as sv +from torchvision.ops.boxes import box_area + def to_sv(results: Dict) -> sv.Detections: @@ -112,3 +114,89 @@ def point_distance_box(points: Tensor, distances: Tensor) -> Tensor: x2y2 = rb + points boxes = torch.concat([x1y1, x2y2], dim=-1) return boxes + + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [ + (x_c - 0.5 * w.clamp(min=0.0)), + (y_c - 0.5 * h.clamp(min=0.0)), + (x_c + 0.5 * w.clamp(min=0.0)), + (y_c + 0.5 * h.clamp(min=0.0)), + ] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x: Tensor) -> Tensor: + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1: Tensor, boxes2: Tensor): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = masks * x.unsqueeze(0) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = masks * y.unsqueeze(0) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1)