Skip to content

Commit 9fb73ab

Browse files
vedantdalimkarvedqubvel
authored
Add DPT for segmentation (#1079)
* Initial timm vit encoder commit * Add DPT model and update logic for TimmViTEncoder class * Removed redudant documentation * Added intitial test and some minor code modifications * Code refactor * Added weight conversion script * Moved conversion script to appropriate location * Added logic in timm table generation for adding ViT encoders for DPT * Ruff formatting * Code revision * Remove unnecessary comment * Simplify ViT encoder * Refactor ProjectionReadout * Refactor modeling DPT * Support more encoders * Refactor a bit conversion, added validation * Fixup * Split forward for timm_vit * Rename readout, remove feature_dim * refactor + add transform * Fixup * Refine docs a bit * Refine docs * Refine model size a bit and docs * Add to docs * Add note * Remove txt * Fix doc * Fix docstring * Fixing list in activation * Fixing list * Fixing list * Fixup, fix type hint * Add to README * Add example * Add decoder_readout according to initial impl * Tests update * Fix encoder tests * Fix DPT tests * Refactor a bit * Tests * Update gen test models * Revert gitignore * Fix test --------- Co-authored-by: ved <[email protected]> Co-authored-by: qubvel <[email protected]>
1 parent 3020a77 commit 9fb73ab

File tree

27 files changed

+1836
-47
lines changed

27 files changed

+1836
-47
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
2121
The main features of the library are:
2222

2323
- Super simple high-level API (just two lines to create a neural network)
24-
- 11 encoder-decoder model architectures (Unet, Unet++, Segformer, ...)
24+
- 12 encoder-decoder model architectures (Unet, Unet++, Segformer, DPT, ...)
2525
- 800+ **pretrained** convolution- and transform-based encoders, including [timm](https://github.com/huggingface/pytorch-image-models) support
2626
- Popular metrics and losses for training routines (Dice, Jaccard, Tversky, ...)
2727
- ONNX export and torch script/trace/compile friendly
@@ -105,6 +105,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
105105
| **Train** multiclass segmentation on CamVid | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/camvid_segmentation_multiclass.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel-org/segmentation_models.pytorch/blob/main/examples/camvid_segmentation_multiclass.ipynb) |
106106
| **Train** clothes binary segmentation by @ternaus | [Repo](https://github.com/ternaus/cloths_segmentation) | |
107107
| **Load and inference** pretrained Segformer | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) |
108+
| **Load and inference** pretrained DPT | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) |
108109
| **Save and load** models locally / to HuggingFace Hub |[Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb)
109110
| **Export** trained model to ONNX | [Notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) |
110111

