Skip to content

Commit 7480e31

Browse files
authored
script for converting retinanet weights from trochvision (#2233)
* script for converting retinanet weights from trochvision * check numericals for torch and keras model after weight convertion * acoomodate script to include weights for second present * nit * Add save to preset code * update preset versions
1 parent 39a8f43 commit 7480e31

File tree

3 files changed

+346
-15
lines changed

3 files changed

+346
-15
lines changed

keras_hub/src/models/retinanet/retinanet_image_converter.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,3 @@
66
@keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
77
class RetinaNetImageConverter(ImageConverter):
88
backbone_cls = RetinaNetBackbone
9-
10-
def __init__(
11-
self,
12-
*args,
13-
**kwargs,
14-
):
15-
# TODO: update presets and remove these old config options. They were
16-
# never needed.
17-
if "norm_mean" in kwargs:
18-
kwargs["offset"] = [-x for x in kwargs.pop("norm_mean")]
19-
if "norm_std" in kwargs:
20-
kwargs["scale"] = [1.0 / x for x in kwargs.pop("norm_std")]
21-
super().__init__(*args, **kwargs)

keras_hub/src/models/retinanet/retinanet_presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"params": 34121239,
1212
"path": "retinanet",
1313
},
14-
"kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/3",
14+
"kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/4",
1515
},
1616
"retinanet_resnet50_fpn_v2_coco": {
1717
"metadata": {
@@ -22,6 +22,6 @@
2222
"params": 31558592,
2323
"path": "retinanet",
2424
},
25-
"kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/2",
25+
"kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/3",
2626
},
2727
}
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
"""Convert ViT checkpoints.
2+
3+
export KAGGLE_USERNAME=XXX
4+
export KAGGLE_KEY=XXX
5+
6+
python tools/checkpoint_conversion/convert_retinanet_checkpoints.py \
7+
--preset retinanet_resnet50_fpn_coco
8+
"""
9+
10+
import os
11+
import shutil
12+
13+
import keras
14+
import numpy as np
15+
import torch
16+
from absl import app
17+
from absl import flags
18+
from keras import ops
19+
from torchvision.models.detection.retinanet import (
20+
RetinaNet_ResNet50_FPN_V2_Weights,
21+
)
22+
from torchvision.models.detection.retinanet import (
23+
RetinaNet_ResNet50_FPN_Weights,
24+
)
25+
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn
26+
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2
27+
28+
import keras_hub
29+
from keras_hub.src.models.backbone import Backbone
30+
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
31+
from keras_hub.src.models.retinanet.retinanet_image_converter import (
32+
RetinaNetImageConverter,
33+
)
34+
from keras_hub.src.models.retinanet.retinanet_object_detector import (
35+
RetinaNetObjectDetector,
36+
)
37+
from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( # noqa: E501
38+
RetinaNetObjectDetectorPreprocessor,
39+
)
40+
41+
FLAGS = flags.FLAGS
42+
43+
PRESET_MAP = {
44+
"retinanet_resnet50_fpn_coco": RetinaNet_ResNet50_FPN_Weights.DEFAULT,
45+
"retinanet_resnet50_fpn_v2_coco": RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT,
46+
}
47+
48+
flags.DEFINE_string(
49+
"preset",
50+
None,
51+
f"Must be one of {','.join(PRESET_MAP.keys())}",
52+
required=True,
53+
)
54+
flags.DEFINE_string(
55+
"upload_uri",
56+
None,
57+
'Could be "kaggle://keras/{variant}/keras/{preset}"',
58+
required=False,
59+
)
60+
61+
62+
def get_keras_backbone(use_p5=False):
63+
image_encoder = Backbone.from_preset(
64+
"resnet_50_imagenet", load_weights=False
65+
)
66+
backbone = RetinaNetBackbone(
67+
image_encoder=image_encoder,
68+
min_level=3,
69+
max_level=7,
70+
use_p5=use_p5,
71+
)
72+
73+
return backbone
74+
75+
76+
# Helper functions.
77+
def port_weight(keras_variable, torch_tensor, hook_fn=None):
78+
if hook_fn:
79+
torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape))
80+
keras_variable.assign(torch_tensor)
81+
82+
83+
def convert_image_encoder(state_dict, backbone):
84+
def port_conv2d(keras_layer_name, torch_weight_prefix):
85+
port_weight(
86+
backbone.get_layer(keras_layer_name).kernel,
87+
torch_tensor=state_dict[f"{torch_weight_prefix}.weight"],
88+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
89+
)
90+
91+
def port_batch_normalization(keras_layer_name, torch_weight_prefix):
92+
port_weight(
93+
backbone.get_layer(keras_layer_name).gamma,
94+
torch_tensor=state_dict[f"{torch_weight_prefix}.weight"],
95+
)
96+
port_weight(
97+
backbone.get_layer(keras_layer_name).beta,
98+
torch_tensor=state_dict[f"{torch_weight_prefix}.bias"],
99+
)
100+
port_weight(
101+
backbone.get_layer(keras_layer_name).moving_mean,
102+
torch_tensor=state_dict[f"{torch_weight_prefix}.running_mean"],
103+
)
104+
port_weight(
105+
backbone.get_layer(keras_layer_name).moving_variance,
106+
torch_tensor=state_dict[f"{torch_weight_prefix}.running_var"],
107+
)
108+
109+
block_type = backbone.block_type
110+
111+
# Stem
112+
port_conv2d("conv1_conv", "backbone.body.conv1")
113+
port_batch_normalization("conv1_bn", "backbone.body.bn1")
114+
115+
# Stages
116+
num_stacks = len(backbone.stackwise_num_filters)
117+
for stack_index in range(num_stacks):
118+
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
119+
keras_name = f"stack{stack_index}_block{block_idx}"
120+
torch_name = f"backbone.body.layer{stack_index + 1}.{block_idx}"
121+
122+
if block_idx == 0 and (
123+
block_type == "bottleneck_block" or stack_index > 0
124+
):
125+
port_conv2d(
126+
f"{keras_name}_0_conv", f"{torch_name}.downsample.0"
127+
)
128+
port_batch_normalization(
129+
f"{keras_name}_0_bn", f"{torch_name}.downsample.1"
130+
)
131+
port_conv2d(f"{keras_name}_1_conv", f"{torch_name}.conv1")
132+
port_batch_normalization(f"{keras_name}_1_bn", f"{torch_name}.bn1")
133+
port_conv2d(f"{keras_name}_2_conv", f"{torch_name}.conv2")
134+
port_batch_normalization(f"{keras_name}_2_bn", f"{torch_name}.bn2")
135+
if block_type == "bottleneck_block":
136+
port_conv2d(f"{keras_name}_3_conv", f"{torch_name}.conv3")
137+
port_batch_normalization(
138+
f"{keras_name}_3_bn", f"{torch_name}.bn3"
139+
)
140+
141+
142+
def convert_fpn(state_dict, fpn_network):
143+
def port_conv2d(kera_weight, torch_weight_prefix):
144+
port_weight(
145+
kera_weight.kernel,
146+
torch_tensor=state_dict[f"{torch_weight_prefix}.weight"],
147+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
148+
)
149+
port_weight(
150+
kera_weight.bias,
151+
torch_tensor=state_dict[f"{torch_weight_prefix}.bias"],
152+
)
153+
154+
for level, layer in fpn_network.lateral_conv_layers.items():
155+
idx = int(level[1])
156+
port_conv2d(layer, f"backbone.fpn.inner_blocks.{idx - 3}.0")
157+
158+
for level, layer in fpn_network.output_conv_layers.items():
159+
idx = int(level[1])
160+
if "output" in layer.name:
161+
port_conv2d(layer, f"backbone.fpn.layer_blocks.{idx - 3}.0")
162+
if "coarser" in layer.name:
163+
port_conv2d(layer, f"backbone.fpn.extra_blocks.p{idx}")
164+
165+
166+
def convert_head_weights(state_dict, keras_model):
167+
def port_conv2d(kera_weight, torch_weight_prefix):
168+
port_weight(
169+
kera_weight.kernel,
170+
torch_tensor=state_dict[f"{torch_weight_prefix}.weight"],
171+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
172+
)
173+
174+
port_weight(
175+
kera_weight.bias,
176+
torch_tensor=state_dict[f"{torch_weight_prefix}.bias"],
177+
)
178+
179+
for idx, layer in enumerate(keras_model.box_head.conv_layers):
180+
if FLAGS.preset == "retinanet_resnet50_fpn_coco":
181+
port_conv2d(layer, f"head.regression_head.conv.{idx}.0")
182+
else:
183+
port_weight(
184+
layer.kernel,
185+
torch_tensor=state_dict[
186+
f"head.regression_head.conv.{idx}.0.weight"
187+
],
188+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
189+
)
190+
191+
for idx, layer in enumerate(keras_model.box_head.group_norm_layers):
192+
port_weight(
193+
layer.gamma,
194+
state_dict[f"head.regression_head.conv.{idx}.1.weight"],
195+
)
196+
port_weight(
197+
layer.beta, state_dict[f"head.regression_head.conv.{idx}.1.bias"]
198+
)
199+
200+
port_conv2d(
201+
keras_model.box_head.prediction_layer,
202+
torch_weight_prefix="head.regression_head.bbox_reg",
203+
)
204+
for idx, layer in enumerate(keras_model.classification_head.conv_layers):
205+
if FLAGS.preset == "retinanet_resnet50_fpn_coco":
206+
port_conv2d(layer, f"head.classification_head.conv.{idx}.0")
207+
else:
208+
port_weight(
209+
layer.kernel,
210+
torch_tensor=state_dict[
211+
f"head.classification_head.conv.{idx}.0.weight"
212+
],
213+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
214+
)
215+
216+
for idx, layer in enumerate(
217+
keras_model.classification_head.group_norm_layers
218+
):
219+
port_weight(
220+
layer.gamma,
221+
state_dict[f"head.classification_head.conv.{idx}.1.weight"],
222+
)
223+
port_weight(
224+
layer.beta,
225+
state_dict[f"head.classification_head.conv.{idx}.1.bias"],
226+
)
227+
228+
port_conv2d(
229+
keras_model.classification_head.prediction_layer,
230+
torch_weight_prefix="head.classification_head.cls_logits",
231+
)
232+
233+
234+
def convert_backbone_weights(state_dict, backbone):
235+
# Convert ResNet50 image encoder
236+
convert_image_encoder(state_dict, backbone.image_encoder)
237+
# Convert FPN
238+
convert_fpn(state_dict, backbone.feature_pyramid)
239+
240+
241+
def convert_image_converter(torch_model):
242+
image_mean = torch_model.transform.image_mean
243+
image_std = torch_model.transform.image_std
244+
resolution = torch_model.transform.min_size[0]
245+
return RetinaNetImageConverter(
246+
image_size=(resolution, resolution),
247+
pad_to_aspect_ratio=True,
248+
crop_to_aspect_ratio=False,
249+
scale=[1.0 / 255.0 / s for s in image_std],
250+
offset=[-m / s for m, s in zip(image_mean, image_std)],
251+
)
252+
253+
254+
def main(_):
255+
if FLAGS.preset not in PRESET_MAP.keys():
256+
raise ValueError(
257+
f"Invalid preset {FLAGS.preset}. Must be one "
258+
f"of {','.join(PRESET_MAP.keys())}"
259+
)
260+
preset = FLAGS.preset
261+
torch_preset = PRESET_MAP[preset]
262+
if os.path.exists(preset):
263+
shutil.rmtree(preset)
264+
os.makedirs(preset)
265+
266+
print(f"🏃 Coverting {preset}")
267+
268+
# Load huggingface model.
269+
if preset == "retinanet_resnet50_fpn_coco":
270+
torch_model = retinanet_resnet50_fpn(weights=torch_preset)
271+
torch_model.eval()
272+
keras_backbone = get_keras_backbone()
273+
elif preset == "retinanet_resnet50_fpn_v2_coco":
274+
torch_model = retinanet_resnet50_fpn_v2(weights=torch_preset)
275+
torch_model.eval()
276+
keras_backbone = get_keras_backbone(use_p5=True)
277+
278+
state_dict = torch_model.state_dict()
279+
print("✅ Torch and KerasHub model loaded.")
280+
281+
convert_backbone_weights(state_dict, keras_backbone)
282+
print("✅ Backbone weights converted.")
283+
284+
keras_image_converter = convert_image_converter(torch_model)
285+
print("✅ Loaded image converter")
286+
287+
preprocessor = RetinaNetObjectDetectorPreprocessor(
288+
image_converter=keras_image_converter
289+
)
290+
291+
keras_model = RetinaNetObjectDetector(
292+
backbone=keras_backbone,
293+
num_classes=len(torch_preset.meta["categories"]),
294+
preprocessor=preprocessor,
295+
use_prediction_head_norm=True
296+
if preset == "retinanet_resnet50_fpn_v2_coco"
297+
else False,
298+
)
299+
300+
convert_head_weights(state_dict, keras_model)
301+
print("✅ Loaded head weights")
302+
303+
filepath = keras.utils.get_file(
304+
origin="http://farm4.staticflickr.com/3755/10245052896_958cbf4766_z.jpg"
305+
)
306+
image = keras.utils.load_img(filepath)
307+
image = ops.cast(image, "float32")
308+
image = ops.expand_dims(image, axis=0)
309+
keras_image = preprocessor(image)
310+
torch_image = ops.transpose(keras_image, axes=(0, 3, 1, 2))
311+
torch_image = ops.convert_to_numpy(torch_image)
312+
torch_image = torch.from_numpy(torch_image)
313+
314+
keras_outputs = keras_model(keras_image)
315+
with torch.no_grad():
316+
torch_mid_outputs = list(torch_model.backbone(torch_image).values())
317+
torch_outputs = torch_model.head(torch_mid_outputs)
318+
319+
bbox_diff = np.mean(
320+
np.abs(
321+
ops.convert_to_numpy(keras_outputs["bbox_regression"])
322+
- torch_outputs["bbox_regression"].numpy()
323+
)
324+
)
325+
cls_logits_diff = np.mean(
326+
np.abs(
327+
ops.convert_to_numpy(keras_outputs["cls_logits"])
328+
- torch_outputs["cls_logits"].numpy()
329+
)
330+
)
331+
print("🔶 Modeling Bounding Box Logits difference:", bbox_diff)
332+
print("🔶 Modeling Class Logits difference:", cls_logits_diff)
333+
334+
keras_model.save_to_preset(f"./{preset}")
335+
print(f"🏁 Preset saved to ./{preset}.")
336+
337+
upload_uri = FLAGS.upload_uri
338+
if upload_uri:
339+
keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}")
340+
print(f"🏁 Preset uploaded to {upload_uri}")
341+
342+
343+
if __name__ == "__main__":
344+
app.run(main)

0 commit comments

Comments
 (0)