Skip to content

Commit 25714a7

Browse files
author
The TensorFlow Datasets Authors
committed
add a new feature "display_image"
PiperOrigin-RevId: 638014851
1 parent 6bbba45 commit 25714a7

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tensorflow_datasets/robotics/dataset_importer_builder.py

+35
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,22 @@ def _info(self) -> tfds.core.DatasetInfo:
8383

8484
tmp = dict(features)
8585

86+
# add all image features from observations to a new featuresdict
87+
image_from_observation = {}
88+
if 'steps' in tmp and 'observation' in tmp['steps']:
89+
observation = tmp['steps']['observation']
90+
for feature_name, feature_data in observation.items():
91+
if isinstance(feature_data, tfds.features.Image):
92+
image_from_observation[feature_name] = feature_data
93+
image_from_observation_dict = tfds.features.FeaturesDict(
94+
image_from_observation
95+
)
96+
tmp['display_image'] = image_from_observation_dict
97+
8698
for key in self.KEYS_TO_STRIP:
8799
if key in tmp:
88100
del tmp[key]
101+
89102
features = tfds.features.FeaturesDict(tmp)
90103

91104
return tfds.core.DatasetInfo(
@@ -115,20 +128,42 @@ def _generate_examples(
115128
split = split_info.name
116129
read_config = read_config_lib.ReadConfig(add_tfds_id=True)
117130

131+
features = self.get_ds_builder().info.features
132+
tmp = dict(features)
133+
image_name_from_observation_set = set()
134+
if 'steps' in tmp and 'observation' in tmp['steps']:
135+
observation = tmp['steps']['observation']
136+
for feature_name, feature_data in observation.items():
137+
if isinstance(feature_data, tfds.features.Image):
138+
image_name_from_observation_set.add(feature_name)
139+
118140
decode_fn = builder.info.features['steps'].feature.decode_example
119141

120142
def converter_fn(example):
121143
# Decode the RLDS Episode and transform it to numpy.
122144
example_out = dict(example)
145+
123146
example_out['steps'] = tf.data.Dataset.from_tensor_slices(
124147
example_out['steps']
125148
).map(decode_fn)
149+
126150
steps = list(iter(example_out['steps'].take(-1)))
127151
example_out['steps'] = steps
128152

129153
example_out = dataset_utils.as_numpy(example_out)
154+
new_step = example_out['steps']
155+
first_step = new_step[0]
156+
image_feature_dict = {}
157+
158+
for feature_name, feature_data in first_step['observation'].items():
159+
# all the features are arrays now we cannot distinguish image/others
160+
if feature_name in image_name_from_observation_set:
161+
image_feature_dict[feature_name] = feature_data
162+
if image_feature_dict:
163+
example_out['display_image'] = image_feature_dict
130164

131165
example_id = example_out['tfds_id'].decode('utf-8')
166+
132167
del example_out['tfds_id']
133168
for key in self.KEYS_TO_STRIP:
134169
if key in example_out:

0 commit comments

Comments
 (0)