@@ -123,6 +124,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
123124
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
124125
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]
125126
- Segformer [[paper](https://arxiv.org/abs/2105.15203)] [[docs](https://smp.readthedocs.io/en/latest/models.html#segformer)]
127+
- DPT [[paper](https://arxiv.org/abs/2103.13413)] [[docs](https://smp.readthedocs.io/en/latest/models.html#dpt)]
126128

127129
### Encoders <a name="encoders"></a>
128130

docs/encoders_dpt.rst

+461
Large diffs are not rendered by default.

docs/models.rst

+15
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,18 @@ Segformer
8181
~~~~~~~~~
8282
.. autoclass:: segmentation_models_pytorch.Segformer
8383

84+
85+
.. _dpt:
86+
87+
DPT
88+
~~~
89+
90+
.. note::
91+
92+
See full list of DPT-compatible timm encoders in :ref:`dpt-encoders`.
93+
94+
.. note::
95+
96+
For some encoders, the model requires ``dynamic_img_size=True`` to be passed in order to work with resolutions different from what the encoder was trained for.
97+
98+
.. autoclass:: segmentation_models_pytorch.DPT

examples/dpt_inference_pretrained.ipynb

+138
Large diffs are not rendered by default.

misc/generate_table_timm.py

+51-8
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,68 @@ def has_dilation_support(name):
1717
return False
1818

1919

20+
def valid_vit_encoder_for_dpt(name):
21+
if "vit" not in name:
22+
return False
23+
encoder = timm.create_model(name)
24+
feature_info = encoder.feature_info
25+
feature_info_obj = timm.models.FeatureInfo(
26+
feature_info=feature_info, out_indices=[0, 1, 2, 3]
27+
)
28+
reduction_scales = list(feature_info_obj.reduction())
29+
30+
if len(set(reduction_scales)) > 1:
31+
return False
32+
33+
output_stride = reduction_scales[0]
34+
if bin(output_stride).count("1") != 1:
35+
return False
36+
37+
return True
38+
39+
2040
def make_table(data):
2141
names = data.keys()
2242
max_len1 = max([len(x) for x in names]) + 2
2343
max_len2 = len("support dilation") + 2
44+
max_len3 = len("Supported for DPT") + 2
2445

25-
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
26-
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
46+
l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n"
47+
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n"
2748
top = (
2849
"| "
2950
+ "Encoder name".ljust(max_len1 - 2)
3051
+ " | "
3152
+ "Support dilation".center(max_len2 - 2)
53+
+ " | "
54+
+ "Supported for DPT".center(max_len3 - 2)
3255
+ " |\n"
3356
)
3457

3558
table = l1 + top + l2
3659

3760
for k in sorted(data.keys()):
38-
support = (
39-
"✅".center(max_len2 - 3)
40-
if data[k]["has_dilation"]
41-
else " ".center(max_len2 - 2)
61+
if "has_dilation" in data[k] and data[k]["has_dilation"]:
62+
support = "✅".center(max_len2 - 3)
63+
64+
else:
65+
support = " ".center(max_len2 - 2)
66+
67+
if "supported_only_for_dpt" in data[k]:
68+
supported_for_dpt = "✅".center(max_len3 - 3)
69+
70+
else:
71+
supported_for_dpt = " ".center(max_len3 - 2)
72+
73+
table += (
74+
"| "
75+
+ k.ljust(max_len1 - 2)
76+
+ " | "
77+
+ support
78+
+ " | "
79+
+ supported_for_dpt
80+
+ " |\n"
4281
)
43-
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
4482
table += l1
4583

4684
return table
@@ -55,8 +93,13 @@ def make_table(data):
5593
check_features_and_reduction(name)
5694
has_dilation = has_dilation_support(name)
5795
supported_models[name] = dict(has_dilation=has_dilation)
96+
5897
except Exception:
59-
continue
98+
try:
99+
if valid_vit_encoder_for_dpt(name):
100+
supported_models[name] = dict(supported_only_for_dpt=True)
101+
except Exception:
102+
continue
60103

61104
table = make_table(supported_models)
62105
print(table)

misc/generate_test_models.py

+31-14
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,50 @@
99

1010
api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN"))
1111

12-
for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
13-
model = model_class(encoder_name=ENCODER_NAME)
14-
model = model.eval()
15-
16-
# generate test sample
17-
torch.manual_seed(423553)
18-
sample = torch.rand(1, 3, 256, 256)
19-
20-
with torch.no_grad():
21-
output = model(sample)
2212

13+
def save_and_push(model, inputs, outputs, model_name, encoder_name):
2314
with tempfile.TemporaryDirectory() as tmpdir:
2415
# save model
2516
model.save_pretrained(f"{tmpdir}")
2617

2718
# save input and output
28-
torch.save(sample, f"{tmpdir}/input-tensor.pth")
29-
torch.save(output, f"{tmpdir}/output-tensor.pth")
19+
torch.save(inputs, f"{tmpdir}/input-tensor.pth")
20+
torch.save(outputs, f"{tmpdir}/output-tensor.pth")
3021

3122
# create repo
32-
repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}"
23+
repo_id = f"{HUB_REPO}/{model_name}-{encoder_name}"
3324
if not api.repo_exists(repo_id=repo_id):
3425
api.create_repo(repo_id=repo_id, repo_type="model")
3526

3627
# upload to hub
3728
api.upload_folder(
3829
folder_path=tmpdir,
39-
repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}",
30+
repo_id=f"{HUB_REPO}/{model_name}-{encoder_name}",
4031
repo_type="model",
4132
)
33+
34+
35+
for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
36+
if model_name == "dpt":
37+
encoder_name = "tu-test_vit"
38+
model = smp.DPT(
39+
encoder_name=encoder_name,
40+
decoder_readout="cat",
41+
decoder_intermediate_channels=(16, 32, 64, 64),
42+
decoder_fusion_channels=16,
43+
dynamic_img_size=True,
44+
)
45+
else:
46+
encoder_name = ENCODER_NAME
47+
model = model_class(encoder_name=encoder_name)
48+
49+
model = model.eval()
50+
51+
# generate test sample
52+
torch.manual_seed(423553)
53+
sample = torch.rand(1, 3, 256, 256)
54+
55+
with torch.no_grad():
56+
output = model(sample)
57+
58+
save_and_push(model, sample, output, model_name, encoder_name)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import cv2
2+
import torch
3+
import albumentations as A
4+
import segmentation_models_pytorch as smp
5+
6+
MODEL_WEIGHTS_PATH = r"dpt_large-ade20k-b12dca68.pt"
7+
HF_HUB_PATH = "qubvel-hf/dpt-large-ade20k"
8+
PUSH_TO_HUB = False
9+
10+
11+
def get_transform():
12+
return A.Compose(
13+
[
14+
A.LongestMaxSize(max_size=480, interpolation=cv2.INTER_CUBIC),
15+
A.Normalize(
16+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0
17+
),
18+
# This is not correct transform, ideally image should resized without padding to multiple of 32,
19+
# but we take there is no such transform in albumentations, here is closest one
20+
A.PadIfNeeded(
21+
min_height=None,
22+
min_width=None,
23+
pad_height_divisor=32,
24+
pad_width_divisor=32,
25+
border_mode=cv2.BORDER_CONSTANT,
26+
value=0,
27+
p=1,
28+
),
29+
]
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
# fmt: off
35+
smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150, dynamic_img_size=True)
36+
dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH, weights_only=True)
37+
38+
for layer_index in range(0, 4):
39+
for param in ["running_mean", "running_var", "num_batches_tracked", "weight", "bias"]:
40+
for block_index in [1, 2]:
41+
for bn_index in [1, 2]:
42+
# Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model,
43+
# Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ...
44+
# and so on ...
45+
# This is because order of calling fusion layers is reversed in original DPT implementation
46+
dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"] = \
47+
dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}")
48+
49+
if param in ["weight", "bias"]:
50+
if param == "weight":
51+
for block_index in [1, 2]:
52+
for conv_index in [1, 2]:
53+
dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"] = \
54+
dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}")
55+
56+
dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"] = \
57+
dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}")
58+
59+
dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.project.{param}"] = \
60+
dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.out_conv.{param}")
61+
62+
dpt_model_dict[f"decoder.projection_blocks.{layer_index}.project.0.{param}"] = \
63+
dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}")
64+
65+
dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"] = \
66+
dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.3.{param}")
67+
68+
if layer_index != 2:
69+
dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"] = \
70+
dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.4.{param}")
71+
72+
# Changing state dict keys for segmentation head
73+
dpt_model_dict = {
74+
name.replace("scratch.output_conv", "segmentation_head.head"): parameter
75+
for name, parameter in dpt_model_dict.items()
76+
}
77+
78+
# Changing state dict keys for encoder layers
79+
dpt_model_dict = {
80+
name.replace("pretrained.model", "encoder.model"): parameter
81+
for name, parameter in dpt_model_dict.items()
82+
}
83+
84+
# Removing keys, value pairs associated with auxiliary head
85+
dpt_model_dict = {
86+
name: parameter
87+
for name, parameter in dpt_model_dict.items()
88+
if not name.startswith("auxlayer")
89+
}
90+
# fmt: on
91+
92+
smp_model.load_state_dict(dpt_model_dict, strict=True)
93+
94+
# ------- DO NOT touch this section -------
95+
smp_model.eval()
96+
97+
input_tensor = torch.ones((1, 3, 384, 384))
98+
output = smp_model(input_tensor)
99+
100+
print(output.shape)
101+
print(output[0, 0, :3, :3])
102+
103+
expected_slice = torch.tensor(
104+
[
105+
[3.4243, 3.4553, 3.4863],
106+
[3.3332, 3.2876, 3.2419],
107+
[3.2422, 3.1199, 2.9975],
108+
]
109+
)
110+
111+
torch.testing.assert_close(
112+
output[0, 0, :3, :3], expected_slice, atol=1e-4, rtol=1e-4
113+
)
114+
115+
# Saving
116+
transform = get_transform()
117+
118+
transform.save_pretrained(HF_HUB_PATH)
119+
smp_model.save_pretrained(HF_HUB_PATH, push_to_hub=PUSH_TO_HUB)
120+
121+
# Re-loading to make sure everything is saved correctly
122+
smp_model = smp.from_pretrained(HF_HUB_PATH)

