Skip to content

Commit

Permalink
Add support for multidimensional arrays in the CroissantBuilder.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730986386
  • Loading branch information
The TensorFlow Datasets Authors committed Feb 25, 2025
1 parent b7d6e96 commit 7c40c08
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tensorflow_datasets/core/dataset_builders/croissant_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from tensorflow_datasets.core.features import features_dict
from tensorflow_datasets.core.features import image_feature
from tensorflow_datasets.core.features import sequence_feature
from tensorflow_datasets.core.features import tensor_feature
from tensorflow_datasets.core.features import text_feature
from tensorflow_datasets.core.utils import conversion_utils
from tensorflow_datasets.core.utils import croissant_utils
Expand All @@ -75,6 +76,51 @@ def _strip_record_set_prefix(
}


def array_datatype_converter(
feature: type_utils.TfdsDType | feature_lib.FeatureConnector | None,
field: mlc.Field,
int_dtype: type_utils.TfdsDType = np.int64,
float_dtype: type_utils.TfdsDType = np.float32,
):
"""Includes the given feature in a sequence or tensor feature.
Single-dimensional arrays are converted to sequences. Multi-dimensional arrays
with unknown dimensions, or with non-native dtypes are converted to sequences
of sequences. Otherwise, they are converted to tensors.
Args:
feature: The inner feature to include in a sequence or tensor feature.
field: The mlc.Field object.
int_dtype: The dtype to use for TFDS integer features. Defaults to np.int64.
float_dtype: The dtype to use for TFDS float features. Defaults to
np.float32.
Returns:
A sequence or tensor feature including the inner feature.
"""
dtype_mapping = {
int: int_dtype,
float: float_dtype,
bool: np.bool_,
bytes: np.str_,
}
dtype = dtype_mapping.get(field.data_type, None)
if len(field.array_shape_tuple) == 1:
return sequence_feature.Sequence(feature, doc=field.description)
elif (-1 in field.array_shape_tuple) or (
field.data_type not in dtype_mapping
):
for _ in range(len(field.array_shape_tuple)):
feature = sequence_feature.Sequence(feature, doc=field.description)
return feature
else:
return tensor_feature.Tensor(
shape=field.array_shape_tuple,
dtype=dtype,
doc=field.description,
)


def datatype_converter(
field: mlc.Field,
int_dtype: type_utils.TfdsDType = np.int64,
Expand Down Expand Up @@ -133,6 +179,16 @@ def datatype_converter(
else:
raise ValueError(f'Unknown data type: {field_data_type}.')

if feature and field.is_array:
feature = array_datatype_converter(
feature=feature,
field=field,
int_dtype=int_dtype,
float_dtype=float_dtype,
)
# If the field is repeated, we return a sequence feature. `field.repeated` is
# deprecated starting from Croissant 1.1, but we still support it for
# backwards compatibility.
if feature and field.repeated:
feature = sequence_feature.Sequence(feature, doc=field.description)
return feature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,57 @@ def test_complex_datatype_converter(field, feature_type, subfield_types):
)


def test_multidimensional_datatype_converter():
field = mlc.Field(
data_types=mlc.DataType.TEXT,
description="Text feature",
is_array=True,
array_shape="2,2",
)
actual_feature = croissant_builder.datatype_converter(field)
assert isinstance(actual_feature, tensor_feature.Tensor)
assert actual_feature.shape == (2, 2)
assert actual_feature.dtype == np.str_


def test_multidimensional_datatype_converter_image_object():
field = mlc.Field(
data_types=mlc.DataType.IMAGE_OBJECT,
description="Text feature",
is_array=True,
array_shape="2,2",
)
actual_feature = croissant_builder.datatype_converter(field)
assert isinstance(actual_feature, sequence_feature.Sequence)
assert isinstance(actual_feature.feature, sequence_feature.Sequence)
assert isinstance(actual_feature.feature.feature, image_feature.Image)


def test_multidimensional_datatype_converter_plain_list():
field = mlc.Field(
data_types=mlc.DataType.TEXT,
description="Text feature",
is_array=True,
array_shape="-1",
)
actual_feature = croissant_builder.datatype_converter(field)
assert isinstance(actual_feature, sequence_feature.Sequence)
assert isinstance(actual_feature.feature, text_feature.Text)


def test_multidimensional_datatype_converter_unknown_shape():
field = mlc.Field(
data_types=mlc.DataType.TEXT,
description="Text feature",
is_array=True,
array_shape="-1,2",
)
actual_feature = croissant_builder.datatype_converter(field)
assert isinstance(actual_feature, sequence_feature.Sequence)
assert isinstance(actual_feature.feature, sequence_feature.Sequence)
assert isinstance(actual_feature.feature.feature, text_feature.Text)


def test_sequence_feature_datatype_converter():
field = mlc.Field(
data_types=mlc.DataType.TEXT,
Expand Down

0 comments on commit 7c40c08

Please sign in to comment.