Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
### Forked Detrex branch

This is a forked branch from detrex to make the maskdino project installable for python using `pip install .`
from projects/maskdino folder
This maskdino-project is also uploaded to devpi.

Originally using the maskdino git repository but with different versions of config using omegaconf the maskdino repo
and maskdino repo inside detrex are quite different and ran into training problems.

Organised the maskdino param properties to be able to change them and make them more consistent.

<div align="center">
<img src="./assets/logo_2.png" width="30%">
</div>
Expand Down
128 changes: 61 additions & 67 deletions projects/maskdino/configs/models/maskdino_r50.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,42 @@
from detrex.modeling.backbone import ResNet, BasicStem

from detectron2.config import LazyCall as L

from projects.maskdino.modeling.meta_arch.maskdino_head import MaskDINOHead
from projects.maskdino.modeling.pixel_decoder.maskdino_encoder import MaskDINOEncoder
from projects.maskdino.modeling.transformer_decoder.maskdino_decoder import MaskDINODecoder
from projects.maskdino.modeling.criterion import SetCriterion
from projects.maskdino.modeling.matcher import HungarianMatcher
from projects.maskdino.maskdino import MaskDINO
from detectron2.data import MetadataCatalog
from detectron2.layers import Conv2d, ShapeSpec, get_norm

from omegaconf import OmegaConf

from ...modeling.meta_arch.maskdino_head import MaskDINOHead
from ...modeling.pixel_decoder.maskdino_encoder import MaskDINOEncoder
from ...modeling.transformer_decoder.maskdino_decoder import MaskDINODecoder
from ...modeling.weighted_criterion import WeightedCriterion
from ...modeling.matcher import HungarianMatcher
from ...maskdino import MaskDINO



dim=256
n_class=80
dn="seg"
dec_layers = 9
input_shape={'res2': ShapeSpec(channels=256, height=None, width=None, stride=4), 'res3': ShapeSpec(channels=512, height=None, width=None, stride=8), 'res4': ShapeSpec(channels=1024, height=None, width=None, stride=16), 'res5': ShapeSpec(channels=2048, height=None, width=None, stride=32)}
model = L(MaskDINO)(
# parameters in one place.
params=OmegaConf.create(dict(
input_shape={
'res2': L(ShapeSpec)(channels=256, height=None, width=None, stride=4),
'res3': L(ShapeSpec)(channels=512, height=None, width=None, stride=8),
'res4': L(ShapeSpec)(channels=1024, height=None, width=None, stride=16),
'res5': L(ShapeSpec)(channels=2048, height=None, width=None, stride=32)
},
dim=256,
hidden_dim=256,
query_dim=4,
num_classes=80,
dec_layers=9,
enc_layers=6,
feed_forward=2048,
n_heads=8,
num_queries=300,
dn_num=100,
dn_mode="seg",
show_weights=True
)),
backbone=L(ResNet)(
stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"),
stages=L(ResNet.make_default_stages)(
Expand All @@ -32,16 +50,16 @@
freeze_at=1,
),
sem_seg_head=L(MaskDINOHead)(
input_shape=input_shape,
num_classes=n_class,
input_shape="${..params.input_shape}",
num_classes="${..params.num_classes}",
pixel_decoder=L(MaskDINOEncoder)(
input_shape=input_shape,
input_shape="${...params.input_shape}",
transformer_dropout=0.0,
transformer_nheads=8,
transformer_dim_feedforward=2048,
transformer_enc_layers=6,
conv_dim=dim,
mask_dim=dim,
transformer_nheads="${...params.n_heads}",
transformer_dim_feedforward="${...params.feed_forward}",
transformer_enc_layers="${...params.enc_layers}",
conv_dim="${...params.dim}",
mask_dim="${...params.dim}",
norm = 'GN',
transformer_in_features=['res3', 'res4', 'res5'],
common_stride=4,
Expand All @@ -52,35 +70,35 @@
loss_weight= 1.0,
ignore_value= -1,
transformer_predictor=L(MaskDINODecoder)(
in_channels=dim,
in_channels="${...params.dim}",
mask_classification=True,
num_classes="${..num_classes}",
hidden_dim=dim,
num_queries=300,
nheads=8,
dim_feedforward=2048,
dec_layers=dec_layers,
mask_dim=dim,
num_classes="${...params.num_classes}",
hidden_dim="${...params.hidden_dim}",
num_queries="${...params.num_queries}",
nheads="${...params.n_heads}",
dim_feedforward="${...params.feed_forward}",
dec_layers="${...params.dec_layers}",
mask_dim="${...params.dim}",
enforce_input_project=False,
two_stage=True,
dn=dn,
dn="${...params.dn_mode}",
noise_scale=0.4,
dn_num=100,
dn_num="${...params.dn_num}",
initialize_box_type='mask2box',
initial_pred=True,
learn_tgt=False,
total_num_feature_levels= "${..pixel_decoder.total_num_feature_levels}",
dropout = 0.0,
activation= 'relu',
nhead= 8,
nhead= "${...params.n_heads}",
dec_n_points= 4,
return_intermediate_dec = True,
query_dim= 4,
query_dim= "${...params.query_dim}",
dec_layer_share = False,
semantic_ce_loss = False,
),
),
criterion=L(SetCriterion)(
criterion=L(WeightedCriterion)(
num_classes="${..sem_seg_head.num_classes}",
matcher=L(HungarianMatcher)(
cost_class = 4.0,
Expand All @@ -91,18 +109,26 @@
cost_giou=2.0,
panoptic_on="${..panoptic_on}",
),
# Params for aux loss weight
class_weight=4.0,
mask_weight=5.0,
dice_weight=5.0,
box_weight=5.0,
giou_weight=2.0,
dec_layers="${..params.dec_layers}",
# Default mask dino options for set criterion.
weight_dict=dict(),
eos_coef=0.1,
losses=['labels', 'masks', 'boxes'],
num_points=12544,
oversample_ratio=3.0,
importance_sample_ratio=0.75,
dn=dn,
dn="${..params.dn_mode}",
dn_losses=['labels', 'masks', 'boxes'],
panoptic_on="${..panoptic_on}",
semantic_ce_loss=False
),
num_queries=300,
num_queries="${.params.num_queries}",
object_mask_threshold=0.25,
overlap_threshold=0.8,
metadata=MetadataCatalog.get('coco_2017_train'),
Expand All @@ -119,35 +145,3 @@
focus_on_box = False,
transform_eval = True,
)

# set aux loss weight dict
class_weight=4.0
mask_weight=5.0
dice_weight=5.0
box_weight=5.0
giou_weight=2.0
weight_dict = {"loss_ce": class_weight}
weight_dict.update({"loss_mask": mask_weight, "loss_dice": dice_weight})
weight_dict.update({"loss_bbox": box_weight, "loss_giou": giou_weight})
# two stage is the query selection scheme

interm_weight_dict = {}
interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})
weight_dict.update(interm_weight_dict)
# denoising training

if dn == "standard":
weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"})
dn_losses = ["labels", "boxes"]
elif dn == "seg":
weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()})
dn_losses = ["labels", "masks", "boxes"]
else:
dn_losses = []
# if deep_supervision:

