Skip to content

Commit f178c98

Browse files
committed
Add utility to map COCO IDs to class names
Similar to the ImageNet utility.
1 parent 0b63ce0 commit f178c98

File tree

6 files changed

+212
-0
lines changed

6 files changed

+212
-0
lines changed

keras_hub/api/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras_hub.src.utils.coco.coco_utils import coco_id_to_name
8+
from keras_hub.src.utils.coco.coco_utils import coco_name_to_id
79
from keras_hub.src.utils.imagenet.imagenet_utils import (
810
decode_imagenet_predictions,
911
)
12+
from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name
13+
from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id

keras_hub/src/utils/coco/__init__.py

Whitespace-only changes.
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
3+
4+
@keras_hub_export("keras_hub.utils.coco_id_to_name")
5+
def coco_id_to_name(id):
6+
"""Convert a single COCO class name to a class ID.
7+
8+
Args:
9+
id: An integer class id from 0 to 91.
10+
11+
Returns:
12+
The human readable image class name, e.g. "bicycle".
13+
14+
Example:
15+
```python
16+
>>> keras_hub.utils.coco_id_to_name(2)
17+
"bicycle"
18+
```
19+
"""
20+
return COCO_NAMES[id]
21+
22+
23+
@keras_hub_export("keras_hub.utils.coco_name_to_id")
24+
def coco_name_to_id(name):
25+
"""Convert a single COCO class name to a class ID.
26+
27+
Args:
28+
name: A human readable image class name, e.g. "bicycle".
29+
30+
Returns:
31+
The integer class id from 0 to 999.
32+
33+
Example:
34+
```python
35+
>>> keras_hub.utils.coco_name_to_id("bicycle")
36+
2
37+
```
38+
"""
39+
return COCO_IDS[name]
40+
41+
42+
COCO_NAMES = {
43+
0: "unlabeled",
44+
1: "person",
45+
2: "bicycle",
46+
3: "car",
47+
4: "motorcycle",
48+
5: "airplane",
49+
6: "bus",
50+
7: "train",
51+
8: "truck",
52+
9: "boat",
53+
10: "traffic_light",
54+
11: "fire_hydrant",
55+
12: "street_sign",
56+
13: "stop_sign",
57+
14: "parking_meter",
58+
15: "bench",
59+
16: "bird",
60+
17: "cat",
61+
18: "dog",
62+
19: "horse",
63+
20: "sheep",
64+
21: "cow",
65+
22: "elephant",
66+
23: "bear",
67+
24: "zebra",
68+
25: "giraffe",
69+
26: "hat",
70+
27: "backpack",
71+
28: "umbrella",
72+
29: "shoe",
73+
30: "eye_glasses",
74+
31: "handbag",
75+
32: "tie",
76+
33: "suitcase",
77+
34: "frisbee",
78+
35: "skis",
79+
36: "snowboard",
80+
37: "sports_ball",
81+
38: "kite",
82+
39: "baseball_bat",
83+
40: "baseball_glove",
84+
41: "skateboard",
85+
42: "surfboard",
86+
43: "tennis_racket",
87+
44: "bottle",
88+
45: "plate",
89+
46: "wine_glass",
90+
47: "cup",
91+
48: "fork",
92+
49: "knife",
93+
50: "spoon",
94+
51: "bowl",
95+
52: "banana",
96+
53: "apple",
97+
54: "sandwich",
98+
55: "orange",
99+
56: "broccoli",
100+
57: "carrot",
101+
58: "hot_dog",
102+
59: "pizza",
103+
60: "donut",
104+
61: "cake",
105+
62: "chair",
106+
63: "couch",
107+
64: "potted_plant",
108+
65: "bed",
109+
66: "mirror",
110+
67: "dining_table",
111+
68: "window",
112+
69: "desk",
113+
70: "toilet",
114+
71: "door",
115+
72: "tv",
116+
73: "laptop",
117+
74: "mouse",
118+
75: "remote",
119+
76: "keyboard",
120+
77: "cell_phone",
121+
78: "microwave",
122+
79: "oven",
123+
80: "toaster",
124+
81: "sink",
125+
82: "refrigerator",
126+
83: "blender",
127+
84: "book",
128+
85: "clock",
129+
86: "vase",
130+
87: "scissors",
131+
88: "teddy_bear",
132+
89: "hair_drier",
133+
90: "toothbrush",
134+
91: "hair_brush",
135+
}
136+
137+
COCO_IDS = {v: k for k, v in COCO_NAMES.items()}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from keras_hub.src.tests.test_case import TestCase
2+
from keras_hub.src.utils.coco.coco_utils import coco_id_to_name
3+
from keras_hub.src.utils.coco.coco_utils import coco_name_to_id
4+
5+
6+
class CocoUtilsTest(TestCase):
7+
def test_coco_id_to_name(self):
8+
self.assertEqual(coco_id_to_name(0), "unlabeled")
9+
self.assertEqual(coco_id_to_name(24), "zebra")
10+
with self.assertRaises(KeyError):
11+
coco_id_to_name(2001)
12+
13+
def test_coco_name_to_id(self):
14+
self.assertEqual(coco_name_to_id("unlabeled"), 0)
15+
self.assertEqual(coco_name_to_id("zebra"), 24)
16+
with self.assertRaises(KeyError):
17+
coco_name_to_id("whirligig")

