From a159666fa06a7f2417d7b6bd98c9f2a967b8c161 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 22 Dec 2025 17:54:40 +0100 Subject: [PATCH 01/11] Add fixed shape tensor --- CMakeLists.txt | 6 +- include/sparrow_extensions.hpp | 3 + .../sparrow_extensions/fixed_shape_tensor.hpp | 259 +++++++ src/fixed_shape_tensor.cpp | 526 +++++++++++++++ tests/CMakeLists.txt | 1 + tests/test_fixed_shape_tensor.cpp | 636 ++++++++++++++++++ 6 files changed, 1429 insertions(+), 2 deletions(-) create mode 100644 include/sparrow_extensions/fixed_shape_tensor.hpp create mode 100644 src/fixed_shape_tensor.cpp create mode 100644 tests/test_fixed_shape_tensor.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 346012e..fed5995 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -182,9 +182,10 @@ set(SPARROW_EXTENSIONS_HEADERS ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions/config/sparrow_extensions_version.hpp # ./ - ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions/uuid_array.hpp - ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions/json_array.hpp ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions/bool8_array.hpp + ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions/fixed_shape_tensor.hpp + ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions/json_array.hpp + ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions/uuid_array.hpp #../ # ${SPARROW_EXTENSIONS_INCLUDE_DIR}/sparrow_extensions.hpp @@ -192,6 +193,7 @@ set(SPARROW_EXTENSIONS_HEADERS set(SPARROW_EXTENSIONS_SRC ${SPARROW_EXTENSIONS_SOURCE_DIR}/bool8_array.cpp + ${SPARROW_EXTENSIONS_SOURCE_DIR}/fixed_shape_tensor.cpp ${SPARROW_EXTENSIONS_SOURCE_DIR}/json_array.cpp ${SPARROW_EXTENSIONS_SOURCE_DIR}/uuid_array.cpp ) diff --git a/include/sparrow_extensions.hpp b/include/sparrow_extensions.hpp index 47143e4..86e6e5a 100644 --- a/include/sparrow_extensions.hpp +++ b/include/sparrow_extensions.hpp @@ -19,4 +19,7 @@ #include // Extensions +#include +#include +#include #include diff --git a/include/sparrow_extensions/fixed_shape_tensor.hpp b/include/sparrow_extensions/fixed_shape_tensor.hpp new file mode 100644 index 0000000..404bca1 --- /dev/null +++ b/include/sparrow_extensions/fixed_shape_tensor.hpp @@ -0,0 +1,259 @@ +// Copyright 2024 Man Group Operations Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "sparrow/array.hpp" +#include "sparrow/list_array.hpp" +#include "sparrow/types/data_type.hpp" +#include "sparrow/utils/contracts.hpp" + +namespace sparrow_extensions +{ + /** + * @brief Fixed shape tensor array implementation following Arrow canonical extension + * specification. + * + * This class implements an Arrow-compatible array for storing fixed-shape tensors + * according to the Apache Arrow canonical extension specification for fixed shape tensors. + * Each tensor is stored as a FixedSizeList with a product of shape dimensions as list size. + * + * The fixed shape tensor extension type is defined as: + * - Extension name: "arrow.fixed_shape_tensor" + * - Storage type: FixedSizeList where: + * - value_type is the data type of individual tensor elements + * - list_size is the product of all elements in tensor shape + * + * Extension type parameters: + * - value_type: the Arrow data type of individual tensor elements + * - shape: the physical shape of the contained tensors as an array + * - dim_names (optional): explicit names to tensor dimensions as an array + * - permutation (optional): indices of the desired ordering of the original dimensions + * + * The metadata must be a valid JSON object including shape of the contained tensors + * as an array with key "shape" plus optional dimension names with keys "dim_names" + * and ordering of the dimensions with key "permutation". + * + * Example metadata: + * - Simple shape: { "shape": [2, 5] } + * - With dim_names: { "shape": [100, 200, 500], "dim_names": ["C", "H", "W"] } + * - With permutation: { "shape": [100, 200, 500], "permutation": [2, 0, 1] } + * + * Note: Elements in a fixed shape tensor extension array are stored in + * row-major/C-contiguous order. + * + * Related Apache Arrow specification: + * https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor + */ + class fixed_shape_tensor_extension + { + public: + + static constexpr std::string_view EXTENSION_NAME = "arrow.fixed_shape_tensor"; + + /** + * @brief Metadata for fixed shape tensor extension. + * + * Stores the shape, optional dimension names, and optional permutation + * for the tensor layout. + */ + struct metadata + { + std::vector shape; + std::optional> dim_names; + std::optional> permutation; + + /** + * @brief Validates that the metadata is well-formed. + * + * @return true if metadata is valid, false otherwise + * + * Validation rules: + * - shape must not be empty + * - shape elements must all be positive + * - if dim_names is present, its size must equal shape size + * - if permutation is present: + * - its size must equal shape size + * - it must contain exactly the values [0, 1, ..., N-1] in some order + */ + [[nodiscard]] bool is_valid() const; + + /** + * @brief Computes the total number of elements (product of shape). + * + * @return Product of all dimensions in shape + */ + [[nodiscard]] std::int64_t compute_size() const; + + /** + * @brief Serializes metadata to JSON string. + * + * @return JSON string representation of the metadata + */ + [[nodiscard]] std::string to_json() const; + + /** + * @brief Deserializes metadata from JSON string. + * + * @param json JSON string to parse + * @return Parsed metadata structure + * @throws std::runtime_error if JSON is invalid + */ + [[nodiscard]] static metadata from_json(std::string_view json); + }; + + /** + * @brief Initializes the extension metadata on an arrow proxy. + * + * @param proxy Arrow proxy to initialize + * @param tensor_metadata Metadata describing the tensor shape and layout + * + * @pre proxy must represent a FixedSizeList + * @pre tensor_metadata must be valid + * @post Extension metadata is added to the proxy + */ + static void init(sparrow::arrow_proxy& proxy, const metadata& tensor_metadata); + + /** + * @brief Extracts metadata from an arrow proxy. + * + * @param proxy Arrow proxy to extract metadata from + * @return Metadata structure parsed from the proxy's extension metadata + * @throws std::runtime_error if metadata is missing or invalid + */ + [[nodiscard]] static metadata extract_metadata(const sparrow::arrow_proxy& proxy); + }; + + /** + * @brief Fixed shape tensor array wrapping a fixed_sized_list_array. + * + * This class provides a convenient interface for working with fixed-shape tensors + * while maintaining compatibility with the Arrow format. + */ + class fixed_shape_tensor_array + { + public: + + using size_type = std::size_t; + using metadata_type = fixed_shape_tensor_extension::metadata; + + /** + * @brief Constructs a fixed shape tensor array from an arrow proxy. + * + * @param proxy Arrow proxy containing the tensor data + * + * @pre proxy must contain valid Fixed Size List array data + * @pre proxy must have valid extension metadata + * @post Array is initialized with data from proxy + */ + explicit fixed_shape_tensor_array(sparrow::arrow_proxy proxy); + + /** + * @brief Constructs a fixed shape tensor array from values and shape. + * + * @param list_size Total number of elements per tensor (product of shape) + * @param flat_values Flattened sparrow array of all tensor elements in row-major order + * @param tensor_metadata Metadata describing the tensor shape and layout + * @param nullable Whether the array should support null values + * + * @pre flat_values.size() must be divisible by list_size + * @pre list_size must equal tensor_metadata.compute_size() + * @pre tensor_metadata must be valid + * @post Array contains tensors reshaped according to the metadata + */ + fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + bool nullable = true + ); + + // Default special members + fixed_shape_tensor_array(const fixed_shape_tensor_array&) = default; + fixed_shape_tensor_array& operator=(const fixed_shape_tensor_array&) = default; + fixed_shape_tensor_array(fixed_shape_tensor_array&&) noexcept = default; + fixed_shape_tensor_array& operator=(fixed_shape_tensor_array&&) noexcept = default; + ~fixed_shape_tensor_array() = default; + + /** + * @brief Returns the number of tensors in the array. + */ + [[nodiscard]] size_type size() const; + + /** + * @brief Returns the metadata describing the tensor shape and layout. + */ + [[nodiscard]] const metadata_type& get_metadata() const; + + /** + * @brief Returns the shape of each tensor. + */ + [[nodiscard]] const std::vector& shape() const; + + /** + * @brief Returns the underlying fixed_sized_list_array. + */ + [[nodiscard]] const sparrow::fixed_sized_list_array& storage() const; + + /** + * @brief Returns the underlying fixed_sized_list_array. + */ + [[nodiscard]] sparrow::fixed_sized_list_array& storage(); + + /** + * @brief Access tensor at index i. + * + * @param i Index of the tensor + * @return A list_value representing the tensor at index i + * + * @pre i < size() + */ + [[nodiscard]] auto operator[](size_type i) const -> decltype(std::declval()[i]); + + /** + * @brief Returns the underlying arrow_proxy. + */ + [[nodiscard]] const sparrow::arrow_proxy& get_arrow_proxy() const; + + /** + * @brief Returns the underlying arrow_proxy. + */ + [[nodiscard]] sparrow::arrow_proxy& get_arrow_proxy(); + + private: + + sparrow::fixed_sized_list_array m_storage; + metadata_type m_metadata; + }; + +} // namespace sparrow_extensions + +namespace sparrow::detail +{ + template <> + struct get_data_type_from_array + { + [[nodiscard]] static constexpr sparrow::data_type get() + { + return sparrow::data_type::FIXED_SIZED_LIST; + } + }; +} diff --git a/src/fixed_shape_tensor.cpp b/src/fixed_shape_tensor.cpp new file mode 100644 index 0000000..54a41a9 --- /dev/null +++ b/src/fixed_shape_tensor.cpp @@ -0,0 +1,526 @@ +// Copyright 2024 Man Group Operations Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparrow_extensions/fixed_shape_tensor.hpp" + +#include +#include +#include +#include + +#include "sparrow/layout/array_access.hpp" +#include "sparrow/layout/array_registry.hpp" +#include "sparrow/utils/contracts.hpp" + +#include "sparrow_extensions/config/config.hpp" + +namespace sparrow_extensions +{ + // Metadata implementation + bool fixed_shape_tensor_extension::metadata::is_valid() const + { + // Shape must not be empty and all dimensions must be positive + if (shape.empty()) + { + return false; + } + + for (const auto dim : shape) + { + if (dim <= 0) + { + return false; + } + } + + // If dim_names is present, it must match the shape size + if (dim_names.has_value() && dim_names->size() != shape.size()) + { + return false; + } + + // If permutation is present, validate it + if (permutation.has_value()) + { + const auto& perm = *permutation; + if (perm.size() != shape.size()) + { + return false; + } + + // Check that permutation contains exactly [0, 1, ..., N-1] + std::vector sorted_perm = perm; + std::ranges::sort(sorted_perm); + for (std::size_t i = 0; i < sorted_perm.size(); ++i) + { + if (sorted_perm[i] != static_cast(i)) + { + return false; + } + } + } + + return true; + } + + std::int64_t fixed_shape_tensor_extension::metadata::compute_size() const + { + return std::accumulate( + shape.begin(), + shape.end(), + std::int64_t{1}, + std::multiplies{} + ); + } + + std::string fixed_shape_tensor_extension::metadata::to_json() const + { + std::ostringstream oss; + oss << "{\"shape\":["; + for (std::size_t i = 0; i < shape.size(); ++i) + { + if (i > 0) + { + oss << ","; + } + oss << shape[i]; + } + oss << "]"; + + if (dim_names.has_value()) + { + oss << ",\"dim_names\":["; + for (std::size_t i = 0; i < dim_names->size(); ++i) + { + if (i > 0) + { + oss << ","; + } + oss << "\"" << (*dim_names)[i] << "\""; + } + oss << "]"; + } + + if (permutation.has_value()) + { + oss << ",\"permutation\":["; + for (std::size_t i = 0; i < permutation->size(); ++i) + { + if (i > 0) + { + oss << ","; + } + oss << (*permutation)[i]; + } + oss << "]"; + } + + oss << "}"; + return oss.str(); + } + + fixed_shape_tensor_extension::metadata fixed_shape_tensor_extension::metadata::from_json( + std::string_view json + ) + { + metadata result; + + // Simple JSON parser for the fixed structure we expect + // This is a minimal implementation - production code might use a proper JSON library + + std::string json_str(json); + std::size_t pos = 0; + + // Helper to skip whitespace + auto skip_whitespace = [&]() + { + while (pos < json_str.size() && std::isspace(json_str[pos])) + { + ++pos; + } + }; + + // Helper to read a string value + auto read_string = [&]() -> std::string + { + skip_whitespace(); + if (pos >= json_str.size() || json_str[pos] != '"') + { + throw std::runtime_error("Expected opening quote"); + } + ++pos; // Skip opening quote + + std::size_t start = pos; + while (pos < json_str.size() && json_str[pos] != '"') + { + ++pos; + } + + if (pos >= json_str.size()) + { + throw std::runtime_error("Expected closing quote"); + } + + std::string value = json_str.substr(start, pos - start); + ++pos; // Skip closing quote + return value; + }; + + // Helper to read an integer + auto read_int = [&]() -> std::int64_t + { + skip_whitespace(); + std::size_t start = pos; + if (pos < json_str.size() && (json_str[pos] == '-' || json_str[pos] == '+')) + { + ++pos; + } + while (pos < json_str.size() && std::isdigit(json_str[pos])) + { + ++pos; + } + return std::stoll(json_str.substr(start, pos - start)); + }; + + // Helper to read an array of integers + auto read_int_array = [&]() -> std::vector + { + std::vector arr; + skip_whitespace(); + if (pos >= json_str.size() || json_str[pos] != '[') + { + throw std::runtime_error("Expected opening bracket"); + } + ++pos; + + skip_whitespace(); + if (pos < json_str.size() && json_str[pos] == ']') + { + ++pos; + return arr; + } + + while (true) + { + arr.push_back(read_int()); + skip_whitespace(); + + if (pos >= json_str.size()) + { + throw std::runtime_error("Unexpected end of JSON"); + } + + if (json_str[pos] == ']') + { + ++pos; + break; + } + + if (json_str[pos] != ',') + { + throw std::runtime_error("Expected comma or closing bracket"); + } + ++pos; + } + + return arr; + }; + + // Helper to read an array of strings + auto read_string_array = [&]() -> std::vector + { + std::vector arr; + skip_whitespace(); + if (pos >= json_str.size() || json_str[pos] != '[') + { + throw std::runtime_error("Expected opening bracket"); + } + ++pos; + + skip_whitespace(); + if (pos < json_str.size() && json_str[pos] == ']') + { + ++pos; + return arr; + } + + while (true) + { + arr.push_back(read_string()); + skip_whitespace(); + + if (pos >= json_str.size()) + { + throw std::runtime_error("Unexpected end of JSON"); + } + + if (json_str[pos] == ']') + { + ++pos; + break; + } + + if (json_str[pos] != ',') + { + throw std::runtime_error("Expected comma or closing bracket"); + } + ++pos; + } + + return arr; + }; + + // Parse the JSON object + skip_whitespace(); + if (pos >= json_str.size() || json_str[pos] != '{') + { + throw std::runtime_error("Expected opening brace"); + } + ++pos; + + while (true) + { + skip_whitespace(); + if (pos >= json_str.size()) + { + throw std::runtime_error("Unexpected end of JSON"); + } + + if (json_str[pos] == '}') + { + ++pos; + break; + } + + std::string key = read_string(); + skip_whitespace(); + + if (pos >= json_str.size() || json_str[pos] != ':') + { + throw std::runtime_error("Expected colon after key"); + } + ++pos; + + if (key == "shape") + { + result.shape = read_int_array(); + } + else if (key == "dim_names") + { + result.dim_names = read_string_array(); + } + else if (key == "permutation") + { + result.permutation = read_int_array(); + } + else + { + throw std::runtime_error("Unknown key: " + key); + } + + skip_whitespace(); + if (pos >= json_str.size()) + { + throw std::runtime_error("Unexpected end of JSON"); + } + + if (json_str[pos] == '}') + { + ++pos; + break; + } + + if (json_str[pos] != ',') + { + throw std::runtime_error("Expected comma or closing brace"); + } + ++pos; + } + + if (result.shape.empty()) + { + throw std::runtime_error("Missing required 'shape' field"); + } + + if (!result.is_valid()) + { + throw std::runtime_error("Invalid metadata"); + } + + return result; + } + + void fixed_shape_tensor_extension::init( + sparrow::arrow_proxy& proxy, + const metadata& tensor_metadata + ) + { + SPARROW_ASSERT_TRUE(tensor_metadata.is_valid()); + + // Get existing metadata + std::optional existing_metadata = proxy.metadata(); + std::vector extension_metadata = + existing_metadata.has_value() + ? std::vector( + existing_metadata->begin(), + existing_metadata->end() + ) + : std::vector{}; + + // Check if extension metadata already exists + const bool has_extension_name = std::ranges::find_if( + extension_metadata, + [](const auto& pair) + { + return pair.first == "ARROW:extension:name" + && pair.second == EXTENSION_NAME; + } + ) + != extension_metadata.end(); + + if (!has_extension_name) + { + extension_metadata.emplace_back("ARROW:extension:name", std::string(EXTENSION_NAME)); + extension_metadata.emplace_back( + "ARROW:extension:metadata", + tensor_metadata.to_json() + ); + } + + proxy.set_metadata(std::make_optional(extension_metadata)); + } + + fixed_shape_tensor_extension::metadata fixed_shape_tensor_extension::extract_metadata( + const sparrow::arrow_proxy& proxy + ) + { + std::optional metadata_opt = proxy.metadata(); + if (!metadata_opt.has_value()) + { + throw std::runtime_error("Missing extension metadata"); + } + + const auto& metadata = *metadata_opt; + std::string metadata_json; + + for (const auto& [key, value] : metadata) + { + if (key == "ARROW:extension:metadata") + { + metadata_json = value; + break; + } + } + + if (metadata_json.empty()) + { + throw std::runtime_error("Missing ARROW:extension:metadata"); + } + + return metadata::from_json(metadata_json); + } + + // fixed_shape_tensor_array implementation + + fixed_shape_tensor_array::fixed_shape_tensor_array(sparrow::arrow_proxy proxy) + : m_storage(proxy) + , m_metadata(fixed_shape_tensor_extension::extract_metadata(proxy)) + { + SPARROW_ASSERT_TRUE(m_metadata.is_valid()); + } + + fixed_shape_tensor_array::fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + bool nullable + ) + : m_storage(list_size, std::move(flat_values), nullable) + , m_metadata(tensor_metadata) + { + SPARROW_ASSERT_TRUE(m_metadata.is_valid()); + SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); + SPARROW_ASSERT_TRUE(m_storage.size() * list_size == (flat_values.size() ? flat_values.size() : m_storage.size() * list_size)); + + // Add extension metadata to the storage using array_access + fixed_shape_tensor_extension::init( + sparrow::detail::array_access::get_arrow_proxy(m_storage), + m_metadata + ); + } + + auto fixed_shape_tensor_array::size() const -> size_type + { + return m_storage.size(); + } + + auto fixed_shape_tensor_array::get_metadata() const -> const metadata_type& + { + return m_metadata; + } + + auto fixed_shape_tensor_array::shape() const -> const std::vector& + { + return m_metadata.shape; + } + + auto fixed_shape_tensor_array::storage() const -> const sparrow::fixed_sized_list_array& + { + return m_storage; + } + + auto fixed_shape_tensor_array::storage() -> sparrow::fixed_sized_list_array& + { + return m_storage; + } + + auto fixed_shape_tensor_array::operator[](size_type i) const -> decltype(std::declval()[i]) + { + return m_storage[i]; + } + + auto fixed_shape_tensor_array::get_arrow_proxy() const -> const sparrow::arrow_proxy& + { + return sparrow::detail::array_access::get_arrow_proxy(m_storage); + } + + auto fixed_shape_tensor_array::get_arrow_proxy() -> sparrow::arrow_proxy& + { + return sparrow::detail::array_access::get_arrow_proxy(m_storage); + } + +} // namespace sparrow_extensions + +namespace sparrow::detail +{ + SPARROW_EXTENSIONS_API const bool fixed_shape_tensor_array_registered = []() + { + auto& registry = array_registry::instance(); + + registry.register_extension( + data_type::FIXED_SIZED_LIST, + "arrow.fixed_shape_tensor", + [](arrow_proxy proxy) + { + return cloning_ptr{ + new array_wrapper_impl( + sparrow_extensions::fixed_shape_tensor_array(std::move(proxy)) + ) + }; + } + ); + + return true; + }(); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d488678..12ab355 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -35,6 +35,7 @@ endif() set(SPARROW_EXTENSIONS_TESTS_SOURCES main.cpp test_bool8_array.cpp + test_fixed_shape_tensor.cpp test_json_array.cpp test_uuid_array.cpp metadata_sample.hpp diff --git a/tests/test_fixed_shape_tensor.cpp b/tests/test_fixed_shape_tensor.cpp new file mode 100644 index 0000000..136508e --- /dev/null +++ b/tests/test_fixed_shape_tensor.cpp @@ -0,0 +1,636 @@ +// Copyright 2024 Man Group Operations Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include + +#include + +#include +#include + +#include "sparrow_extensions/fixed_shape_tensor.hpp" + +namespace sparrow_extensions +{ + TEST_SUITE("fixed_shape_tensor") + { + using metadata = fixed_shape_tensor_extension::metadata; + + TEST_CASE("metadata::is_valid") + { + SUBCASE("valid simple shape") + { + metadata meta{{2, 3}, std::nullopt, std::nullopt}; + CHECK(meta.is_valid()); + } + + SUBCASE("valid with dim_names") + { + metadata meta{{100, 200, 500}, std::vector{"C", "H", "W"}, std::nullopt}; + CHECK(meta.is_valid()); + } + + SUBCASE("valid with permutation") + { + metadata meta{{100, 200, 500}, std::nullopt, std::vector{2, 0, 1}}; + CHECK(meta.is_valid()); + } + + SUBCASE("valid with both dim_names and permutation") + { + metadata meta{ + {100, 200, 500}, + std::vector{"C", "H", "W"}, + std::vector{2, 0, 1} + }; + CHECK(meta.is_valid()); + } + + SUBCASE("invalid empty shape") + { + metadata meta{{}, std::nullopt, std::nullopt}; + CHECK_FALSE(meta.is_valid()); + } + + SUBCASE("invalid negative dimension") + { + metadata meta{{2, -3}, std::nullopt, std::nullopt}; + CHECK_FALSE(meta.is_valid()); + } + + SUBCASE("invalid zero dimension") + { + metadata meta{{2, 0, 4}, std::nullopt, std::nullopt}; + CHECK_FALSE(meta.is_valid()); + } + + SUBCASE("invalid dim_names size mismatch") + { + metadata meta{{100, 200}, std::vector{"C", "H", "W"}, std::nullopt}; + CHECK_FALSE(meta.is_valid()); + } + + SUBCASE("invalid permutation size mismatch") + { + metadata meta{{100, 200, 500}, std::nullopt, std::vector{2, 0}}; + CHECK_FALSE(meta.is_valid()); + } + + SUBCASE("invalid permutation values") + { + metadata meta{{100, 200, 500}, std::nullopt, std::vector{0, 0, 1}}; + CHECK_FALSE(meta.is_valid()); + } + + SUBCASE("invalid permutation out of range") + { + metadata meta{{100, 200, 500}, std::nullopt, std::vector{0, 1, 3}}; + CHECK_FALSE(meta.is_valid()); + } + } + + TEST_CASE("compute_size") + { + SUBCASE("simple 2D") + { + metadata meta{{2, 5}, std::nullopt, std::nullopt}; + CHECK_EQ(meta.compute_size(), 10); + } + + SUBCASE("3D tensor") + { + metadata meta{{100, 200, 500}, std::nullopt, std::nullopt}; + CHECK_EQ(meta.compute_size(), 10000000); + } + + SUBCASE("1D tensor") + { + metadata meta{{42}, std::nullopt, std::nullopt}; + CHECK_EQ(meta.compute_size(), 42); + } + + SUBCASE("4D tensor") + { + metadata meta{{2, 3, 4, 5}, std::nullopt, std::nullopt}; + CHECK_EQ(meta.compute_size(), 120); + } + } + + TEST_CASE("to_json") + { + SUBCASE("simple shape") + { + metadata meta{{2, 5}, std::nullopt, std::nullopt}; + const std::string json = meta.to_json(); + CHECK_EQ(json, R"({"shape":[2,5]})"); + } + + SUBCASE("with dim_names") + { + metadata meta{{100, 200, 500}, std::vector{"C", "H", "W"}, std::nullopt}; + const std::string json = meta.to_json(); + CHECK_EQ(json, R"({"shape":[100,200,500],"dim_names":["C","H","W"]})"); + } + + SUBCASE("with permutation") + { + metadata meta{{100, 200, 500}, std::nullopt, std::vector{2, 0, 1}}; + const std::string json = meta.to_json(); + CHECK_EQ(json, R"({"shape":[100,200,500],"permutation":[2,0,1]})"); + } + + SUBCASE("with both dim_names and permutation") + { + metadata meta{ + {100, 200, 500}, + std::vector{"C", "H", "W"}, + std::vector{2, 0, 1} + }; + const std::string json = meta.to_json(); + CHECK_EQ( + json, + R"({"shape":[100,200,500],"dim_names":["C","H","W"],"permutation":[2,0,1]})" + ); + } + } + + TEST_CASE("from_json") + { + SUBCASE("simple shape") + { + const std::string json = R"({"shape":[2,5]})"; + const metadata meta = metadata::from_json(json); + CHECK(meta.is_valid()); + REQUIRE_EQ(meta.shape.size(), 2); + CHECK_EQ(meta.shape[0], 2); + CHECK_EQ(meta.shape[1], 5); + CHECK_FALSE(meta.dim_names.has_value()); + CHECK_FALSE(meta.permutation.has_value()); + } + + SUBCASE("with dim_names") + { + const std::string json = R"({"shape":[100,200,500],"dim_names":["C","H","W"]})"; + const metadata meta = metadata::from_json(json); + CHECK(meta.is_valid()); + REQUIRE_EQ(meta.shape.size(), 3); + CHECK_EQ(meta.shape[0], 100); + CHECK_EQ(meta.shape[1], 200); + CHECK_EQ(meta.shape[2], 500); + REQUIRE(meta.dim_names.has_value()); + REQUIRE_EQ(meta.dim_names->size(), 3); + CHECK_EQ((*meta.dim_names)[0], "C"); + CHECK_EQ((*meta.dim_names)[1], "H"); + CHECK_EQ((*meta.dim_names)[2], "W"); + CHECK_FALSE(meta.permutation.has_value()); + } + + SUBCASE("with permutation") + { + const std::string json = R"({"shape":[100,200,500],"permutation":[2,0,1]})"; + const metadata meta = metadata::from_json(json); + CHECK(meta.is_valid()); + REQUIRE_EQ(meta.shape.size(), 3); + CHECK_FALSE(meta.dim_names.has_value()); + REQUIRE(meta.permutation.has_value()); + REQUIRE_EQ(meta.permutation->size(), 3); + CHECK_EQ((*meta.permutation)[0], 2); + CHECK_EQ((*meta.permutation)[1], 0); + CHECK_EQ((*meta.permutation)[2], 1); + } + + SUBCASE("with whitespace") + { + const std::string json = R"( { "shape" : [ 2 , 5 ] } )"; + const metadata meta = metadata::from_json(json); + CHECK(meta.is_valid()); + REQUIRE_EQ(meta.shape.size(), 2); + CHECK_EQ(meta.shape[0], 2); + CHECK_EQ(meta.shape[1], 5); + } + + SUBCASE("invalid - missing shape") + { + const std::string json = R"({"dim_names":["C","H","W"]})"; + CHECK_THROWS_AS(metadata::from_json(json), std::runtime_error); + } + + SUBCASE("invalid - malformed JSON") + { + const std::string json = R"({"shape":[2,5)"; + CHECK_THROWS_AS(metadata::from_json(json), std::runtime_error); + } + } + + TEST_CASE("round-trip serialization") + { + SUBCASE("simple") + { + metadata original{{2, 5}, std::nullopt, std::nullopt}; + const std::string json = original.to_json(); + const metadata parsed = metadata::from_json(json); + CHECK(parsed.shape == original.shape); + CHECK(parsed.dim_names == original.dim_names); + CHECK(parsed.permutation == original.permutation); + } + + SUBCASE("complex") + { + metadata original{ + {100, 200, 500}, + std::vector{"C", "H", "W"}, + std::vector{2, 0, 1} + }; + const std::string json = original.to_json(); + const metadata parsed = metadata::from_json(json); + CHECK(parsed.shape == original.shape); + CHECK(parsed.dim_names == original.dim_names); + CHECK(parsed.permutation == original.permutation); + } + } + + TEST_CASE("fixed_shape_tensor_array::constructor with simple 2D tensors") + { + // Create a flattened array of 3 tensors of shape [2, 3] + // Total elements: 3 * 2 * 3 = 18 + std::vector flat_data; + for (int i = 0; i < 18; ++i) + { + flat_data.push_back(static_cast(i)); + } + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 3); + CHECK(tensor_array.shape() == shape); + + const auto& retrieved_meta = tensor_array.get_metadata(); + CHECK(retrieved_meta.shape == shape); + CHECK_FALSE(retrieved_meta.dim_names.has_value()); + CHECK_FALSE(retrieved_meta.permutation.has_value()); + } + + TEST_CASE("constructor with 3D tensors and dim_names") + { + // Create 2 tensors of shape [2, 2, 2] + // Total elements: 2 * 2 * 2 * 2 = 16 + std::vector flat_data(16); + std::iota(flat_data.begin(), flat_data.end(), 0); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 2, 2}; + const std::vector dim_names{"X", "Y", "Z"}; + metadata tensor_meta{shape, dim_names, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 2); + CHECK(tensor_array.shape() == shape); + + const auto& meta = tensor_array.get_metadata(); + CHECK(meta.shape == shape); + REQUIRE(meta.dim_names.has_value()); + CHECK(*meta.dim_names == dim_names); + CHECK_FALSE(meta.permutation.has_value()); + } + + TEST_CASE("constructor with permutation") + { + // Create 1 tensor of shape [3, 4, 5] + // Total elements: 60 + std::vector flat_data(60); + std::iota(flat_data.begin(), flat_data.end(), 0.0); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{3, 4, 5}; + const std::vector permutation{2, 0, 1}; // Logical shape is [5, 3, 4] + metadata tensor_meta{shape, std::nullopt, permutation}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 1); + CHECK(tensor_array.shape() == shape); + + const auto& meta = tensor_array.get_metadata(); + CHECK(meta.shape == shape); + CHECK_FALSE(meta.dim_names.has_value()); + REQUIRE(meta.permutation.has_value()); + CHECK(*meta.permutation == permutation); + } + + TEST_CASE("constructor with validity bitmap") + { + // Create 4 tensors of shape [2, 2] + // Total elements: 4 * 2 * 2 = 16 + std::vector flat_data(16); + std::iota(flat_data.begin(), flat_data.end(), 0); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 2}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 4); + + // When nullable is true but no validity bitmap is provided, + // all elements are valid by default + const auto& storage = tensor_array.storage(); + CHECK(storage[0].has_value()); + CHECK(storage[1].has_value()); + CHECK(storage[2].has_value()); + CHECK(storage[3].has_value()); + } + + TEST_CASE("element access") + { + // Create 2 tensors of shape [2, 3] + std::vector flat_data{ + // First tensor + 1.0f, + 2.0f, + 3.0f, + 4.0f, + 5.0f, + 6.0f, + // Second tensor + 7.0f, + 8.0f, + 9.0f, + 10.0f, + 11.0f, + 12.0f + }; + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + // Access first tensor + auto tensor0 = tensor_array[0]; + CHECK(tensor0.has_value()); + + // Access second tensor + auto tensor1 = tensor_array[1]; + CHECK(tensor1.has_value()); + } + + TEST_CASE("1D tensor (vector)") + { + // Create 5 vectors of length 10 + std::vector flat_data(50); + std::iota(flat_data.begin(), flat_data.end(), 0); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{10}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 5); + CHECK_EQ(tensor_array.shape()[0], 10); + } + + TEST_CASE("extension metadata roundtrip") + { + // Create array with metadata + std::vector flat_data(12); + std::iota(flat_data.begin(), flat_data.end(), 0.0f); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + const std::vector dim_names{"rows", "cols"}; + const std::vector permutation{1, 0}; + metadata tensor_meta{shape, dim_names, permutation}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + // Get and verify the metadata + const auto& extracted_meta = tensor_array.get_metadata(); + + // Verify all fields + CHECK(extracted_meta.shape == shape); + REQUIRE(extracted_meta.dim_names.has_value()); + CHECK(*extracted_meta.dim_names == dim_names); + REQUIRE(extracted_meta.permutation.has_value()); + CHECK(*extracted_meta.permutation == permutation); + } + + TEST_CASE("copy constructor") + { + std::vector flat_data{1, 2, 3, 4, 5, 6}; + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array original( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + fixed_shape_tensor_array copy(original); + + CHECK_EQ(copy.size(), original.size()); + CHECK(copy.shape() == original.shape()); + CHECK(copy.get_metadata().shape == original.get_metadata().shape); + } + + TEST_CASE("move constructor") + { + std::vector flat_data{1, 2, 3, 4, 5, 6}; + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array original( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + const std::size_t original_size = original.size(); + const auto original_shape = original.shape(); + + fixed_shape_tensor_array moved(std::move(original)); + + CHECK_EQ(moved.size(), original_size); + CHECK(moved.shape() == original_shape); + } + + TEST_CASE("storage access") + { + std::vector flat_data(6); + std::iota(flat_data.begin(), flat_data.end(), 0.0f); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + // Test const access + const auto& const_storage = tensor_array.storage(); + CHECK_EQ(const_storage.size(), 1); + + // Test mutable access + auto& mut_storage = tensor_array.storage(); + CHECK_EQ(mut_storage.size(), 1); + } + + TEST_CASE("spec examples") + { + SUBCASE("Example: { \"shape\": [2, 5]}") + { + std::vector flat_data(10); // 1 tensor of shape [2, 5] + std::iota(flat_data.begin(), flat_data.end(), 0.0); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 5}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 1); + CHECK(tensor_array.shape() == shape); + CHECK_EQ(tensor_array.get_metadata().compute_size(), 10); + } + + SUBCASE("Example with dim_names for NCHW: { \"shape\": [100, 200, 500], \"dim_names\": " + "[\"C\", \"H\", \"W\"]}") + { + // Just one tensor for testing + const std::int64_t tensor_size = 100 * 200 * 500; + std::vector flat_data(tensor_size, 0.0f); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{100, 200, 500}; + const std::vector dim_names{"C", "H", "W"}; + metadata tensor_meta{shape, dim_names, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 1); + CHECK(tensor_array.shape() == shape); + const auto& retrieved_meta = tensor_array.get_metadata(); + REQUIRE(retrieved_meta.dim_names.has_value()); + CHECK(*retrieved_meta.dim_names == dim_names); + } + + SUBCASE("Example with permutation: { \"shape\": [100, 200, 500], \"permutation\": [2, 0, " + "1]}") + { + // Physical shape [100, 200, 500], logical shape [500, 100, 200] + const std::int64_t tensor_size = 100 * 200 * 500; + std::vector flat_data(tensor_size, 0.0f); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{100, 200, 500}; + const std::vector permutation{2, 0, 1}; + metadata tensor_meta{shape, std::nullopt, permutation}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + true + ); + + CHECK_EQ(tensor_array.size(), 1); + CHECK(tensor_array.shape() == shape); // Physical shape + const auto& retrieved_meta = tensor_array.get_metadata(); + REQUIRE(retrieved_meta.permutation.has_value()); + CHECK(*retrieved_meta.permutation == permutation); + + // Note: Logical shape would be [500, 100, 200] + // which is shape[permutation[i]] for each i + } + } + } +} // namespace sparrow_extensions From c70efecffe4ccb14cd5427063fe7c28f5c581ca6 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Mon, 22 Dec 2025 18:32:12 +0100 Subject: [PATCH 02/11] wip --- .../sparrow_extensions/fixed_shape_tensor.hpp | 7 +- src/fixed_shape_tensor.cpp | 129 ++++-------------- tests/test_fixed_shape_tensor.cpp | 1 - 3 files changed, 35 insertions(+), 102 deletions(-) diff --git a/include/sparrow_extensions/fixed_shape_tensor.hpp b/include/sparrow_extensions/fixed_shape_tensor.hpp index 404bca1..f0b20c6 100644 --- a/include/sparrow_extensions/fixed_shape_tensor.hpp +++ b/include/sparrow_extensions/fixed_shape_tensor.hpp @@ -26,6 +26,8 @@ #include "sparrow/types/data_type.hpp" #include "sparrow/utils/contracts.hpp" +#include "sparrow_extensions/config/config.hpp" + namespace sparrow_extensions { /** @@ -148,7 +150,7 @@ namespace sparrow_extensions * This class provides a convenient interface for working with fixed-shape tensors * while maintaining compatibility with the Arrow format. */ - class fixed_shape_tensor_array + class SPARROW_EXTENSIONS_API fixed_shape_tensor_array { public: @@ -226,7 +228,8 @@ namespace sparrow_extensions * * @pre i < size() */ - [[nodiscard]] auto operator[](size_type i) const -> decltype(std::declval()[i]); + [[nodiscard]] auto operator[](size_type i) const + -> decltype(std::declval()[i]); /** * @brief Returns the underlying arrow_proxy. diff --git a/src/fixed_shape_tensor.cpp b/src/fixed_shape_tensor.cpp index 54a41a9..ba0c832 100644 --- a/src/fixed_shape_tensor.cpp +++ b/src/fixed_shape_tensor.cpp @@ -31,19 +31,11 @@ namespace sparrow_extensions bool fixed_shape_tensor_extension::metadata::is_valid() const { // Shape must not be empty and all dimensions must be positive - if (shape.empty()) + if (shape.empty() || !std::ranges::all_of(shape, [](auto dim) { return dim > 0; })) { return false; } - for (const auto dim : shape) - { - if (dim <= 0) - { - return false; - } - } - // If dim_names is present, it must match the shape size if (dim_names.has_value() && dim_names->size() != shape.size()) { @@ -62,13 +54,10 @@ namespace sparrow_extensions // Check that permutation contains exactly [0, 1, ..., N-1] std::vector sorted_perm = perm; std::ranges::sort(sorted_perm); - for (std::size_t i = 0; i < sorted_perm.size(); ++i) - { - if (sorted_perm[i] != static_cast(i)) - { - return false; - } - } + return std::ranges::equal( + sorted_perm, + std::views::iota(std::int64_t{0}, static_cast(sorted_perm.size())) + ); } return true; @@ -86,27 +75,28 @@ namespace sparrow_extensions std::string fixed_shape_tensor_extension::metadata::to_json() const { - std::ostringstream oss; - oss << "{\"shape\":["; - for (std::size_t i = 0; i < shape.size(); ++i) + // Helper to serialize integer array + auto serialize_int_array = [](std::ostringstream& oss, const std::vector& arr) { - if (i > 0) + oss << "["; + for (std::size_t i = 0; i < arr.size(); ++i) { - oss << ","; + if (i > 0) oss << ","; + oss << arr[i]; } - oss << shape[i]; - } - oss << "]"; + oss << "]"; + }; + + std::ostringstream oss; + oss << "{\"shape\":"; + serialize_int_array(oss, shape); if (dim_names.has_value()) { oss << ",\"dim_names\":["; for (std::size_t i = 0; i < dim_names->size(); ++i) { - if (i > 0) - { - oss << ","; - } + if (i > 0) oss << ","; oss << "\"" << (*dim_names)[i] << "\""; } oss << "]"; @@ -114,16 +104,8 @@ namespace sparrow_extensions if (permutation.has_value()) { - oss << ",\"permutation\":["; - for (std::size_t i = 0; i < permutation->size(); ++i) - { - if (i > 0) - { - oss << ","; - } - oss << (*permutation)[i]; - } - oss << "]"; + oss << ",\"permutation\":"; + serialize_int_array(oss, *permutation); } oss << "}"; @@ -193,10 +175,10 @@ namespace sparrow_extensions return std::stoll(json_str.substr(start, pos - start)); }; - // Helper to read an array of integers - auto read_int_array = [&]() -> std::vector + // Generic helper to read an array + auto read_array = [&](auto reader) -> std::vector { - std::vector arr; + std::vector arr; skip_whitespace(); if (pos >= json_str.size() || json_str[pos] != '[') { @@ -213,7 +195,7 @@ namespace sparrow_extensions while (true) { - arr.push_back(read_int()); + arr.push_back(reader()); skip_whitespace(); if (pos >= json_str.size()) @@ -237,49 +219,8 @@ namespace sparrow_extensions return arr; }; - // Helper to read an array of strings - auto read_string_array = [&]() -> std::vector - { - std::vector arr; - skip_whitespace(); - if (pos >= json_str.size() || json_str[pos] != '[') - { - throw std::runtime_error("Expected opening bracket"); - } - ++pos; - - skip_whitespace(); - if (pos < json_str.size() && json_str[pos] == ']') - { - ++pos; - return arr; - } - - while (true) - { - arr.push_back(read_string()); - skip_whitespace(); - - if (pos >= json_str.size()) - { - throw std::runtime_error("Unexpected end of JSON"); - } - - if (json_str[pos] == ']') - { - ++pos; - break; - } - - if (json_str[pos] != ',') - { - throw std::runtime_error("Expected comma or closing bracket"); - } - ++pos; - } - - return arr; - }; + auto read_int_array = [&]() { return read_array.template operator()(read_int); }; + auto read_string_array = [&]() { return read_array.template operator()(read_string); }; // Parse the JSON object skip_whitespace(); @@ -405,30 +346,21 @@ namespace sparrow_extensions const sparrow::arrow_proxy& proxy ) { - std::optional metadata_opt = proxy.metadata(); + const auto metadata_opt = proxy.metadata(); if (!metadata_opt.has_value()) { throw std::runtime_error("Missing extension metadata"); } - const auto& metadata = *metadata_opt; - std::string metadata_json; - - for (const auto& [key, value] : metadata) + for (const auto& [key, value] : *metadata_opt) { if (key == "ARROW:extension:metadata") { - metadata_json = value; - break; + return metadata::from_json(value); } } - if (metadata_json.empty()) - { - throw std::runtime_error("Missing ARROW:extension:metadata"); - } - - return metadata::from_json(metadata_json); + throw std::runtime_error("Missing ARROW:extension:metadata"); } // fixed_shape_tensor_array implementation @@ -451,7 +383,6 @@ namespace sparrow_extensions { SPARROW_ASSERT_TRUE(m_metadata.is_valid()); SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); - SPARROW_ASSERT_TRUE(m_storage.size() * list_size == (flat_values.size() ? flat_values.size() : m_storage.size() * list_size)); // Add extension metadata to the storage using array_access fixed_shape_tensor_extension::init( diff --git a/tests/test_fixed_shape_tensor.cpp b/tests/test_fixed_shape_tensor.cpp index 136508e..a964921 100644 --- a/tests/test_fixed_shape_tensor.cpp +++ b/tests/test_fixed_shape_tensor.cpp @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include From d8974f2ce0d5c18c848499693f1ca7f6ba4820c0 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 09:56:06 +0100 Subject: [PATCH 03/11] wip --- .../sparrow_extensions/fixed_shape_tensor.hpp | 128 +++++++++++- src/fixed_shape_tensor.cpp | 35 +++- tests/test_fixed_shape_tensor.cpp | 195 ++++++++++++++---- 3 files changed, 309 insertions(+), 49 deletions(-) diff --git a/include/sparrow_extensions/fixed_shape_tensor.hpp b/include/sparrow_extensions/fixed_shape_tensor.hpp index f0b20c6..b1dcdc8 100644 --- a/include/sparrow_extensions/fixed_shape_tensor.hpp +++ b/include/sparrow_extensions/fixed_shape_tensor.hpp @@ -21,10 +21,9 @@ #include #include -#include "sparrow/array.hpp" +#include "sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp" // Workaround for sparrow 2.0.0 bug #include "sparrow/list_array.hpp" #include "sparrow/types/data_type.hpp" -#include "sparrow/utils/contracts.hpp" #include "sparrow_extensions/config/config.hpp" @@ -77,7 +76,7 @@ namespace sparrow_extensions * Stores the shape, optional dimension names, and optional permutation * for the tensor layout. */ - struct metadata + struct SPARROW_EXTENSIONS_API metadata { std::vector shape; std::optional> dim_names; @@ -174,18 +173,89 @@ namespace sparrow_extensions * @param list_size Total number of elements per tensor (product of shape) * @param flat_values Flattened sparrow array of all tensor elements in row-major order * @param tensor_metadata Metadata describing the tensor shape and layout - * @param nullable Whether the array should support null values * * @pre flat_values.size() must be divisible by list_size * @pre list_size must equal tensor_metadata.compute_size() * @pre tensor_metadata must be valid * @post Array contains tensors reshaped according to the metadata */ + fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata + ); + + /** + * @brief Constructs a fixed shape tensor array with name and/or metadata. + * + * @param list_size Total number of elements per tensor (product of shape) + * @param flat_values Flattened sparrow array of all tensor elements in row-major order + * @param tensor_metadata Metadata describing the tensor shape and layout + * @param name Name for the array + * @param arrow_metadata Optional Arrow metadata key-value pairs + * + * @pre flat_values.size() must be divisible by list_size + * @pre list_size must equal tensor_metadata.compute_size() + * @pre tensor_metadata must be valid + * @post Array contains tensors with the specified name and metadata + */ + fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + std::string_view name, + std::optional> arrow_metadata = std::nullopt + ); + + /** + * @brief Constructs a fixed shape tensor array with validity bitmap. + * + * @tparam VB Type of validity bitmap input + * @param list_size Total number of elements per tensor (product of shape) + * @param flat_values Flattened sparrow array of all tensor elements in row-major order + * @param tensor_metadata Metadata describing the tensor shape and layout + * @param validity_input Validity bitmap (one bit per tensor) + * + * @pre flat_values.size() must be divisible by list_size + * @pre list_size must equal tensor_metadata.compute_size() + * @pre tensor_metadata must be valid + * @pre validity_input size must match number of tensors (flat_values.size() / list_size) + * @post Array contains tensors with the specified validity bitmap + */ + template + fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + VB&& validity_input + ); + + /** + * @brief Constructs a fixed shape tensor array with validity, name, and metadata. + * + * @tparam VB Type of validity bitmap input + * @tparam METADATA_RANGE Type of metadata container + * @param list_size Total number of elements per tensor (product of shape) + * @param flat_values Flattened sparrow array of all tensor elements in row-major order + * @param tensor_metadata Metadata describing the tensor shape and layout + * @param validity_input Validity bitmap (one bit per tensor) + * @param name Optional name for the array + * @param arrow_metadata Optional Arrow metadata key-value pairs + * + * @pre flat_values.size() must be divisible by list_size + * @pre list_size must equal tensor_metadata.compute_size() + * @pre tensor_metadata must be valid + * @pre validity_input size must match number of tensors (flat_values.size() / list_size) + * @post Array contains tensors with the specified validity bitmap, name, and metadata + */ + template > fixed_shape_tensor_array( std::uint64_t list_size, sparrow::array&& flat_values, const metadata_type& tensor_metadata, - bool nullable = true + VB&& validity_input, + std::optional name, + std::optional arrow_metadata = std::nullopt ); // Default special members @@ -247,6 +317,54 @@ namespace sparrow_extensions metadata_type m_metadata; }; + // Template constructor implementations + + template + fixed_shape_tensor_array::fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + VB&& validity_input + ) + : m_storage(list_size, std::move(flat_values), std::forward(validity_input)) + , m_metadata(tensor_metadata) + { + SPARROW_ASSERT_TRUE(m_metadata.is_valid()); + SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); + + fixed_shape_tensor_extension::init( + sparrow::detail::array_access::get_arrow_proxy(m_storage), + m_metadata + ); + } + + template + fixed_shape_tensor_array::fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + VB&& validity_input, + std::optional name, + std::optional arrow_metadata + ) + : m_storage( + list_size, + std::move(flat_values), + std::forward(validity_input), + name, + std::forward>(arrow_metadata) + ) + , m_metadata(tensor_metadata) + { + SPARROW_ASSERT_TRUE(m_metadata.is_valid()); + SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); + + fixed_shape_tensor_extension::init( + sparrow::detail::array_access::get_arrow_proxy(m_storage), + m_metadata + ); + } + } // namespace sparrow_extensions namespace sparrow::detail diff --git a/src/fixed_shape_tensor.cpp b/src/fixed_shape_tensor.cpp index ba0c832..ca49c0c 100644 --- a/src/fixed_shape_tensor.cpp +++ b/src/fixed_shape_tensor.cpp @@ -375,10 +375,9 @@ namespace sparrow_extensions fixed_shape_tensor_array::fixed_shape_tensor_array( std::uint64_t list_size, sparrow::array&& flat_values, - const metadata_type& tensor_metadata, - bool nullable + const metadata_type& tensor_metadata ) - : m_storage(list_size, std::move(flat_values), nullable) + : m_storage(list_size, std::move(flat_values)) , m_metadata(tensor_metadata) { SPARROW_ASSERT_TRUE(m_metadata.is_valid()); @@ -391,6 +390,32 @@ namespace sparrow_extensions ); } + fixed_shape_tensor_array::fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + std::string_view name, + std::optional> arrow_metadata + ) + : m_storage(list_size, std::move(flat_values)) + , m_metadata(tensor_metadata) + { + SPARROW_ASSERT_TRUE(m_metadata.is_valid()); + SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); + + // Get the proxy and set name/metadata + auto& proxy = sparrow::detail::array_access::get_arrow_proxy(m_storage); + proxy.set_name(name); + + if (arrow_metadata.has_value()) + { + proxy.set_metadata(std::make_optional(*arrow_metadata)); + } + + // Add extension metadata + fixed_shape_tensor_extension::init(proxy, m_metadata); + } + auto fixed_shape_tensor_array::size() const -> size_type { return m_storage.size(); @@ -406,12 +431,12 @@ namespace sparrow_extensions return m_metadata.shape; } - auto fixed_shape_tensor_array::storage() const -> const sparrow::fixed_sized_list_array& + const sparrow::fixed_sized_list_array& fixed_shape_tensor_array::storage() const { return m_storage; } - auto fixed_shape_tensor_array::storage() -> sparrow::fixed_sized_list_array& + sparrow::fixed_sized_list_array& fixed_shape_tensor_array::storage() { return m_storage; } diff --git a/tests/test_fixed_shape_tensor.cpp b/tests/test_fixed_shape_tensor.cpp index a964921..bbbe440 100644 --- a/tests/test_fixed_shape_tensor.cpp +++ b/tests/test_fixed_shape_tensor.cpp @@ -281,9 +281,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 3); CHECK(tensor_array.shape() == shape); @@ -310,9 +308,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 2); CHECK(tensor_array.shape() == shape); @@ -340,9 +336,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 1); CHECK(tensor_array.shape() == shape); @@ -369,9 +363,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 4); @@ -412,9 +404,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); // Access first tensor auto tensor0 = tensor_array[0]; @@ -439,9 +429,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 5); CHECK_EQ(tensor_array.shape()[0], 10); @@ -463,9 +451,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); // Get and verify the metadata const auto& extracted_meta = tensor_array.get_metadata(); @@ -489,9 +475,7 @@ namespace sparrow_extensions fixed_shape_tensor_array original( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); fixed_shape_tensor_array copy(original); @@ -511,9 +495,7 @@ namespace sparrow_extensions fixed_shape_tensor_array original( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); const std::size_t original_size = original.size(); const auto original_shape = original.shape(); @@ -537,9 +519,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); // Test const access const auto& const_storage = tensor_array.storage(); @@ -565,9 +545,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 1); CHECK(tensor_array.shape() == shape); @@ -590,9 +568,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 1); CHECK(tensor_array.shape() == shape); @@ -617,9 +593,7 @@ namespace sparrow_extensions fixed_shape_tensor_array tensor_array( list_size, sparrow::array(std::move(values_array)), - tensor_meta, - true - ); + tensor_meta); CHECK_EQ(tensor_array.size(), 1); CHECK(tensor_array.shape() == shape); // Physical shape @@ -631,5 +605,148 @@ namespace sparrow_extensions // which is shape[permutation[i]] for each i } } + + TEST_CASE("constructor with name and metadata") + { + SUBCASE("with name only") + { + std::vector flat_data; + for (int i = 0; i < 12; ++i) + { + flat_data.push_back(static_cast(i)); + } + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + "my_tensor_array"); + + CHECK_EQ(tensor_array.size(), 2); + CHECK(tensor_array.shape() == shape); + + const auto& proxy = tensor_array.get_arrow_proxy(); + CHECK(proxy.name() == "my_tensor_array"); + } + + SUBCASE("with metadata only") + { + std::vector flat_data(8); + std::iota(flat_data.begin(), flat_data.end(), 0); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 2}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + std::vector arrow_meta{ + {"key1", "value1"}, + {"key2", "value2"} + }; + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + "", // empty name + arrow_meta); + + CHECK_EQ(tensor_array.size(), 2); + CHECK(tensor_array.shape() == shape); + + const auto& proxy = tensor_array.get_arrow_proxy(); + const auto metadata_opt = proxy.metadata(); + REQUIRE(metadata_opt.has_value()); + + bool found_key1 = false; + bool found_key2 = false; + for (const auto& [key, value] : *metadata_opt) + { + if (key == "key1" && value == "value1") + found_key1 = true; + if (key == "key2" && value == "value2") + found_key2 = true; + } + CHECK(found_key1); + CHECK(found_key2); + } + + SUBCASE("with both name and metadata") + { + std::vector flat_data(24); + std::iota(flat_data.begin(), flat_data.end(), 0.0); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3, 4}; + const std::vector dim_names{"X", "Y", "Z"}; + metadata tensor_meta{shape, dim_names, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + std::vector arrow_meta{ + {"author", "test"}, + {"version", "1.0"} + }; + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + "named_tensor", + arrow_meta); + + CHECK_EQ(tensor_array.size(), 1); + CHECK(tensor_array.shape() == shape); + + const auto& proxy = tensor_array.get_arrow_proxy(); + CHECK(proxy.name() == "named_tensor"); + + const auto& meta = tensor_array.get_metadata(); + REQUIRE(meta.dim_names.has_value()); + CHECK(*meta.dim_names == dim_names); + + const auto metadata_opt = proxy.metadata(); + REQUIRE(metadata_opt.has_value()); + + bool found_extension = false; + bool found_author = false; + for (const auto& [key, value] : *metadata_opt) + { + if (key == "ARROW:extension:name" && value == "arrow.fixed_shape_tensor") + found_extension = true; + if (key == "author" && value == "test") + found_author = true; + } + CHECK(found_extension); + CHECK(found_author); + } + + SUBCASE("simple name without metadata") + { + std::vector flat_data(6); + std::iota(flat_data.begin(), flat_data.end(), 0.0f); + + sparrow::primitive_array values_array(flat_data); + const std::vector shape{2, 3}; + metadata tensor_meta{shape, std::nullopt, std::nullopt}; + const std::uint64_t list_size = static_cast(tensor_meta.compute_size()); + + fixed_shape_tensor_array tensor_array( + list_size, + sparrow::array(std::move(values_array)), + tensor_meta, + "test_array"); + + CHECK_EQ(tensor_array.size(), 1); + CHECK(tensor_array.shape() == shape); + + const auto& proxy = tensor_array.get_arrow_proxy(); + CHECK(proxy.name() == "test_array"); + } + } } } // namespace sparrow_extensions From 825a17ee62a5563645e7fb016ee17016b96effff Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 10:07:07 +0100 Subject: [PATCH 04/11] try --- .../sparrow_extensions/fixed_shape_tensor.hpp | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/include/sparrow_extensions/fixed_shape_tensor.hpp b/include/sparrow_extensions/fixed_shape_tensor.hpp index b1dcdc8..cd78a16 100644 --- a/include/sparrow_extensions/fixed_shape_tensor.hpp +++ b/include/sparrow_extensions/fixed_shape_tensor.hpp @@ -347,22 +347,27 @@ namespace sparrow_extensions std::optional name, std::optional arrow_metadata ) - : m_storage( - list_size, - std::move(flat_values), - std::forward(validity_input), - name, - std::forward>(arrow_metadata) - ) + : m_storage(list_size, std::move(flat_values), std::forward(validity_input)) , m_metadata(tensor_metadata) { SPARROW_ASSERT_TRUE(m_metadata.is_valid()); SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); - fixed_shape_tensor_extension::init( - sparrow::detail::array_access::get_arrow_proxy(m_storage), - m_metadata - ); + // Get the proxy and set name/metadata if provided + auto& proxy = sparrow::detail::array_access::get_arrow_proxy(m_storage); + + if (name.has_value()) + { + proxy.set_name(*name); + } + + if (arrow_metadata.has_value()) + { + proxy.set_metadata(std::make_optional(*arrow_metadata)); + } + + // Add extension metadata + fixed_shape_tensor_extension::init(proxy, m_metadata); } } // namespace sparrow_extensions From 54fcaba6a8df0feb795ce8ccbdde490933e33e1d Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 10:17:21 +0100 Subject: [PATCH 05/11] fix --- include/sparrow_extensions/fixed_shape_tensor.hpp | 2 +- src/fixed_shape_tensor.cpp | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/include/sparrow_extensions/fixed_shape_tensor.hpp b/include/sparrow_extensions/fixed_shape_tensor.hpp index cd78a16..47458d3 100644 --- a/include/sparrow_extensions/fixed_shape_tensor.hpp +++ b/include/sparrow_extensions/fixed_shape_tensor.hpp @@ -21,7 +21,7 @@ #include #include -#include "sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp" // Workaround for sparrow 2.0.0 bug +#include "sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp" #include "sparrow/list_array.hpp" #include "sparrow/types/data_type.hpp" diff --git a/src/fixed_shape_tensor.cpp b/src/fixed_shape_tensor.cpp index ca49c0c..7cd54a6 100644 --- a/src/fixed_shape_tensor.cpp +++ b/src/fixed_shape_tensor.cpp @@ -22,6 +22,7 @@ #include "sparrow/layout/array_access.hpp" #include "sparrow/layout/array_registry.hpp" #include "sparrow/utils/contracts.hpp" +#include "sparrow/array.hpp" #include "sparrow_extensions/config/config.hpp" @@ -377,7 +378,7 @@ namespace sparrow_extensions sparrow::array&& flat_values, const metadata_type& tensor_metadata ) - : m_storage(list_size, std::move(flat_values)) + : m_storage(list_size, std::move(flat_values), std::vector{}) , m_metadata(tensor_metadata) { SPARROW_ASSERT_TRUE(m_metadata.is_valid()); @@ -397,7 +398,7 @@ namespace sparrow_extensions std::string_view name, std::optional> arrow_metadata ) - : m_storage(list_size, std::move(flat_values)) + : m_storage(list_size, std::move(flat_values), std::vector{}) , m_metadata(tensor_metadata) { SPARROW_ASSERT_TRUE(m_metadata.is_valid()); From 98045a864a258958a5f99a4d7952dd59d13796e0 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 11:12:48 +0100 Subject: [PATCH 06/11] Use simd json --- CMakeLists.txt | 2 +- cmake/external_dependencies.cmake | 13 +- environment-dev.yml | 2 +- src/fixed_shape_tensor.cpp | 361 ++++++++++++------------------ 4 files changed, 153 insertions(+), 225 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fed5995..b69dc7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -248,7 +248,7 @@ target_include_directories(sparrow-extensions PUBLIC set_target_properties(sparrow-extensions PROPERTIES CMAKE_CXX_EXTENSIONS OFF) target_compile_features(sparrow-extensions PUBLIC cxx_std_20) -target_link_libraries(sparrow-extensions PUBLIC sparrow::sparrow ${SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES}) +target_link_libraries(sparrow-extensions PUBLIC ${SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES}) if(ENABLE_COVERAGE) enable_coverage(sparrow-extensions) diff --git a/cmake/external_dependencies.cmake b/cmake/external_dependencies.cmake index ebc8586..8901aba 100644 --- a/cmake/external_dependencies.cmake +++ b/cmake/external_dependencies.cmake @@ -78,11 +78,22 @@ if(NOT TARGET sparrow::sparrow) add_library(sparrow::sparrow ALIAS sparrow) endif() +# add sparrow::sparrow to SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES list +set(SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES sparrow::sparrow) + +find_package_or_fetch( + PACKAGE_NAME simdjson + GIT_REPOSITORY https://github.com/simdjson/simdjson.git + TAG v2.4.12 +) + +set(SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES ${SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES} simdjson::simdjson) + if(SPARROW_EXTENSIONS_BUILD_TESTS) find_package_or_fetch( PACKAGE_NAME doctest GIT_REPOSITORY https://github.com/doctest/doctest.git - TAG v2.4.12 + TAG v4.2.4 ) # better_junit_reporter is provided by sparrow diff --git a/environment-dev.yml b/environment-dev.yml index 453d838..1b0f537 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,7 +7,7 @@ dependencies: - ninja - cxx-compiler # Libraries dependencies - - nlohmann_json + - simdjson - sparrow-devel >=1.4.0 # Testing dependencies - doctest diff --git a/src/fixed_shape_tensor.cpp b/src/fixed_shape_tensor.cpp index 7cd54a6..9588fda 100644 --- a/src/fixed_shape_tensor.cpp +++ b/src/fixed_shape_tensor.cpp @@ -19,6 +19,8 @@ #include #include +#include + #include "sparrow/layout/array_access.hpp" #include "sparrow/layout/array_registry.hpp" #include "sparrow/utils/contracts.hpp" @@ -28,6 +30,19 @@ namespace sparrow_extensions { + namespace + { + // JSON serialization size estimation constants + constexpr std::size_t json_base_size = 10; // {"shape":[]} + constexpr std::size_t json_integer_avg_size = 10; // Average size per integer + constexpr std::size_t json_dim_names_overhead = 15; // ,"dim_names":[] + constexpr std::size_t json_string_overhead = 3; // "name", + constexpr std::size_t json_permutation_overhead = 17; // ,"permutation":[] + + // JSON parsing capacity hints + constexpr std::size_t typical_tensor_dimensions = 8; // Typical tensor rank (2-4 dims, reserve 8) + } + // Metadata implementation bool fixed_shape_tensor_extension::metadata::is_valid() const { @@ -52,13 +67,18 @@ namespace sparrow_extensions return false; } - // Check that permutation contains exactly [0, 1, ..., N-1] - std::vector sorted_perm = perm; - std::ranges::sort(sorted_perm); - return std::ranges::equal( - sorted_perm, - std::views::iota(std::int64_t{0}, static_cast(sorted_perm.size())) - ); + // Check that permutation contains exactly [0, 1, ..., N-1] without copying + // Use a bitset to track seen indices + std::vector seen(perm.size(), false); + for (const auto idx : perm) + { + if (idx < 0 || static_cast(idx) >= perm.size() || seen[static_cast(idx)]) + { + return false; + } + seen[static_cast(idx)] = true; + } + return true; } return true; @@ -66,241 +86,137 @@ namespace sparrow_extensions std::int64_t fixed_shape_tensor_extension::metadata::compute_size() const { - return std::accumulate( + return std::reduce( shape.begin(), shape.end(), std::int64_t{1}, - std::multiplies{} + std::multiplies<>{} ); } std::string fixed_shape_tensor_extension::metadata::to_json() const { - // Helper to serialize integer array - auto serialize_int_array = [](std::ostringstream& oss, const std::vector& arr) + // Pre-calculate approximate size to minimize allocations + std::size_t estimated_size = json_base_size; + estimated_size += shape.size() * json_integer_avg_size; + if (dim_names.has_value()) + { + estimated_size += json_dim_names_overhead; + for (const auto& name : *dim_names) + { + estimated_size += name.size() + json_string_overhead; + } + } + if (permutation.has_value()) + { + estimated_size += json_permutation_overhead + permutation->size() * json_integer_avg_size; + } + + std::string result; + result.reserve(estimated_size); + + auto serialize_array = [&result](const auto& arr, auto&& formatter) { - oss << "["; - for (std::size_t i = 0; i < arr.size(); ++i) + result += '['; + bool first = true; + for (const auto& item : arr) { - if (i > 0) oss << ","; - oss << arr[i]; + if (!first) result += ','; + first = false; + formatter(item); } - oss << "]"; + result += ']'; }; - std::ostringstream oss; - oss << "{\"shape\":"; - serialize_int_array(oss, shape); + result += "{\"shape\":"; + serialize_array(shape, [&result](const auto& val) { result += std::to_string(val); }); if (dim_names.has_value()) { - oss << ",\"dim_names\":["; - for (std::size_t i = 0; i < dim_names->size(); ++i) - { - if (i > 0) oss << ","; - oss << "\"" << (*dim_names)[i] << "\""; - } - oss << "]"; + result += ",\"dim_names\":"; + serialize_array(*dim_names, [&result](const auto& val) { + result += '\"'; + result += val; + result += '\"'; + }); } if (permutation.has_value()) { - oss << ",\"permutation\":"; - serialize_int_array(oss, *permutation); + result += ",\"permutation\":"; + serialize_array(*permutation, [&result](const auto& val) { result += std::to_string(val); }); } - oss << "}"; - return oss.str(); + result += '}'; + return result; } fixed_shape_tensor_extension::metadata fixed_shape_tensor_extension::metadata::from_json( std::string_view json ) { - metadata result; - - // Simple JSON parser for the fixed structure we expect - // This is a minimal implementation - production code might use a proper JSON library - - std::string json_str(json); - std::size_t pos = 0; - - // Helper to skip whitespace - auto skip_whitespace = [&]() + auto parse_int_array = [](simdjson::ondemand::array arr) -> std::vector { - while (pos < json_str.size() && std::isspace(json_str[pos])) + std::vector result; + result.reserve(typical_tensor_dimensions); + for (auto value : arr) { - ++pos; + result.push_back(static_cast(value.get_int64())); } + return result; }; - // Helper to read a string value - auto read_string = [&]() -> std::string + auto parse_string_array = [](simdjson::ondemand::array arr) -> std::vector { - skip_whitespace(); - if (pos >= json_str.size() || json_str[pos] != '"') - { - throw std::runtime_error("Expected opening quote"); - } - ++pos; // Skip opening quote - - std::size_t start = pos; - while (pos < json_str.size() && json_str[pos] != '"') - { - ++pos; - } - - if (pos >= json_str.size()) - { - throw std::runtime_error("Expected closing quote"); - } - - std::string value = json_str.substr(start, pos - start); - ++pos; // Skip closing quote - return value; - }; - - // Helper to read an integer - auto read_int = [&]() -> std::int64_t - { - skip_whitespace(); - std::size_t start = pos; - if (pos < json_str.size() && (json_str[pos] == '-' || json_str[pos] == '+')) - { - ++pos; - } - while (pos < json_str.size() && std::isdigit(json_str[pos])) - { - ++pos; - } - return std::stoll(json_str.substr(start, pos - start)); - }; - - // Generic helper to read an array - auto read_array = [&](auto reader) -> std::vector - { - std::vector arr; - skip_whitespace(); - if (pos >= json_str.size() || json_str[pos] != '[') - { - throw std::runtime_error("Expected opening bracket"); - } - ++pos; - - skip_whitespace(); - if (pos < json_str.size() && json_str[pos] == ']') + std::vector result; + result.reserve(typical_tensor_dimensions); + for (auto value : arr) { - ++pos; - return arr; + result.emplace_back(value.get_string().value()); } - - while (true) - { - arr.push_back(reader()); - skip_whitespace(); - - if (pos >= json_str.size()) - { - throw std::runtime_error("Unexpected end of JSON"); - } - - if (json_str[pos] == ']') - { - ++pos; - break; - } - - if (json_str[pos] != ',') - { - throw std::runtime_error("Expected comma or closing bracket"); - } - ++pos; - } - - return arr; + return result; }; - auto read_int_array = [&]() { return read_array.template operator()(read_int); }; - auto read_string_array = [&]() { return read_array.template operator()(read_string); }; - - // Parse the JSON object - skip_whitespace(); - if (pos >= json_str.size() || json_str[pos] != '{') - { - throw std::runtime_error("Expected opening brace"); - } - ++pos; - - while (true) + try { - skip_whitespace(); - if (pos >= json_str.size()) - { - throw std::runtime_error("Unexpected end of JSON"); - } - - if (json_str[pos] == '}') + metadata result; + + simdjson::ondemand::parser parser; + simdjson::padded_string padded_json(json); + simdjson::ondemand::document doc = parser.iterate(padded_json); + + // Parse shape (required) + result.shape = parse_int_array(doc["shape"].get_array()); + + if (result.shape.empty()) { - ++pos; - break; + throw std::runtime_error("Missing required 'shape' field"); } - - std::string key = read_string(); - skip_whitespace(); - - if (pos >= json_str.size() || json_str[pos] != ':') + + // Parse optional fields + auto dim_names_field = doc["dim_names"]; + if (dim_names_field.error() == simdjson::SUCCESS) { - throw std::runtime_error("Expected colon after key"); + result.dim_names = parse_string_array(dim_names_field.get_array()); } - ++pos; - - if (key == "shape") + + auto permutation_field = doc["permutation"]; + if (permutation_field.error() == simdjson::SUCCESS) { - result.shape = read_int_array(); + result.permutation = parse_int_array(permutation_field.get_array()); } - else if (key == "dim_names") + + if (!result.is_valid()) { - result.dim_names = read_string_array(); + throw std::runtime_error("Invalid metadata"); } - else if (key == "permutation") - { - result.permutation = read_int_array(); - } - else - { - throw std::runtime_error("Unknown key: " + key); - } - - skip_whitespace(); - if (pos >= json_str.size()) - { - throw std::runtime_error("Unexpected end of JSON"); - } - - if (json_str[pos] == '}') - { - ++pos; - break; - } - - if (json_str[pos] != ',') - { - throw std::runtime_error("Expected comma or closing brace"); - } - ++pos; + + return result; } - - if (result.shape.empty()) - { - throw std::runtime_error("Missing required 'shape' field"); - } - - if (!result.is_valid()) + catch (const simdjson::simdjson_error& e) { - throw std::runtime_error("Invalid metadata"); + throw std::runtime_error(std::string("JSON parsing error: ") + e.what()); } - - return result; } void fixed_shape_tensor_extension::init( @@ -311,36 +227,40 @@ namespace sparrow_extensions SPARROW_ASSERT_TRUE(tensor_metadata.is_valid()); // Get existing metadata - std::optional existing_metadata = proxy.metadata(); - std::vector extension_metadata = - existing_metadata.has_value() - ? std::vector( - existing_metadata->begin(), - existing_metadata->end() - ) - : std::vector{}; - - // Check if extension metadata already exists - const bool has_extension_name = std::ranges::find_if( - extension_metadata, - [](const auto& pair) - { - return pair.first == "ARROW:extension:name" - && pair.second == EXTENSION_NAME; - } - ) - != extension_metadata.end(); - - if (!has_extension_name) + auto existing_metadata = proxy.metadata(); + std::vector extension_metadata; + + if (existing_metadata.has_value()) { - extension_metadata.emplace_back("ARROW:extension:name", std::string(EXTENSION_NAME)); - extension_metadata.emplace_back( - "ARROW:extension:metadata", - tensor_metadata.to_json() - ); + extension_metadata.assign(existing_metadata->begin(), existing_metadata->end()); + + // Check if extension metadata already exists + const bool has_extension_name = std::ranges::find_if( + extension_metadata, + [](const auto& pair) + { + return pair.first == "ARROW:extension:name" + && pair.second == EXTENSION_NAME; + } + ) + != extension_metadata.end(); + + if (has_extension_name) + { + proxy.set_metadata(std::make_optional(std::move(extension_metadata))); + return; + } } + + // Reserve space for new entries + extension_metadata.reserve(extension_metadata.size() + 2); + extension_metadata.emplace_back("ARROW:extension:name", std::string(EXTENSION_NAME)); + extension_metadata.emplace_back( + "ARROW:extension:metadata", + tensor_metadata.to_json() + ); - proxy.set_metadata(std::make_optional(extension_metadata)); + proxy.set_metadata(std::make_optional(std::move(extension_metadata))); } fixed_shape_tensor_extension::metadata fixed_shape_tensor_extension::extract_metadata( @@ -384,7 +304,6 @@ namespace sparrow_extensions SPARROW_ASSERT_TRUE(m_metadata.is_valid()); SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); - // Add extension metadata to the storage using array_access fixed_shape_tensor_extension::init( sparrow::detail::array_access::get_arrow_proxy(m_storage), m_metadata @@ -404,7 +323,6 @@ namespace sparrow_extensions SPARROW_ASSERT_TRUE(m_metadata.is_valid()); SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); - // Get the proxy and set name/metadata auto& proxy = sparrow::detail::array_access::get_arrow_proxy(m_storage); proxy.set_name(name); @@ -413,7 +331,6 @@ namespace sparrow_extensions proxy.set_metadata(std::make_optional(*arrow_metadata)); } - // Add extension metadata fixed_shape_tensor_extension::init(proxy, m_metadata); } From 609d6593f9af8a006deed3eaade661269a20800b Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 11:30:29 +0100 Subject: [PATCH 07/11] fix --- cmake/external_dependencies.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/external_dependencies.cmake b/cmake/external_dependencies.cmake index 8901aba..f4d92ab 100644 --- a/cmake/external_dependencies.cmake +++ b/cmake/external_dependencies.cmake @@ -84,7 +84,7 @@ set(SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES sparrow::sparrow) find_package_or_fetch( PACKAGE_NAME simdjson GIT_REPOSITORY https://github.com/simdjson/simdjson.git - TAG v2.4.12 + TAG v4.2.4 ) set(SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES ${SPARROW_EXTENSIONS_INTERFACE_DEPENDENCIES} simdjson::simdjson) @@ -93,7 +93,7 @@ if(SPARROW_EXTENSIONS_BUILD_TESTS) find_package_or_fetch( PACKAGE_NAME doctest GIT_REPOSITORY https://github.com/doctest/doctest.git - TAG v4.2.4 + TAG v2.4.12 ) # better_junit_reporter is provided by sparrow From bb2a572a7077624935b1d5e206068b75554436a2 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 11:35:13 +0100 Subject: [PATCH 08/11] Add documentation --- README.md | 2 + docs/source/fixed_shape_tensor_array.md | 506 ++++++++++++++++++++++++ docs/source/main_page.md | 2 + 3 files changed, 510 insertions(+) create mode 100644 docs/source/fixed_shape_tensor_array.md diff --git a/README.md b/README.md index 357d7c0..f14fd2d 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,8 @@ Extension types for the [sparrow](https://github.com/man-group/sparrow) library - `uuid_array`: Arrow-compatible array for storing UUID values as 16-byte fixed-width binary according to the `arrow.uuid` extension type specification. - `json_array`: Arrow-compatible array for storing JSON values as UTF-8 strings according to the `arrow.json` extension type specification. - `bool8_array`: Arrow-compatible array for storing boolean values as 8-bit integers according to the `arrow.bool8` extension type specification. +- `fixed_shape_tensor_array`: Arrow-compatible array for storing fixed-shape tensors according to the `arrow.fixed_shape_tensor` extension type specification. + ## Installation ### Install from sources diff --git a/docs/source/fixed_shape_tensor_array.md b/docs/source/fixed_shape_tensor_array.md new file mode 100644 index 0000000..472aac3 --- /dev/null +++ b/docs/source/fixed_shape_tensor_array.md @@ -0,0 +1,506 @@ +Fixed Shape Tensor Array {#fixed_shape_tensor_array} +=========================== + +Introduction +------------ + +The Fixed Shape Tensor Array is an Arrow-compatible array for storing multi-dimensional tensors with a fixed shape according to the [Apache Arrow canonical extension specification for FixedShapeTensor](https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor). + +This extension enables efficient storage and transfer of tensors (multi-dimensional arrays) with the same shape. Each element in the array represents a complete tensor of the specified shape. The underlying storage uses Arrow's `FixedSizeList` type to store flattened tensor data. + +The FixedShapeTensor extension type is defined as: +- Extension name: `arrow.fixed_shape_tensor` +- Storage type: `FixedSizeList` where `T` is the value type +- Extension metadata: JSON object containing shape, optional dimension names, and optional permutation +- List size: Product of all dimensions in the shape + +Metadata Structure +------------------ + +The extension metadata is a JSON object with the following fields: + +```json +{ + "shape": [dim0, dim1, ..., dimN], + "dim_names": ["name0", "name1", ..., "nameN"], // optional + "permutation": [idx0, idx1, ..., idxN] // optional +} +``` + +### Fields + +- **shape** (required): Array of positive integers specifying the dimensions of each tensor +- **dim_names** (optional): Array of strings naming each dimension (must match shape length) +- **permutation** (optional): Array defining the physical-to-logical dimension mapping (must be a valid permutation of [0, 1, ..., N-1]) + +Usage +----- + +### Basic Usage + +```cpp +#include "sparrow_extensions/fixed_shape_tensor.hpp" +using namespace sparrow_extensions; +using namespace sparrow; + +// Create 3 tensors of shape [2, 3] (2 rows, 3 columns) +std::vector flat_data = { + // First tensor + 1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f, + // Second tensor + 7.0f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, + // Third tensor + 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f +}; + +primitive_array values_array(flat_data); +fixed_shape_tensor_extension::metadata tensor_meta{ + {2, 3}, // shape + std::nullopt, // no dimension names + std::nullopt // no permutation +}; + +const std::uint64_t list_size = tensor_meta.compute_size(); // 2 * 3 = 6 + +fixed_shape_tensor_array tensor_array( + list_size, + array(std::move(values_array)), + tensor_meta +); + +// Access properties +std::cout << "Number of tensors: " << tensor_array.size() << "\n"; // 3 +std::cout << "Tensor shape: "; +for (auto dim : tensor_array.shape()) +{ + std::cout << dim << " "; // 2 3 +} +std::cout << "\n"; + +// Access individual tensors +auto first_tensor = tensor_array[0]; +if (first_tensor.has_value()) +{ + // Process first tensor... +} +``` + +### With Dimension Names + +```cpp +// Create a 3D tensor with named dimensions +std::vector flat_data(100 * 200 * 500); // Fill with data... + +primitive_array values_array(flat_data); +fixed_shape_tensor_extension::metadata tensor_meta{ + {100, 200, 500}, // shape: channels, height, width + std::vector{"C", "H", "W"}, // dimension names + std::nullopt // no permutation +}; + +const std::uint64_t list_size = tensor_meta.compute_size(); + +fixed_shape_tensor_array tensor_array( + list_size, + array(std::move(values_array)), + tensor_meta +); + +// Access metadata +const auto& meta = tensor_array.get_metadata(); +if (meta.dim_names.has_value()) +{ + for (const auto& name : *meta.dim_names) + { + std::cout << name << " "; // C H W + } +} +``` + +### With Permutation + +```cpp +// Physical storage is [100, 200, 500] but logical layout is [500, 100, 200] +std::vector flat_data(100 * 200 * 500); // Fill with data... + +primitive_array values_array(flat_data); +fixed_shape_tensor_extension::metadata tensor_meta{ + {100, 200, 500}, // physical shape + std::nullopt, + std::vector{2, 0, 1} // permutation: logical[i] = physical[perm[i]] +}; + +const std::uint64_t list_size = tensor_meta.compute_size(); + +fixed_shape_tensor_array tensor_array( + list_size, + array(std::move(values_array)), + tensor_meta +); + +// Physical shape: [100, 200, 500] +// Logical shape: [500, 100, 200] +const auto& meta = tensor_array.get_metadata(); +if (meta.permutation.has_value()) +{ + std::cout << "Has permutation: "; + for (auto idx : *meta.permutation) + { + std::cout << idx << " "; // 2 0 1 + } +} +``` + +### With Array Name and Metadata + +```cpp +// Create tensors with custom name and Arrow metadata +std::vector flat_data(24); // 1 tensor of shape [2, 3, 4] +std::iota(flat_data.begin(), flat_data.end(), 0.0); + +primitive_array values_array(flat_data); +fixed_shape_tensor_extension::metadata tensor_meta{ + {2, 3, 4}, + std::vector{"X", "Y", "Z"}, + std::nullopt +}; + +std::vector arrow_metadata{ + {"author", "research_team"}, + {"version", "2.0"}, + {"experiment", "trial_42"} +}; + +const std::uint64_t list_size = tensor_meta.compute_size(); + +fixed_shape_tensor_array tensor_array( + list_size, + array(std::move(values_array)), + tensor_meta, + "neural_network_weights", // array name + arrow_metadata // additional metadata +); + +// Access the Arrow proxy to read metadata +const auto& proxy = tensor_array.get_arrow_proxy(); +std::cout << "Array name: " << proxy.name() << "\n"; + +if (auto meta_opt = proxy.metadata()) +{ + for (const auto& [key, value] : *meta_opt) + { + std::cout << key << ": " << value << "\n"; + } +} +``` + +### With Validity Bitmap + +```cpp +#include "sparrow_extensions/fixed_shape_tensor.hpp" +using namespace sparrow_extensions; +using namespace sparrow; + +// Create 4 tensors of shape [2, 2], with some null values +std::vector flat_data(16); +std::iota(flat_data.begin(), flat_data.end(), 0); + +primitive_array values_array(flat_data); +fixed_shape_tensor_extension::metadata tensor_meta{{2, 2}, std::nullopt, std::nullopt}; + +// Create validity bitmap: first and third tensors are valid, others are null +std::vector validity{true, false, true, false}; + +const std::uint64_t list_size = tensor_meta.compute_size(); + +fixed_shape_tensor_array tensor_array( + list_size, + array(std::move(values_array)), + tensor_meta, + validity +); + +// Check which tensors are valid +for (size_t i = 0; i < tensor_array.size(); ++i) +{ + auto tensor = tensor_array[i]; + if (tensor.has_value()) + { + std::cout << "Tensor " << i << " is valid\n"; + } + else + { + std::cout << "Tensor " << i << " is null\n"; + } +} +``` + +### JSON Metadata Serialization + +```cpp +#include "sparrow_extensions/fixed_shape_tensor.hpp" +using namespace sparrow_extensions; + +// Create metadata +fixed_shape_tensor_extension::metadata meta{ + {100, 200, 500}, + std::vector{"C", "H", "W"}, + std::vector{2, 0, 1} +}; + +// Serialize to JSON +std::string json = meta.to_json(); +// Result: {"shape":[100,200,500],"dim_names":["C","H","W"],"permutation":[2,0,1]} + +// Deserialize from JSON +auto parsed_meta = fixed_shape_tensor_extension::metadata::from_json(json); + +// Validate +if (parsed_meta.is_valid()) +{ + std::cout << "Metadata is valid\n"; + std::cout << "Tensor size: " << parsed_meta.compute_size() << "\n"; +} +``` + +Constructors +------------ + +The `fixed_shape_tensor_array` class provides five constructors to accommodate different use cases: + +### 1. From Arrow Proxy (Reconstruction) + +```cpp +fixed_shape_tensor_array(sparrow::arrow_proxy proxy); +``` + +Reconstructs a tensor array from an existing Arrow proxy. Used internally by the Arrow extension registry. + +### 2. Basic Constructor + +```cpp +template +fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata +); +``` + +Creates a tensor array with the specified shape and flattened values. All tensors are valid (no null values). + +**Parameters:** +- `list_size`: Product of all dimensions (from `tensor_metadata.compute_size()`) +- `flat_values`: Flattened tensor data as a primitive array +- `tensor_metadata`: Shape, optional dim_names, and optional permutation + +### 3. With Name and Metadata + +```cpp +template +fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + std::string_view name, + M&& arrow_metadata = std::nullopt +); +``` + +Creates a tensor array with custom name and Arrow metadata fields. + +**Parameters:** +- All parameters from basic constructor, plus: +- `name`: Name for the array (stored in Arrow schema) +- `arrow_metadata`: Optional additional metadata pairs + +### 4. With Validity Bitmap + +```cpp +template +fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + VB&& validity +); +``` + +Creates a tensor array where some tensors may be null. + +**Parameters:** +- All parameters from basic constructor, plus: +- `validity`: Validity bitmap (e.g., `std::vector`) indicating which tensors are valid + +### 5. Complete Constructor + +```cpp +template +fixed_shape_tensor_array( + std::uint64_t list_size, + sparrow::array&& flat_values, + const metadata_type& tensor_metadata, + VB&& validity, + std::string_view name, + M&& arrow_metadata = std::nullopt +); +``` + +Combines all features: validity bitmap, name, and metadata. + +API Reference +------------- + +### fixed_shape_tensor_array + +The `fixed_shape_tensor_array` class wraps `sparrow::fixed_sized_list_array` with extension metadata for tensor storage. + +| Method | Description | +| ------ | ----------- | +| `size() const` | Returns the number of tensors in the array | +| `shape() const` | Returns the shape vector for each tensor | +| `get_metadata() const` | Returns the complete tensor metadata (shape, dim_names, permutation) | +| `storage() const` | Returns const reference to underlying `fixed_sized_list_array` | +| `storage()` | Returns mutable reference to underlying `fixed_sized_list_array` | +| `operator[](size_type i) const` | Accesses the i-th tensor (returns nullable reference) | +| `get_arrow_proxy() const` | Returns const reference to Arrow proxy | +| `get_arrow_proxy()` | Returns mutable reference to Arrow proxy | + +### fixed_shape_tensor_extension::metadata + +The metadata structure for tensor arrays. + +| Field | Type | Description | +| ----- | ---- | ----------- | +| `shape` | `std::vector` | Dimensions of each tensor (required) | +| `dim_names` | `std::optional>` | Optional dimension names | +| `permutation` | `std::optional>` | Optional dimension permutation | + +| Method | Description | +| ------ | ----------- | +| `is_valid() const` | Validates metadata consistency | +| `compute_size() const` | Computes the product of all dimensions | +| `to_json() const` | Serializes metadata to JSON string | +| `static from_json(std::string_view)` | Deserializes metadata from JSON | + +### Metadata Validation Rules + +The `is_valid()` method checks: +- Shape is not empty and all dimensions are positive +- If `dim_names` is present, its size matches the shape size +- If `permutation` is present: + - Its size matches the shape size + - It contains exactly the indices [0, 1, ..., N-1] without duplicates + - All indices are in valid range + +Performance Considerations +-------------------------- + +### Optimizations + +The implementation includes several performance optimizations: + +1. **String reservation in JSON serialization**: Pre-calculates approximate JSON size to minimize allocations +2. **Direct string concatenation**: Avoids `std::ostringstream` overhead by using direct string operations +3. **Vector capacity hints**: Reserves space for typical tensor dimensions (2-4D) during JSON parsing +4. **Bitset-based permutation validation**: O(n) validation instead of O(n log n) sorting +5. **Move semantics**: Efficiently transfers metadata and arrays without copying +6. **Early returns**: Skips unnecessary work when extension metadata already exists + +### Best Practices + +- Use `compute_size()` to calculate the required `list_size` parameter +- Pre-allocate flat data vectors when possible +- Use move semantics when passing arrays to constructors +- Validate metadata with `is_valid()` before creating arrays +- For large tensors, consider memory layout and cache efficiency + +Extension Metadata +------------------ + +The Fixed Shape Tensor array automatically sets the following Arrow extension metadata: + +- `ARROW:extension:name`: `"arrow.fixed_shape_tensor"` +- `ARROW:extension:metadata`: JSON string containing shape, optional dim_names, and optional permutation + +This metadata is added to the Arrow schema, allowing other Arrow implementations to recognize and correctly interpret the tensor arrays. + +Examples from Specification +---------------------------- + +### Example 1: Simple 2×5 Tensor + +```cpp +// From spec: { "shape": [2, 5] } +std::vector data(10); +std::iota(data.begin(), data.end(), 0.0); + +primitive_array values(data); +fixed_shape_tensor_extension::metadata meta{{2, 5}, std::nullopt, std::nullopt}; + +fixed_shape_tensor_array tensors( + meta.compute_size(), + array(std::move(values)), + meta +); +// Contains 1 tensor of shape [2, 5] +``` + +### Example 2: Image Data with Dimension Names + +```cpp +// From spec: { "shape": [100, 200, 500], "dim_names": ["C", "H", "W"] } +std::vector image_data(100 * 200 * 500); +// Fill with image data... + +primitive_array values(image_data); +fixed_shape_tensor_extension::metadata meta{ + {100, 200, 500}, + std::vector{"C", "H", "W"}, + std::nullopt +}; + +fixed_shape_tensor_array images( + meta.compute_size(), + array(std::move(values)), + meta +); +// Contains 1 tensor representing channels × height × width +``` + +### Example 3: Permuted Layout + +```cpp +// From spec: { "shape": [100, 200, 500], "permutation": [2, 0, 1] } +// Physical layout: [100, 200, 500] +// Logical layout: [500, 100, 200] + +std::vector data(100 * 200 * 500); +// Fill with data in physical layout... + +primitive_array values(data); +fixed_shape_tensor_extension::metadata meta{ + {100, 200, 500}, + std::nullopt, + std::vector{2, 0, 1} +}; + +fixed_shape_tensor_array tensors( + meta.compute_size(), + array(std::move(values)), + meta +); +// Data is stored in physical layout but can be interpreted with logical shape +``` + +See Also +-------- + +- [Apache Arrow Canonical Extension: FixedShapeTensor](https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor) +- [sparrow::fixed_sized_list_array](https://github.com/man-group/sparrow) +- [Bool8 Array](@ref bool8_array) +- [UUID Array](@ref uuid_array) +- [JSON Array](@ref json_array) diff --git a/docs/source/main_page.md b/docs/source/main_page.md index eaf022e..dd76dfa 100644 --- a/docs/source/main_page.md +++ b/docs/source/main_page.md @@ -25,3 +25,5 @@ This software is licensed under the Apache License 2.0. See the LICENSE file for \subpage bool8_array +\subpage fixed_shape_tensor_array + From c1c9e9959b5858c6a9e9aa0484c32e852fd7f319 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 11:44:43 +0100 Subject: [PATCH 09/11] wip --- src/fixed_shape_tensor.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fixed_shape_tensor.cpp b/src/fixed_shape_tensor.cpp index 9588fda..fd93bf6 100644 --- a/src/fixed_shape_tensor.cpp +++ b/src/fixed_shape_tensor.cpp @@ -16,7 +16,6 @@ #include #include -#include #include #include From df7073e5b60682fd068f7d137096b84f506aea37 Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Tue, 23 Dec 2025 11:47:56 +0100 Subject: [PATCH 10/11] wip --- include/sparrow_extensions/fixed_shape_tensor.hpp | 14 -------------- tests/test_fixed_shape_tensor.cpp | 14 -------------- 2 files changed, 28 deletions(-) diff --git a/include/sparrow_extensions/fixed_shape_tensor.hpp b/include/sparrow_extensions/fixed_shape_tensor.hpp index 47458d3..929fcd9 100644 --- a/include/sparrow_extensions/fixed_shape_tensor.hpp +++ b/include/sparrow_extensions/fixed_shape_tensor.hpp @@ -1,17 +1,3 @@ -// Copyright 2024 Man Group Operations Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #pragma once #include diff --git a/tests/test_fixed_shape_tensor.cpp b/tests/test_fixed_shape_tensor.cpp index bbbe440..fce7acc 100644 --- a/tests/test_fixed_shape_tensor.cpp +++ b/tests/test_fixed_shape_tensor.cpp @@ -1,17 +1,3 @@ -// Copyright 2024 Man Group Operations Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #include #include From 02d3bcf917f38cc944dfa2d7937b993da48cba7b Mon Sep 17 00:00:00 2001 From: Alexis Placet Date: Wed, 24 Dec 2025 09:18:03 +0100 Subject: [PATCH 11/11] wip --- src/fixed_shape_tensor.cpp | 136 +++++++++++++++++++++---------------- 1 file changed, 77 insertions(+), 59 deletions(-) diff --git a/src/fixed_shape_tensor.cpp b/src/fixed_shape_tensor.cpp index fd93bf6..13fd374 100644 --- a/src/fixed_shape_tensor.cpp +++ b/src/fixed_shape_tensor.cpp @@ -20,10 +20,10 @@ #include +#include "sparrow/array.hpp" #include "sparrow/layout/array_access.hpp" #include "sparrow/layout/array_registry.hpp" #include "sparrow/utils/contracts.hpp" -#include "sparrow/array.hpp" #include "sparrow_extensions/config/config.hpp" @@ -32,21 +32,28 @@ namespace sparrow_extensions namespace { // JSON serialization size estimation constants - constexpr std::size_t json_base_size = 10; // {"shape":[]} - constexpr std::size_t json_integer_avg_size = 10; // Average size per integer - constexpr std::size_t json_dim_names_overhead = 15; // ,"dim_names":[] - constexpr std::size_t json_string_overhead = 3; // "name", - constexpr std::size_t json_permutation_overhead = 17; // ,"permutation":[] - + constexpr std::size_t json_base_size = 10; // {"shape":[]} + constexpr std::size_t json_integer_avg_size = 10; // Average size per integer + constexpr std::size_t json_dim_names_overhead = 15; // ,"dim_names":[] + constexpr std::size_t json_string_overhead = 3; // "name", + constexpr std::size_t json_permutation_overhead = 17; // ,"permutation":[] + // JSON parsing capacity hints - constexpr std::size_t typical_tensor_dimensions = 8; // Typical tensor rank (2-4 dims, reserve 8) + constexpr std::size_t typical_tensor_dimensions = 8; // Typical tensor rank (2-4 dims, reserve 8) } // Metadata implementation bool fixed_shape_tensor_extension::metadata::is_valid() const { // Shape must not be empty and all dimensions must be positive - if (shape.empty() || !std::ranges::all_of(shape, [](auto dim) { return dim > 0; })) + if (shape.empty() + || !std::ranges::all_of( + shape, + [](auto dim) + { + return dim > 0; + } + )) { return false; } @@ -71,7 +78,8 @@ namespace sparrow_extensions std::vector seen(perm.size(), false); for (const auto idx : perm) { - if (idx < 0 || static_cast(idx) >= perm.size() || seen[static_cast(idx)]) + if (idx < 0 || static_cast(idx) >= perm.size() + || seen[static_cast(idx)]) { return false; } @@ -85,12 +93,7 @@ namespace sparrow_extensions std::int64_t fixed_shape_tensor_extension::metadata::compute_size() const { - return std::reduce( - shape.begin(), - shape.end(), - std::int64_t{1}, - std::multiplies<>{} - ); + return std::reduce(shape.begin(), shape.end(), std::int64_t{1}, std::multiplies<>{}); } std::string fixed_shape_tensor_extension::metadata::to_json() const @@ -110,17 +113,20 @@ namespace sparrow_extensions { estimated_size += json_permutation_overhead + permutation->size() * json_integer_avg_size; } - + std::string result; result.reserve(estimated_size); - + auto serialize_array = [&result](const auto& arr, auto&& formatter) { result += '['; bool first = true; for (const auto& item : arr) { - if (!first) result += ','; + if (!first) + { + result += ','; + } first = false; formatter(item); } @@ -128,31 +134,46 @@ namespace sparrow_extensions }; result += "{\"shape\":"; - serialize_array(shape, [&result](const auto& val) { result += std::to_string(val); }); + serialize_array( + shape, + [&result](const auto& val) + { + result += std::to_string(val); + } + ); if (dim_names.has_value()) { result += ",\"dim_names\":"; - serialize_array(*dim_names, [&result](const auto& val) { - result += '\"'; - result += val; - result += '\"'; - }); + serialize_array( + *dim_names, + [&result](const auto& val) + { + result += '\"'; + result += val; + result += '\"'; + } + ); } if (permutation.has_value()) { result += ",\"permutation\":"; - serialize_array(*permutation, [&result](const auto& val) { result += std::to_string(val); }); + serialize_array( + *permutation, + [&result](const auto& val) + { + result += std::to_string(val); + } + ); } result += '}'; return result; } - fixed_shape_tensor_extension::metadata fixed_shape_tensor_extension::metadata::from_json( - std::string_view json - ) + fixed_shape_tensor_extension::metadata + fixed_shape_tensor_extension::metadata::from_json(std::string_view json) { auto parse_int_array = [](simdjson::ondemand::array arr) -> std::vector { @@ -179,37 +200,43 @@ namespace sparrow_extensions try { metadata result; - + simdjson::ondemand::parser parser; simdjson::padded_string padded_json(json); simdjson::ondemand::document doc = parser.iterate(padded_json); - + // Parse shape (required) - result.shape = parse_int_array(doc["shape"].get_array()); - - if (result.shape.empty()) + auto shape_field = doc["shape"]; + if (shape_field.error() != simdjson::SUCCESS) { throw std::runtime_error("Missing required 'shape' field"); } - + + result.shape = parse_int_array(shape_field.get_array()); + + if (result.shape.empty()) + { + throw std::runtime_error("'shape' field cannot be empty"); + } + // Parse optional fields auto dim_names_field = doc["dim_names"]; if (dim_names_field.error() == simdjson::SUCCESS) { result.dim_names = parse_string_array(dim_names_field.get_array()); } - + auto permutation_field = doc["permutation"]; if (permutation_field.error() == simdjson::SUCCESS) { result.permutation = parse_int_array(permutation_field.get_array()); } - + if (!result.is_valid()) { throw std::runtime_error("Invalid metadata"); } - + return result; } catch (const simdjson::simdjson_error& e) @@ -218,21 +245,18 @@ namespace sparrow_extensions } } - void fixed_shape_tensor_extension::init( - sparrow::arrow_proxy& proxy, - const metadata& tensor_metadata - ) + void fixed_shape_tensor_extension::init(sparrow::arrow_proxy& proxy, const metadata& tensor_metadata) { SPARROW_ASSERT_TRUE(tensor_metadata.is_valid()); // Get existing metadata auto existing_metadata = proxy.metadata(); std::vector extension_metadata; - + if (existing_metadata.has_value()) { extension_metadata.assign(existing_metadata->begin(), existing_metadata->end()); - + // Check if extension metadata already exists const bool has_extension_name = std::ranges::find_if( extension_metadata, @@ -243,28 +267,24 @@ namespace sparrow_extensions } ) != extension_metadata.end(); - + if (has_extension_name) { proxy.set_metadata(std::make_optional(std::move(extension_metadata))); return; } } - + // Reserve space for new entries extension_metadata.reserve(extension_metadata.size() + 2); extension_metadata.emplace_back("ARROW:extension:name", std::string(EXTENSION_NAME)); - extension_metadata.emplace_back( - "ARROW:extension:metadata", - tensor_metadata.to_json() - ); + extension_metadata.emplace_back("ARROW:extension:metadata", tensor_metadata.to_json()); proxy.set_metadata(std::make_optional(std::move(extension_metadata))); } - fixed_shape_tensor_extension::metadata fixed_shape_tensor_extension::extract_metadata( - const sparrow::arrow_proxy& proxy - ) + fixed_shape_tensor_extension::metadata + fixed_shape_tensor_extension::extract_metadata(const sparrow::arrow_proxy& proxy) { const auto metadata_opt = proxy.metadata(); if (!metadata_opt.has_value()) @@ -303,10 +323,7 @@ namespace sparrow_extensions SPARROW_ASSERT_TRUE(m_metadata.is_valid()); SPARROW_ASSERT_TRUE(static_cast(list_size) == m_metadata.compute_size()); - fixed_shape_tensor_extension::init( - sparrow::detail::array_access::get_arrow_proxy(m_storage), - m_metadata - ); + fixed_shape_tensor_extension::init(sparrow::detail::array_access::get_arrow_proxy(m_storage), m_metadata); } fixed_shape_tensor_array::fixed_shape_tensor_array( @@ -324,7 +341,7 @@ namespace sparrow_extensions auto& proxy = sparrow::detail::array_access::get_arrow_proxy(m_storage); proxy.set_name(name); - + if (arrow_metadata.has_value()) { proxy.set_metadata(std::make_optional(*arrow_metadata)); @@ -358,7 +375,8 @@ namespace sparrow_extensions return m_storage; } - auto fixed_shape_tensor_array::operator[](size_type i) const -> decltype(std::declval()[i]) + auto fixed_shape_tensor_array::operator[](size_type i) const + -> decltype(std::declval()[i]) { return m_storage[i]; }