aux_weight_dict = {}
for i in range(dec_layers):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
model.criterion.weight_dict=weight_dict
2 changes: 1 addition & 1 deletion projects/maskdino/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from . import datasets
# from . import datasets
# from . import datasets_detr
6 changes: 4 additions & 2 deletions projects/maskdino/maskdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
focus_on_box: bool = False,
transform_eval: bool = False,
semantic_ce_loss: bool = False,
params: dict = None # Add params option for omegaconf dict node of info.
):
"""
Args:
Expand Down Expand Up @@ -107,8 +108,9 @@ def __init__(

if not self.semantic_on:
assert self.sem_seg_postprocess_before_inference

print('criterion.weight_dict ', self.criterion.weight_dict)

if isinstance(params, dict) and getattr(params, 'show_weights', False):
print('criterion.weight_dict ', self.criterion.weight_dict)

@property
def device(self):
Expand Down
2 changes: 1 addition & 1 deletion projects/maskdino/modeling/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
from projects.maskdino.utils import box_ops
from ..utils import box_ops

# from maskdino.maskformer_model import sigmoid_focal_loss
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
Expand Down
2 changes: 1 addition & 1 deletion projects/maskdino/modeling/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.cuda.amp import autocast

from detectron2.projects.point_rend.point_features import point_sample
from projects.maskdino.utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy
from ..utils.box_ops import generalized_box_iou,box_cxcywh_to_xyxy
import random
def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor):
"""
Expand Down
47 changes: 47 additions & 0 deletions projects/maskdino/modeling/weighted_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from .criterion import SetCriterion


class WeightedCriterion(SetCriterion):

def __init__(
self,
num_classes,
matcher,
weight_dict,
eos_coef,
losses,
num_points,
oversample_ratio,
importance_sample_ratio,
dn="no",
dn_losses=...,
panoptic_on=False,
semantic_ce_loss=False,
class_weight=4.0,
mask_weight=5.0,
dice_weight=5.0,
box_weight=5.0,
giou_weight=2.0,
dec_layers=9):
# Parse weight dict if it's empty.
if not isinstance(weight_dict, dict) or len(weight_dict) == 0:
weight_dict = {"loss_ce": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight, "loss_bbox": box_weight, "loss_giou": giou_weight}
interm_weight_dict = {}
interm_weight_dict.update({k + f'_interm': v for k, v in weight_dict.items()})
weight_dict.update(interm_weight_dict)
# denoising training
if dn == "standard":
weight_dict.update({k + f"_dn": v for k, v in weight_dict.items() if k != "loss_mask" and k != "loss_dice"})
dn_losses = ["labels", "boxes"]
elif dn == "seg":
weight_dict.update({k + f"_dn": v for k, v in weight_dict.items()})
dn_losses = ["labels", "masks", "boxes"]
else:
dn_losses = []
# if deep_supervision
aux_weight_dict = {}
for i in range(dec_layers):
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)

super().__init__(num_classes, matcher, weight_dict, eos_coef, losses, num_points, oversample_ratio, importance_sample_ratio, dn, dn_losses, panoptic_on, semantic_ce_loss)
14 changes: 14 additions & 0 deletions projects/maskdino/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[build-system]
requires = ["setuptools>=60"]

[project]
name = "maskdino_project"
version = "1.0.6"
dynamic = ["dependencies"]
requires-python = ">=3.6"
description = "Installable project for MaskDINO"
readme = "README.md"
license = {file = "LICENSE"}

[tool.setuptools.package-dir]
maskdino_project = "."