segmentation_models_pytorch/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .decoders.pan import PAN
1515
from .decoders.upernet import UPerNet
1616
from .decoders.segformer import Segformer
17+
from .decoders.dpt import DPT
1718
from .base.hub_mixin import from_pretrained
1819

1920
from .__version__ import __version__
@@ -34,6 +35,7 @@
3435
PAN,
3536
UPerNet,
3637
Segformer,
38+
DPT,
3739
]
3840
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}
3941

@@ -84,6 +86,7 @@ def create_model(
8486
"PAN",
8587
"UPerNet",
8688
"Segformer",
89+
"DPT",
8790
"from_pretrained",
8891
"create_model",
8992
"__version__",

segmentation_models_pytorch/decoders/deeplabv3/model.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ class DeepLabV3(SegmentationModel):
3434
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
3535
activation: An activation function to apply after the final convolution layer.
3636
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
37-
**callable** and **None**.
38-
Default is **None**
37+
**callable** and **None**. Default is **None**.
3938
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
4039
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
4140
on top of encoder if **aux_params** is not **None** (default). Supported params:
@@ -159,8 +158,7 @@ class DeepLabV3Plus(SegmentationModel):
159158
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
160159
activation: An activation function to apply after the final convolution layer.
161160
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
162-
**callable** and **None**.
163-
Default is **None**
161+
**callable** and **None**. Default is **None**.
164162
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity.
165163
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
166164
on top of encoder if **aux_params** is not **None** (default). Supported params:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import DPT
2+
3+
__all__ = ["DPT"]

0 commit comments

Comments
 (0)