keras_hub/src/utils/imagenet/imagenet_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,44 @@
33
from keras_hub.src.api_export import keras_hub_export
44

55

6+
@keras_hub_export("keras_hub.utils.imagenet_id_to_name")
7+
def imagenet_id_to_name(id):
8+
"""Convert a single ImageNet class ID to a class name.
9+
10+
Args:
11+
id: An integer class id from 0 to 999.
12+
13+
Returns:
14+
The human readable image class name, e.g. "goldfish".
15+
16+
Example:
17+
```python
18+
>>> keras_hub.utils.imagenet_id_to_name(1)
19+
"goldfish"
20+
```
21+
"""
22+
return IMAGENET_NAMES[id][1]
23+
24+
25+
@keras_hub_export("keras_hub.utils.imagenet_name_to_id")
26+
def imagenet_name_to_id(name):
27+
"""Convert a single ImageNet class name to a class ID.
28+
29+
Args:
30+
name: A human readable image class name, e.g. "goldfish".
31+
32+
Returns:
33+
The integer class id from 0 to 999.
34+
35+
Example:
36+
```python
37+
>>> keras_hub.utils.imagenet_name_to_id("goldfish")
38+
1
39+
```
40+
"""
41+
return IMAGENET_IDS[name]
42+
43+
644
@keras_hub_export("keras_hub.utils.decode_imagenet_predictions")
745
def decode_imagenet_predictions(preds, top=5, include_synset_ids=False):
846
"""Decodes the predictions for an ImageNet-1k prediction.
@@ -1052,3 +1090,5 @@ def decode_imagenet_predictions(preds, top=5, include_synset_ids=False):
10521090
998: ("n13133613", "ear"),
10531091
999: ("n15075141", "toilet_tissue"),
10541092
}
1093+
1094+
IMAGENET_IDS = {v[1]: k for k, v in IMAGENET_NAMES.items()}

keras_hub/src/utils/imagenet/imagenet_utils_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,23 @@
44
from keras_hub.src.utils.imagenet.imagenet_utils import (
55
decode_imagenet_predictions,
66
)
7+
from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_id_to_name
8+
from keras_hub.src.utils.imagenet.imagenet_utils import imagenet_name_to_id
79

810

911
class ImageNetUtilsTest(TestCase):
12+
def test_imagenet_id_to_name(self):
13+
self.assertEqual(imagenet_id_to_name(0), "tench")
14+
self.assertEqual(imagenet_id_to_name(21), "kite")
15+
with self.assertRaises(KeyError):
16+
imagenet_id_to_name(2001)
17+
18+
def test_imagenet_name_to_id(self):
19+
self.assertEqual(imagenet_name_to_id("tench"), 0)
20+
self.assertEqual(imagenet_name_to_id("kite"), 21)
21+
with self.assertRaises(KeyError):
22+
imagenet_name_to_id(2001)
23+
1024
def test_decode_imagenet_predictions(self):
1125
preds = np.array(
1226
[

0 commit comments

Comments
 (0)