diff --git a/README.md b/README.md
index b4666332..f058824b 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,14 @@
+### Forked Detrex branch
+
+This is a forked branch from detrex to make the maskdino project installable for python using `pip install .`
+from projects/maskdino folder
+This maskdino-project is also uploaded to devpi.
+
+Originally using the maskdino git repository but with different versions of config using omegaconf the maskdino repo
+and maskdino repo inside detrex are quite different and ran into training problems.
+
+Organised the maskdino param properties to be able to change them and make them more consistent.
+
diff --git a/projects/maskdino/configs/models/maskdino_r50.py b/projects/maskdino/configs/models/maskdino_r50.py
index 89cbe5b1..7a08fe8f 100644
--- a/projects/maskdino/configs/models/maskdino_r50.py
+++ b/projects/maskdino/configs/models/maskdino_r50.py
@@ -3,24 +3,42 @@
from detrex.modeling.backbone import ResNet, BasicStem
from detectron2.config import LazyCall as L
-
-from projects.maskdino.modeling.meta_arch.maskdino_head import MaskDINOHead
-from projects.maskdino.modeling.pixel_decoder.maskdino_encoder import MaskDINOEncoder
-from projects.maskdino.modeling.transformer_decoder.maskdino_decoder import MaskDINODecoder
-from projects.maskdino.modeling.criterion import SetCriterion
-from projects.maskdino.modeling.matcher import HungarianMatcher
-from projects.maskdino.maskdino import MaskDINO
from detectron2.data import MetadataCatalog
from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from omegaconf import OmegaConf
+
+from ...modeling.meta_arch.maskdino_head import MaskDINOHead
+from ...modeling.pixel_decoder.maskdino_encoder import MaskDINOEncoder
+from ...modeling.transformer_decoder.maskdino_decoder import MaskDINODecoder
+from ...modeling.weighted_criterion import WeightedCriterion
+from ...modeling.matcher import HungarianMatcher
+from ...maskdino import MaskDINO
+
-dim=256
-n_class=80
-dn="seg"
-dec_layers = 9
-input_shape={'res2': ShapeSpec(channels=256, height=None, width=None, stride=4), 'res3': ShapeSpec(channels=512, height=None, width=None, stride=8), 'res4': ShapeSpec(channels=1024, height=None, width=None, stride=16), 'res5': ShapeSpec(channels=2048, height=None, width=None, stride=32)}
model = L(MaskDINO)(
+ # parameters in one place.
+ params=OmegaConf.create(dict(
+ input_shape={
+ 'res2': L(ShapeSpec)(channels=256, height=None, width=None, stride=4),
+ 'res3': L(ShapeSpec)(channels=512, height=None, width=None, stride=8),
+ 'res4': L(ShapeSpec)(channels=1024, height=None, width=None, stride=16),
+ 'res5': L(ShapeSpec)(channels=2048, height=None, width=None, stride=32)
+ },
+ dim=256,
+ hidden_dim=256,
+ query_dim=4,
+ num_classes=80,
+ dec_layers=9,
+ enc_layers=6,
+ feed_forward=2048,
+ n_heads=8,
+ num_queries=300,
+ dn_num=100,
+ dn_mode="seg",
+ show_weights=True
+ )),
backbone=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
@@ -32,16 +50,16 @@
freeze_at=1,
),
sem_seg_head=L(MaskDINOHead)(
- input_shape=input_shape,
- num_classes=n_class,
+ input_shape="${..params.input_shape}",
+ num_classes="${..params.num_classes}",
pixel_decoder=L(MaskDINOEncoder)(
- input_shape=input_shape,
+ input_shape="${...params.input_shape}",
transformer_dropout=0.0,
- transformer_nheads=8,
- transformer_dim_feedforward=2048,
- transformer_enc_layers=6,
- conv_dim=dim,
- mask_dim=dim,
+ transformer_nheads="${...params.n_heads}",
+ transformer_dim_feedforward="${...params.feed_forward}",
+ transformer_enc_layers="${...params.enc_layers}",
+ conv_dim="${...params.dim}",
+ mask_dim="${...params.dim}",
norm = 'GN',
transformer_in_features=['res3', 'res4', 'res5'],
common_stride=4,
@@ -52,35 +70,35 @@
loss_weight= 1.0,
ignore_value= -1,
transformer_predictor=L(MaskDINODecoder)(
- in_channels=dim,
+ in_channels="${...params.dim}",
mask_classification=True,
- num_classes="${..num_classes}",
- hidden_dim=dim,
- num_queries=300,
- nheads=8,
- dim_feedforward=2048,
- dec_layers=dec_layers,
- mask_dim=dim,
+ num_classes="${...params.num_classes}",
+ hidden_dim="${...params.hidden_dim}",
+ num_queries="${...params.num_queries}",
+ nheads="${...params.n_heads}",
+ dim_feedforward="${...params.feed_forward}",
+ dec_layers="${...params.dec_layers}",
+ mask_dim="${...params.dim}",
enforce_input_project=False,
two_stage=True,
- dn=dn,
+ dn="${...params.dn_mode}",
noise_scale=0.4,
- dn_num=100,
+ dn_num="${...params.dn_num}",
initialize_box_type='mask2box',
initial_pred=True,
learn_tgt=False,
total_num_feature_levels= "${..pixel_decoder.total_num_feature_levels}",
dropout = 0.0,
activation= 'relu',
- nhead= 8,
+ nhead= "${...params.n_heads}",
dec_n_points= 4,
return_intermediate_dec = True,
- query_dim= 4,
+ query_dim= "${...params.query_dim}",
dec_layer_share = False,
semantic_ce_loss = False,
),
),
- criterion=L(SetCriterion)(
+ criterion=L(WeightedCriterion)(
num_classes="${..sem_seg_head.num_classes}",
matcher=L(HungarianMatcher)(
cost_class = 4.0,
@@ -91,18 +109,26 @@
cost_giou=2.0,
panoptic_on="${..panoptic_on}",
),
+ # Params for aux loss weight
+ class_weight=4.0,
+ mask_weight=5.0,
+ dice_weight=5.0,
+ box_weight=5.0,
+ giou_weight=2.0,
+ dec_layers="${..params.dec_layers}",
+ # Default mask dino options for set criterion.
weight_dict=dict(),
eos_coef=0.1,
losses=['labels', 'masks', 'boxes'],
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
- dn=dn,
+ dn="${..params.dn_mode}",
dn_losses=['labels', 'masks', 'boxes'],
panoptic_on="${..panoptic_on}",
semantic_ce_loss=False
),
- num_queries=300,
+ num_queries="${.params.num_queries}",
object_mask_threshold=0.25,
overlap_threshold=0.8,
metadata=MetadataCatalog.get('coco_2017_train'),
@@ -119,35 +145,3 @@
focus_on_box = False,
transform_eval = True,
)
-
-# set aux loss weight dict
-class_weight=4.0
-mask_weight=5.0
-dice_weight=5.0
-box_weight=5.0
-giou_weight=2.0
-weight_dict = {"loss_ce": class_weight}
-weight_dict.update({"loss_mask": mask_weight, "loss_dice": dice_weight})
-weight_dict.update({"loss_bbox": box_weight, "loss_giou": giou_weight})
-# two stage is the query selection scheme
-
-interm_weight_dict = {}
-interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})
-weight_dict.update(interm_weight_dict)
-# denoising training
-
-if dn == "standard":
- weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"})
- dn_losses = ["labels", "boxes"]
-elif dn == "seg":
- weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()})
- dn_losses = ["labels", "masks", "boxes"]
-else:
- dn_losses = []
-# if deep_supervision:
-
-aux_weight_dict = {}
-for i in range(dec_layers):
- aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
-weight_dict.update(aux_weight_dict)
-model.criterion.weight_dict=weight_dict
\ No newline at end of file
diff --git a/projects/maskdino/data/__init__.py b/projects/maskdino/data/__init__.py
index 7f209a83..2c7e0831 100644
--- a/projects/maskdino/data/__init__.py
+++ b/projects/maskdino/data/__init__.py
@@ -1,3 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.
-from . import datasets
+# from . import datasets
# from . import datasets_detr
diff --git a/projects/maskdino/maskdino.py b/projects/maskdino/maskdino.py
index 75496f61..7f1cefc2 100644
--- a/projects/maskdino/maskdino.py
+++ b/projects/maskdino/maskdino.py
@@ -50,6 +50,7 @@ def __init__(
focus_on_box: bool = False,
transform_eval: bool = False,
semantic_ce_loss: bool = False,
+ params: dict = None # Add params option for omegaconf dict node of info.
):
"""
Args:
@@ -107,8 +108,9 @@ def __init__(
if not self.semantic_on:
assert self.sem_seg_postprocess_before_inference
-
- print('criterion.weight_dict ', self.criterion.weight_dict)
+
+ if isinstance(params, dict) and getattr(params, 'show_weights', False):
+ print('criterion.weight_dict ', self.criterion.weight_dict)
@property
def device(self):
diff --git a/projects/maskdino/modeling/criterion.py b/projects/maskdino/modeling/criterion.py
index bcd51fc9..41832dc8 100644
--- a/projects/maskdino/modeling/criterion.py
+++ b/projects/maskdino/modeling/criterion.py
@@ -18,7 +18,7 @@
)
from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
-from projects.maskdino.utils import box_ops
+from ..utils import box_ops
# from maskdino.maskformer_model import sigmoid_focal_loss
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
diff --git a/projects/maskdino/modeling/matcher.py b/projects/maskdino/modeling/matcher.py
index 050196a9..371ae4e9 100644
--- a/projects/maskdino/modeling/matcher.py
+++ b/projects/maskdino/modeling/matcher.py
@@ -12,7 +12,7 @@
from torch.cuda.amp import autocast
from detectron2.projects.point_rend.point_features import point_sample
-from projects.maskdino.utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy
+from ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy
import random
def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
"""
diff --git a/projects/maskdino/modeling/weighted_criterion.py b/projects/maskdino/modeling/weighted_criterion.py
new file mode 100644
index 00000000..27beb1a1
--- /dev/null
+++ b/projects/maskdino/modeling/weighted_criterion.py
@@ -0,0 +1,47 @@
+from .criterion import SetCriterion
+
+
+class WeightedCriterion(SetCriterion):
+
+ def __init__(
+ self,
+ num_classes,
+ matcher,
+ weight_dict,
+ eos_coef,
+ losses,
+ num_points,
+ oversample_ratio,
+ importance_sample_ratio,
+ dn="no",
+ dn_losses=...,
+ panoptic_on=False,
+ semantic_ce_loss=False,
+ class_weight=4.0,
+ mask_weight=5.0,
+ dice_weight=5.0,
+ box_weight=5.0,
+ giou_weight=2.0,
+ dec_layers=9):
+ # Parse weight dict if it's empty.
+ if not isinstance(weight_dict, dict) or len(weight_dict) == 0:
+ weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight, "loss_bbox": box_weight, "loss_giou": giou_weight}
+ interm_weight_dict = {}
+ interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})
+ weight_dict.update(interm_weight_dict)
+ # denoising training
+ if dn == "standard":
+ weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"})
+ dn_losses = ["labels", "boxes"]
+ elif dn == "seg":
+ weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()})
+ dn_losses = ["labels", "masks", "boxes"]
+ else:
+ dn_losses = []
+ # if deep_supervision
+ aux_weight_dict = {}
+ for i in range(dec_layers):
+ aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+ weight_dict.update(aux_weight_dict)
+
+ super().__init__(num_classes, matcher, weight_dict, eos_coef, losses, num_points, oversample_ratio, importance_sample_ratio, dn, dn_losses, panoptic_on, semantic_ce_loss)
\ No newline at end of file
diff --git a/projects/maskdino/pyproject.toml b/projects/maskdino/pyproject.toml
new file mode 100644
index 00000000..8d74f0fc
--- /dev/null
+++ b/projects/maskdino/pyproject.toml
@@ -0,0 +1,14 @@
+[build-system]
+requires = ["setuptools>=60"]
+
+[project]
+name = "maskdino_project"
+version = "1.0.6"
+dynamic = ["dependencies"]
+requires-python = ">=3.6"
+description = "Installable project for MaskDINO"
+readme = "README.md"
+license = {file = "LICENSE"}
+
+[tool.setuptools.package-dir]
+maskdino_project = "."
\ No newline at end of file