diff --git a/README.md b/README.md index b4666332..f058824b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,14 @@ +### Forked Detrex branch + +This is a forked branch from detrex to make the maskdino project installable for python using `pip install .` +from projects/maskdino folder +This maskdino-project is also uploaded to devpi. + +Originally using the maskdino git repository but with different versions of config using omegaconf the maskdino repo +and maskdino repo inside detrex are quite different and ran into training problems. + +Organised the maskdino param properties to be able to change them and make them more consistent. +
diff --git a/projects/maskdino/configs/models/maskdino_r50.py b/projects/maskdino/configs/models/maskdino_r50.py index 89cbe5b1..7a08fe8f 100644 --- a/projects/maskdino/configs/models/maskdino_r50.py +++ b/projects/maskdino/configs/models/maskdino_r50.py @@ -3,24 +3,42 @@ from detrex.modeling.backbone import ResNet, BasicStem from detectron2.config import LazyCall as L - -from projects.maskdino.modeling.meta_arch.maskdino_head import MaskDINOHead -from projects.maskdino.modeling.pixel_decoder.maskdino_encoder import MaskDINOEncoder -from projects.maskdino.modeling.transformer_decoder.maskdino_decoder import MaskDINODecoder -from projects.maskdino.modeling.criterion import SetCriterion -from projects.maskdino.modeling.matcher import HungarianMatcher -from projects.maskdino.maskdino import MaskDINO from detectron2.data import MetadataCatalog from detectron2.layers import Conv2d, ShapeSpec, get_norm +from omegaconf import OmegaConf + +from ...modeling.meta_arch.maskdino_head import MaskDINOHead +from ...modeling.pixel_decoder.maskdino_encoder import MaskDINOEncoder +from ...modeling.transformer_decoder.maskdino_decoder import MaskDINODecoder +from ...modeling.weighted_criterion import WeightedCriterion +from ...modeling.matcher import HungarianMatcher +from ...maskdino import MaskDINO + -dim=256 -n_class=80 -dn="seg" -dec_layers = 9 -input_shape={'res2': ShapeSpec(channels=256, height=None, width=None, stride=4), 'res3': ShapeSpec(channels=512, height=None, width=None, stride=8), 'res4': ShapeSpec(channels=1024, height=None, width=None, stride=16), 'res5': ShapeSpec(channels=2048, height=None, width=None, stride=32)} model = L(MaskDINO)( + # parameters in one place. + params=OmegaConf.create(dict( + input_shape={ + 'res2': L(ShapeSpec)(channels=256, height=None, width=None, stride=4), + 'res3': L(ShapeSpec)(channels=512, height=None, width=None, stride=8), + 'res4': L(ShapeSpec)(channels=1024, height=None, width=None, stride=16), + 'res5': L(ShapeSpec)(channels=2048, height=None, width=None, stride=32) + }, + dim=256, + hidden_dim=256, + query_dim=4, + num_classes=80, + dec_layers=9, + enc_layers=6, + feed_forward=2048, + n_heads=8, + num_queries=300, + dn_num=100, + dn_mode="seg", + show_weights=True + )), backbone=L(ResNet)( stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), stages=L(ResNet.make_default_stages)( @@ -32,16 +50,16 @@ freeze_at=1, ), sem_seg_head=L(MaskDINOHead)( - input_shape=input_shape, - num_classes=n_class, + input_shape="${..params.input_shape}", + num_classes="${..params.num_classes}", pixel_decoder=L(MaskDINOEncoder)( - input_shape=input_shape, + input_shape="${...params.input_shape}", transformer_dropout=0.0, - transformer_nheads=8, - transformer_dim_feedforward=2048, - transformer_enc_layers=6, - conv_dim=dim, - mask_dim=dim, + transformer_nheads="${...params.n_heads}", + transformer_dim_feedforward="${...params.feed_forward}", + transformer_enc_layers="${...params.enc_layers}", + conv_dim="${...params.dim}", + mask_dim="${...params.dim}", norm = 'GN', transformer_in_features=['res3', 'res4', 'res5'], common_stride=4, @@ -52,35 +70,35 @@ loss_weight= 1.0, ignore_value= -1, transformer_predictor=L(MaskDINODecoder)( - in_channels=dim, + in_channels="${...params.dim}", mask_classification=True, - num_classes="${..num_classes}", - hidden_dim=dim, - num_queries=300, - nheads=8, - dim_feedforward=2048, - dec_layers=dec_layers, - mask_dim=dim, + num_classes="${...params.num_classes}", + hidden_dim="${...params.hidden_dim}", + num_queries="${...params.num_queries}", + nheads="${...params.n_heads}", + dim_feedforward="${...params.feed_forward}", + dec_layers="${...params.dec_layers}", + mask_dim="${...params.dim}", enforce_input_project=False, two_stage=True, - dn=dn, + dn="${...params.dn_mode}", noise_scale=0.4, - dn_num=100, + dn_num="${...params.dn_num}", initialize_box_type='mask2box', initial_pred=True, learn_tgt=False, total_num_feature_levels= "${..pixel_decoder.total_num_feature_levels}", dropout = 0.0, activation= 'relu', - nhead= 8, + nhead= "${...params.n_heads}", dec_n_points= 4, return_intermediate_dec = True, - query_dim= 4, + query_dim= "${...params.query_dim}", dec_layer_share = False, semantic_ce_loss = False, ), ), - criterion=L(SetCriterion)( + criterion=L(WeightedCriterion)( num_classes="${..sem_seg_head.num_classes}", matcher=L(HungarianMatcher)( cost_class = 4.0, @@ -91,18 +109,26 @@ cost_giou=2.0, panoptic_on="${..panoptic_on}", ), + # Params for aux loss weight + class_weight=4.0, + mask_weight=5.0, + dice_weight=5.0, + box_weight=5.0, + giou_weight=2.0, + dec_layers="${..params.dec_layers}", + # Default mask dino options for set criterion. weight_dict=dict(), eos_coef=0.1, losses=['labels', 'masks', 'boxes'], num_points=12544, oversample_ratio=3.0, importance_sample_ratio=0.75, - dn=dn, + dn="${..params.dn_mode}", dn_losses=['labels', 'masks', 'boxes'], panoptic_on="${..panoptic_on}", semantic_ce_loss=False ), - num_queries=300, + num_queries="${.params.num_queries}", object_mask_threshold=0.25, overlap_threshold=0.8, metadata=MetadataCatalog.get('coco_2017_train'), @@ -119,35 +145,3 @@ focus_on_box = False, transform_eval = True, ) - -# set aux loss weight dict -class_weight=4.0 -mask_weight=5.0 -dice_weight=5.0 -box_weight=5.0 -giou_weight=2.0 -weight_dict = {"loss_ce": class_weight} -weight_dict.update({"loss_mask": mask_weight, "loss_dice": dice_weight}) -weight_dict.update({"loss_bbox": box_weight, "loss_giou": giou_weight}) -# two stage is the query selection scheme - -interm_weight_dict = {} -interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) -weight_dict.update(interm_weight_dict) -# denoising training - -if dn == "standard": - weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"}) - dn_losses = ["labels", "boxes"] -elif dn == "seg": - weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) - dn_losses = ["labels", "masks", "boxes"] -else: - dn_losses = [] -# if deep_supervision: - -aux_weight_dict = {} -for i in range(dec_layers): - aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) -weight_dict.update(aux_weight_dict) -model.criterion.weight_dict=weight_dict \ No newline at end of file diff --git a/projects/maskdino/data/__init__.py b/projects/maskdino/data/__init__.py index 7f209a83..2c7e0831 100644 --- a/projects/maskdino/data/__init__.py +++ b/projects/maskdino/data/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Facebook, Inc. and its affiliates. -from . import datasets +# from . import datasets # from . import datasets_detr diff --git a/projects/maskdino/maskdino.py b/projects/maskdino/maskdino.py index 75496f61..7f1cefc2 100644 --- a/projects/maskdino/maskdino.py +++ b/projects/maskdino/maskdino.py @@ -50,6 +50,7 @@ def __init__( focus_on_box: bool = False, transform_eval: bool = False, semantic_ce_loss: bool = False, + params: dict = None # Add params option for omegaconf dict node of info. ): """ Args: @@ -107,8 +108,9 @@ def __init__( if not self.semantic_on: assert self.sem_seg_postprocess_before_inference - - print('criterion.weight_dict ', self.criterion.weight_dict) + + if isinstance(params, dict) and getattr(params, 'show_weights', False): + print('criterion.weight_dict ', self.criterion.weight_dict) @property def device(self): diff --git a/projects/maskdino/modeling/criterion.py b/projects/maskdino/modeling/criterion.py index bcd51fc9..41832dc8 100644 --- a/projects/maskdino/modeling/criterion.py +++ b/projects/maskdino/modeling/criterion.py @@ -18,7 +18,7 @@ ) from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list -from projects.maskdino.utils import box_ops +from ..utils import box_ops # from maskdino.maskformer_model import sigmoid_focal_loss def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): diff --git a/projects/maskdino/modeling/matcher.py b/projects/maskdino/modeling/matcher.py index 050196a9..371ae4e9 100644 --- a/projects/maskdino/modeling/matcher.py +++ b/projects/maskdino/modeling/matcher.py @@ -12,7 +12,7 @@ from torch.cuda.amp import autocast from detectron2.projects.point_rend.point_features import point_sample -from projects.maskdino.utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy +from ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy import random def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): """ diff --git a/projects/maskdino/modeling/weighted_criterion.py b/projects/maskdino/modeling/weighted_criterion.py new file mode 100644 index 00000000..27beb1a1 --- /dev/null +++ b/projects/maskdino/modeling/weighted_criterion.py @@ -0,0 +1,47 @@ +from .criterion import SetCriterion + + +class WeightedCriterion(SetCriterion): + + def __init__( + self, + num_classes, + matcher, + weight_dict, + eos_coef, + losses, + num_points, + oversample_ratio, + importance_sample_ratio, + dn="no", + dn_losses=..., + panoptic_on=False, + semantic_ce_loss=False, + class_weight=4.0, + mask_weight=5.0, + dice_weight=5.0, + box_weight=5.0, + giou_weight=2.0, + dec_layers=9): + # Parse weight dict if it's empty. + if not isinstance(weight_dict, dict) or len(weight_dict) == 0: + weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight, "loss_bbox": box_weight, "loss_giou": giou_weight} + interm_weight_dict = {} + interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()}) + weight_dict.update(interm_weight_dict) + # denoising training + if dn == "standard": + weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"}) + dn_losses = ["labels", "boxes"] + elif dn == "seg": + weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()}) + dn_losses = ["labels", "masks", "boxes"] + else: + dn_losses = [] + # if deep_supervision + aux_weight_dict = {} + for i in range(dec_layers): + aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + super().__init__(num_classes, matcher, weight_dict, eos_coef, losses, num_points, oversample_ratio, importance_sample_ratio, dn, dn_losses, panoptic_on, semantic_ce_loss) \ No newline at end of file diff --git a/projects/maskdino/pyproject.toml b/projects/maskdino/pyproject.toml new file mode 100644 index 00000000..8d74f0fc --- /dev/null +++ b/projects/maskdino/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["setuptools>=60"] + +[project] +name = "maskdino_project" +version = "1.0.6" +dynamic = ["dependencies"] +requires-python = ">=3.6" +description = "Installable project for MaskDINO" +readme = "README.md" +license = {file = "LICENSE"} + +[tool.setuptools.package-dir] +maskdino_project = "." \ No newline at end of file