@@ -83,9 +83,22 @@ def _info(self) -> tfds.core.DatasetInfo:
83
83
84
84
tmp = dict (features )
85
85
86
+ # add all image features from observations to a new featuresdict
87
+ image_from_observation = {}
88
+ if 'steps' in tmp and 'observation' in tmp ['steps' ]:
89
+ observation = tmp ['steps' ]['observation' ]
90
+ for feature_name , feature_data in observation .items ():
91
+ if isinstance (feature_data , tfds .features .Image ):
92
+ image_from_observation [feature_name ] = feature_data
93
+ image_from_observation_dict = tfds .features .FeaturesDict (
94
+ image_from_observation
95
+ )
96
+ tmp ['display_image' ] = image_from_observation_dict
97
+
86
98
for key in self .KEYS_TO_STRIP :
87
99
if key in tmp :
88
100
del tmp [key ]
101
+
89
102
features = tfds .features .FeaturesDict (tmp )
90
103
91
104
return tfds .core .DatasetInfo (
@@ -115,20 +128,42 @@ def _generate_examples(
115
128
split = split_info .name
116
129
read_config = read_config_lib .ReadConfig (add_tfds_id = True )
117
130
131
+ features = self .get_ds_builder ().info .features
132
+ tmp = dict (features )
133
+ image_name_from_observation_set = set ()
134
+ if 'steps' in tmp and 'observation' in tmp ['steps' ]:
135
+ observation = tmp ['steps' ]['observation' ]
136
+ for feature_name , feature_data in observation .items ():
137
+ if isinstance (feature_data , tfds .features .Image ):
138
+ image_name_from_observation_set .add (feature_name )
139
+
118
140
decode_fn = builder .info .features ['steps' ].feature .decode_example
119
141
120
142
def converter_fn (example ):
121
143
# Decode the RLDS Episode and transform it to numpy.
122
144
example_out = dict (example )
145
+
123
146
example_out ['steps' ] = tf .data .Dataset .from_tensor_slices (
124
147
example_out ['steps' ]
125
148
).map (decode_fn )
149
+
126
150
steps = list (iter (example_out ['steps' ].take (- 1 )))
127
151
example_out ['steps' ] = steps
128
152
129
153
example_out = dataset_utils .as_numpy (example_out )
154
+ new_step = example_out ['steps' ]
155
+ first_step = new_step [0 ]
156
+ image_feature_dict = {}
157
+
158
+ for feature_name , feature_data in first_step ['observation' ].items ():
159
+ # all the features are arrays now we cannot distinguish image/others
160
+ if feature_name in image_name_from_observation_set :
161
+ image_feature_dict [feature_name ] = feature_data
162
+ if image_feature_dict :
163
+ example_out ['display_image' ] = image_feature_dict
130
164
131
165
example_id = example_out ['tfds_id' ].decode ('utf-8' )
166
+
132
167
del example_out ['tfds_id' ]
133
168
for key in self .KEYS_TO_STRIP :
134
169
if key in example_out :
0 commit comments