From 2f59c9ae3c2062b155d5da6d943aa24b54065043 Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Sat, 4 Oct 2025 04:40:34 +0000 Subject: [PATCH 1/6] Update vendored DuckDB sources to 5657cbdc0b --- CMakeLists.txt | 1 + .../scalar/generic/current_setting.cpp | 5 +- .../extension/json/include/json_common.hpp | 6 +- src/duckdb/extension/json/json_functions.cpp | 6 +- src/duckdb/extension/json/json_reader.cpp | 3 +- .../extension/parquet/column_writer.cpp | 56 +- .../parquet/include/column_writer.hpp | 16 +- .../include/reader/string_column_reader.hpp | 13 + .../reader/variant/variant_binary_decoder.hpp | 6 +- .../include/writer/list_column_writer.hpp | 9 +- .../include/writer/struct_column_writer.hpp | 6 +- .../writer/templated_column_writer.hpp | 1 + .../include/writer/variant_column_writer.hpp | 43 + .../extension/parquet/parquet_extension.cpp | 33 +- .../extension/parquet/parquet_reader.cpp | 13 +- .../extension/parquet/parquet_statistics.cpp | 12 +- .../extension/parquet/parquet_writer.cpp | 2 +- .../parquet/reader/string_column_reader.cpp | 21 +- .../reader/variant/variant_binary_decoder.cpp | 16 +- .../parquet/writer/array_column_writer.cpp | 9 +- .../parquet/writer/list_column_writer.cpp | 21 +- .../writer/primitive_column_writer.cpp | 2 +- .../parquet/writer/struct_column_writer.cpp | 2 +- .../writer/variant/convert_variant.cpp | 633 ++++++++++++++ .../parquet/writer/variant_column_writer.cpp | 131 +++ .../catalog_entry/duck_table_entry.cpp | 4 +- .../catalog_entry/table_catalog_entry.cpp | 2 +- .../src/catalog/catalog_search_path.cpp | 4 +- src/duckdb/src/common/enum_util.cpp | 56 +- src/duckdb/src/common/enums/metric_type.cpp | 6 + .../src/common/enums/optimizer_type.cpp | 1 + src/duckdb/src/common/extra_type_info.cpp | 15 + src/duckdb/src/common/file_system.cpp | 30 - .../src/common/operator/cast_operators.cpp | 9 + src/duckdb/src/common/render_tree.cpp | 2 +- src/duckdb/src/common/sorting/sort.cpp | 12 +- src/duckdb/src/common/sorting/sorted_run.cpp | 167 +++- .../src/common/sorting/sorted_run_merger.cpp | 122 +-- src/duckdb/src/common/string_util.cpp | 15 + src/duckdb/src/common/types.cpp | 13 + .../types/column/column_data_collection.cpp | 3 +- src/duckdb/src/common/types/geometry.cpp | 773 ++++++++++++++++++ .../types/row/tuple_data_collection.cpp | 19 +- src/duckdb/src/common/types/value.cpp | 7 + src/duckdb/src/common/types/vector.cpp | 4 + src/duckdb/src/common/virtual_file_system.cpp | 69 +- .../src/execution/index/art/art_merger.cpp | 3 - src/duckdb/src/execution/index/art/prefix.cpp | 4 + .../operator/join/physical_asof_join.cpp | 485 +++++------ .../operator/join/physical_iejoin.cpp | 626 +++++++++----- .../join/physical_piecewise_merge_join.cpp | 553 +++++++------ .../operator/join/physical_range_join.cpp | 311 ++++--- .../persistent/physical_batch_insert.cpp | 10 +- .../operator/persistent/physical_insert.cpp | 6 +- .../physical_plan/plan_asof_join.cpp | 6 +- .../aggregate/sorted_aggregate_function.cpp | 1 + .../src/function/cast/default_casts.cpp | 2 + src/duckdb/src/function/cast/geo_casts.cpp | 24 + src/duckdb/src/function/cast/string_cast.cpp | 2 + .../function/cast/variant/from_variant.cpp | 24 +- .../src/function/cast/variant/to_json.cpp | 6 + .../function/scalar/variant/variant_utils.cpp | 10 + .../src/function/table/direct_file_reader.cpp | 2 +- .../table/system/duckdb_connection_count.cpp | 45 + .../table/system/pragma_storage_info.cpp | 2 +- .../src/function/table/system_functions.cpp | 1 + .../function/table/version/pragma_version.cpp | 6 +- .../window/window_merge_sort_tree.cpp | 3 + .../function/window/window_rank_function.cpp | 1 + .../window/window_rownumber_function.cpp | 1 + .../catalog_entry/duck_table_entry.hpp | 2 +- .../catalog_entry/table_catalog_entry.hpp | 2 +- .../catalog/default/builtin_types/types.hpp | 5 +- .../src/include/duckdb/common/enum_util.hpp | 8 + .../duckdb/common/enums/metric_type.hpp | 1 + .../duckdb/common/enums/optimizer_type.hpp | 3 +- .../include/duckdb/common/extra_type_info.hpp | 15 +- .../duckdb/common/operator/cast_operators.hpp | 13 + .../include/duckdb/common/sorting/sort.hpp | 25 +- .../duckdb/common/sorting/sorted_run.hpp | 31 +- .../common/sorting/sorted_run_merger.hpp | 14 +- .../src/include/duckdb/common/string_util.hpp | 2 + .../src/include/duckdb/common/types.hpp | 3 + .../include/duckdb/common/types/geometry.hpp | 35 + .../common/types/row/block_iterator.hpp | 5 +- .../types/row/tuple_data_collection.hpp | 5 +- .../src/include/duckdb/common/types/value.hpp | 2 + .../include/duckdb/common/types/variant.hpp | 2 + .../duckdb/common/virtual_file_system.hpp | 5 +- .../operator/join/physical_asof_join.hpp | 12 +- .../join/physical_piecewise_merge_join.hpp | 2 + .../operator/join/physical_range_join.hpp | 110 ++- .../duckdb/function/cast/default_casts.hpp | 1 + .../cast/variant/primitive_to_variant.hpp | 3 + .../cast/variant/struct_to_variant.hpp | 2 +- .../cast/variant/variant_to_variant.hpp | 2 +- .../duckdb/function/compression_function.hpp | 8 +- .../function/scalar/variant_functions.hpp | 2 +- .../duckdb/function/scalar/variant_utils.hpp | 2 + .../function/table/system_functions.hpp | 4 + .../include/duckdb/main/attached_database.hpp | 4 + .../include/duckdb/main/client_context.hpp | 6 + .../src/include/duckdb/main/connection.hpp | 1 - .../duckdb/main/connection_manager.hpp | 1 - .../main/database_file_path_manager.hpp | 9 +- .../include/duckdb/main/extension_entries.hpp | 1 + .../include/duckdb/main/profiling_info.hpp | 19 +- .../include/duckdb/main/query_profiler.hpp | 3 +- .../src/include/duckdb/main/secret/secret.hpp | 4 +- .../optimizer/common_subplan_optimizer.hpp | 31 + .../include/duckdb/optimizer/cte_inlining.hpp | 1 + .../duckdb/optimizer/filter_pushdown.hpp | 2 + .../src/include/duckdb/planner/binder.hpp | 65 +- .../expression_binder/lateral_binder.hpp | 4 +- .../duckdb/planner/operator/logical_cte.hpp | 2 +- .../operator/logical_dependent_join.hpp | 6 +- .../subquery/flatten_dependent_join.hpp | 4 +- .../subquery/has_correlated_expressions.hpp | 4 +- .../planner/subquery/rewrite_cte_scan.hpp | 4 +- .../duckdb/planner/tableref/bound_joinref.hpp | 2 +- .../include/duckdb/storage/buffer_manager.hpp | 1 + .../storage/compression/alp/alp_analyze.hpp | 7 +- .../storage/compression/alp/alp_scan.hpp | 2 +- .../compression/alprd/alprd_analyze.hpp | 6 +- .../storage/compression/alprd/alprd_scan.hpp | 2 +- .../storage/compression/chimp/chimp_scan.hpp | 2 +- .../storage/compression/empty_validity.hpp | 2 +- .../storage/compression/patas/patas_scan.hpp | 2 +- .../include/duckdb/storage/data_pointer.hpp | 1 + .../src/include/duckdb/storage/data_table.hpp | 4 +- .../storage/metadata/metadata_manager.hpp | 2 +- .../duckdb/storage/optimistic_data_writer.hpp | 13 +- .../storage/standard_buffer_manager.hpp | 2 +- .../storage/statistics/string_stats.hpp | 2 + .../duckdb/storage/string_uncompressed.hpp | 2 +- .../storage/table/array_column_data.hpp | 4 +- .../duckdb/storage/table/column_data.hpp | 11 +- .../duckdb/storage/table/column_segment.hpp | 1 - .../duckdb/storage/table/list_column_data.hpp | 4 +- .../duckdb/storage/table/row_group.hpp | 17 +- .../storage/table/row_group_collection.hpp | 17 +- .../duckdb/storage/table/scan_state.hpp | 10 +- .../storage/table/standard_column_data.hpp | 4 +- .../storage/table/struct_column_data.hpp | 4 +- .../duckdb/transaction/cleanup_state.hpp | 4 +- .../duckdb/transaction/local_storage.hpp | 6 + src/duckdb/src/main/attached_database.cpp | 16 +- src/duckdb/src/main/client_data.cpp | 3 + src/duckdb/src/main/connection.cpp | 7 +- .../src/main/database_file_path_manager.cpp | 18 +- src/duckdb/src/main/database_manager.cpp | 2 +- src/duckdb/src/main/http/http_util.cpp | 4 +- src/duckdb/src/main/profiling_info.cpp | 33 +- src/duckdb/src/main/query_profiler.cpp | 76 +- .../optimizer/common_subplan_optimizer.cpp | 575 +++++++++++++ src/duckdb/src/optimizer/cte_inlining.cpp | 13 +- src/duckdb/src/optimizer/filter_pushdown.cpp | 71 +- src/duckdb/src/optimizer/optimizer.cpp | 7 + .../pushdown/pushdown_inner_join.cpp | 1 + .../optimizer/pushdown/pushdown_left_join.cpp | 1 + .../pushdown/pushdown_semi_anti_join.cpp | 1 + src/duckdb/src/parser/parser.cpp | 42 +- src/duckdb/src/planner/binder.cpp | 4 +- .../binder/query_node/plan_subquery.cpp | 14 +- .../planner/binder/statement/bind_create.cpp | 6 +- .../binder/statement/bind_merge_into.cpp | 12 +- .../planner/binder/tableref/bind_pivot.cpp | 11 +- .../binder/tableref/bind_table_function.cpp | 11 +- src/duckdb/src/planner/expression_binder.cpp | 4 +- .../expression_binder/lateral_binder.cpp | 22 +- .../table_function_binder.cpp | 4 + .../operator/logical_dependent_join.cpp | 4 +- .../subquery/flatten_dependent_join.cpp | 26 +- .../subquery/has_correlated_expressions.cpp | 2 +- .../src/planner/subquery/rewrite_cte_scan.cpp | 4 +- .../src/storage/compression/bitpacking.cpp | 17 +- .../src/storage/compression/dict_fsst.cpp | 5 +- .../compression/dictionary_compression.cpp | 5 +- .../compression/fixed_size_uncompressed.cpp | 4 +- src/duckdb/src/storage/compression/fsst.cpp | 4 +- .../storage/compression/numeric_constant.cpp | 2 +- src/duckdb/src/storage/compression/rle.cpp | 2 +- .../storage/compression/roaring/common.cpp | 2 +- .../compression/string_uncompressed.cpp | 3 +- .../compression/validity_uncompressed.cpp | 2 +- src/duckdb/src/storage/compression/zstd.cpp | 4 +- src/duckdb/src/storage/data_table.cpp | 14 +- src/duckdb/src/storage/local_storage.cpp | 16 +- .../src/storage/metadata/metadata_manager.cpp | 2 +- .../src/storage/optimistic_data_writer.cpp | 44 +- .../storage/serialization/serialize_types.cpp | 12 + .../src/storage/standard_buffer_manager.cpp | 2 +- .../src/storage/statistics/string_stats.cpp | 8 + .../src/storage/table/array_column_data.cpp | 12 +- src/duckdb/src/storage/table/column_data.cpp | 8 +- .../src/storage/table/column_segment.cpp | 3 +- .../src/storage/table/list_column_data.cpp | 12 +- src/duckdb/src/storage/table/row_group.cpp | 55 +- .../storage/table/row_group_collection.cpp | 41 +- .../storage/table/standard_column_data.cpp | 12 +- .../src/storage/table/struct_column_data.cpp | 11 +- src/duckdb/src/storage/wal_replay.cpp | 2 +- src/duckdb/src/transaction/cleanup_state.cpp | 4 +- .../transaction/duck_transaction_manager.cpp | 4 +- src/duckdb/src/transaction/undo_buffer.cpp | 3 +- src/duckdb/third_party/httplib/httplib.hpp | 7 +- .../yyjson/include/yyjson_utils.hpp | 33 + src/duckdb/ub_extension_parquet_writer.cpp | 2 + .../ub_extension_parquet_writer_variant.cpp | 2 + src/duckdb/ub_src_common_types.cpp | 2 + src/duckdb/ub_src_function_cast.cpp | 2 + src/duckdb/ub_src_function_table_system.cpp | 2 + src/duckdb/ub_src_optimizer.cpp | 2 + 213 files changed, 4945 insertions(+), 1580 deletions(-) create mode 100644 src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp create mode 100644 src/duckdb/extension/parquet/writer/variant/convert_variant.cpp create mode 100644 src/duckdb/extension/parquet/writer/variant_column_writer.cpp create mode 100644 src/duckdb/src/common/types/geometry.cpp create mode 100644 src/duckdb/src/function/cast/geo_casts.cpp create mode 100644 src/duckdb/src/function/table/system/duckdb_connection_count.cpp create mode 100644 src/duckdb/src/include/duckdb/common/types/geometry.hpp create mode 100644 src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp create mode 100644 src/duckdb/src/optimizer/common_subplan_optimizer.cpp create mode 100644 src/duckdb/third_party/yyjson/include/yyjson_utils.hpp create mode 100644 src/duckdb/ub_extension_parquet_writer_variant.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 7153f8424..280269485 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -389,6 +389,7 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_extension_parquet_reader.cpp src/duckdb/ub_extension_parquet_reader_variant.cpp src/duckdb/ub_extension_parquet_writer.cpp + src/duckdb/ub_extension_parquet_writer_variant.cpp src/duckdb/third_party/parquet/parquet_types.cpp src/duckdb/third_party/thrift/thrift/protocol/TProtocol.cpp src/duckdb/third_party/thrift/thrift/transport/TTransportException.cpp diff --git a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp index 4464f0544..31e1afe17 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp @@ -53,10 +53,7 @@ unique_ptr CurrentSettingBind(ClientContext &context, ScalarFuncti if (!context.TryGetCurrentSetting(key, val)) { auto extension_name = Catalog::AutoloadExtensionByConfigName(context, key); // If autoloader didn't throw, the config is now available - if (!context.TryGetCurrentSetting(key, val)) { - throw InternalException("Extension %s did not provide the '%s' config setting", - extension_name.ToStdString(), key); - } + context.TryGetCurrentSetting(key, val); } bound_function.return_type = val.type(); diff --git a/src/duckdb/extension/json/include/json_common.hpp b/src/duckdb/extension/json/include/json_common.hpp index f6dd78f05..81bbd6868 100644 --- a/src/duckdb/extension/json/include/json_common.hpp +++ b/src/duckdb/extension/json/include/json_common.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "yyjson.hpp" +#include "duckdb/common/types/blob.hpp" using namespace duckdb_yyjson; // NOLINT @@ -228,11 +229,8 @@ struct JSONCommon { static string FormatParseError(const char *data, idx_t length, yyjson_read_err &error, const string &extra = "") { D_ASSERT(error.code != YYJSON_READ_SUCCESS); - // Go to blob so we can have a better error message for weird strings - auto blob = Value::BLOB(string(data, length)); // Truncate, so we don't print megabytes worth of JSON - string input = blob.ToString(); - input = input.length() > 50 ? string(input.c_str(), 47) + "..." : input; + auto input = length > 50 ? string(data, 47) + "..." : string(data, length); // Have to replace \r, otherwise output is unreadable input = StringUtil::Replace(input, "\r", "\\r"); return StringUtil::Format("Malformed JSON at byte %lld of input: %s. %s Input: \"%s\"", error.pos, error.msg, diff --git a/src/duckdb/extension/json/json_functions.cpp b/src/duckdb/extension/json/json_functions.cpp index 2d09828c3..2d0ef11f5 100644 --- a/src/duckdb/extension/json/json_functions.cpp +++ b/src/duckdb/extension/json/json_functions.cpp @@ -394,7 +394,11 @@ void JSONFunctions::RegisterSimpleCastFunctions(ExtensionLoader &loader) { loader.RegisterCastFunction(LogicalType::LIST(LogicalType::JSON()), LogicalTypeId::VARCHAR, CastJSONListToVarchar, json_list_to_varchar_cost); - // VARCHAR to JSON[] (also needs a special case otherwise get a VARCHAR -> VARCHAR[] cast first) + // JSON[] to JSON is allowed implicitly + loader.RegisterCastFunction(LogicalType::LIST(LogicalType::JSON()), LogicalType::JSON(), CastJSONListToVarchar, + 100); + + // VARCHAR to JSON[] (also needs a special case otherwise we get a VARCHAR -> VARCHAR[] cast first) const auto varchar_to_json_list_cost = CastFunctionSet::ImplicitCastCost(db, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::JSON())) - 1; BoundCastInfo varchar_to_json_list_info(CastVarcharToJSONList, nullptr, JSONFunctionLocalState::InitCastLocalState); diff --git a/src/duckdb/extension/json/json_reader.cpp b/src/duckdb/extension/json/json_reader.cpp index b52026a4e..ad61da7d2 100644 --- a/src/duckdb/extension/json/json_reader.cpp +++ b/src/duckdb/extension/json/json_reader.cpp @@ -184,8 +184,7 @@ void JSONReader::OpenJSONFile() { if (!IsOpen()) { auto &fs = FileSystem::GetFileSystem(context); auto regular_file_handle = fs.OpenFile(file, FileFlags::FILE_FLAGS_READ | options.compression); - file_handle = make_uniq(QueryContext(context), std::move(regular_file_handle), - BufferAllocator::Get(context)); + file_handle = make_uniq(context, std::move(regular_file_handle), BufferAllocator::Get(context)); } Reset(); } diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 7cdd51bc5..55bfa2007 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -13,6 +13,7 @@ #include "writer/list_column_writer.hpp" #include "writer/primitive_column_writer.hpp" #include "writer/struct_column_writer.hpp" +#include "writer/variant_column_writer.hpp" #include "writer/templated_column_writer.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/operator/comparison_operators.hpp" @@ -181,8 +182,7 @@ void ColumnWriter::CompressPage(MemoryStream &temp_writer, size_t &compressed_si } } -void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count, - idx_t max_repeat) const { +void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count) const { if (!parent) { // no repeat levels without a parent node return; @@ -264,6 +264,45 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vector VARIANT { + // metadata BYTE_ARRAY, + // value BYTE_ARRAY, + // [] + // } + + const bool is_shredded = false; + + // variant group + duckdb_parquet::SchemaElement top_element; + top_element.repetition_type = null_type; + top_element.num_children = is_shredded ? 3 : 2; + top_element.logicalType.__isset.VARIANT = true; + top_element.logicalType.VARIANT.__isset.specification_version = true; + top_element.logicalType.VARIANT.specification_version = 1; + top_element.__isset.logicalType = true; + top_element.__isset.num_children = true; + top_element.__isset.repetition_type = true; + schemas.push_back(std::move(top_element)); + + child_list_t child_types; + child_types.emplace_back("metadata", LogicalType::BLOB); + child_types.emplace_back("value", LogicalType::BLOB); + if (is_shredded) { + throw NotImplementedException("Writing shredded VARIANT isn't supported for Parquet yet"); + } + + ParquetColumnSchema variant_column(name, type, max_define, max_repeat, schema_idx, 0); + variant_column.children.reserve(child_types.size()); + for (auto &child_type : child_types) { + variant_column.children.emplace_back(FillParquetSchema(schemas, child_type.second, child_type.first, + child_field_ids, max_repeat, max_define + 1, false)); + } + return variant_column; + } + if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { auto &child_types = StructType::GetChildTypes(type); // set up the schema element for this struct @@ -400,6 +439,19 @@ ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &write auto &type = schema.type; auto can_have_nulls = parquet_schemas[schema.schema_index].repetition_type == FieldRepetitionType::OPTIONAL; path_in_schema.push_back(schema.name); + + if (type.id() == LogicalTypeId::STRUCT && type.GetAlias() == "PARQUET_VARIANT") { + D_ASSERT(schema.children.size() == 2); //! NOTE: shredded variants not supported yet + + vector> child_writers; + child_writers.reserve(schema.children.size()); + for (idx_t i = 0; i < schema.children.size(); i++) { + child_writers.push_back( + CreateWriterRecursive(context, writer, parquet_schemas, schema.children[i], path_in_schema)); + } + return make_uniq(writer, schema, path_in_schema, std::move(child_writers), can_have_nulls); + } + if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { // construct the child writers recursively vector> child_writers; diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp index d475e903b..d208bddf6 100644 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ b/src/duckdb/extension/parquet/include/column_writer.hpp @@ -71,11 +71,6 @@ class ColumnWriter { bool can_have_nulls); virtual ~ColumnWriter(); - ParquetWriter &writer; - const ParquetColumnSchema &column_schema; - vector schema_path; - bool can_have_nulls; - public: const LogicalType &Type() const { return column_schema.type; @@ -129,10 +124,19 @@ class ColumnWriter { protected: void HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, const ValidityMask &validity, const idx_t count, const uint16_t define_value, const uint16_t null_value) const; - void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count, idx_t max_repeat) const; + void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count) const; void CompressPage(MemoryStream &temp_writer, size_t &compressed_size, data_ptr_t &compressed_data, AllocatedData &compressed_buf); + +public: + ParquetWriter &writer; + const ParquetColumnSchema &column_schema; + vector schema_path; + bool can_have_nulls; + +protected: + vector> child_writers; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp index 4bc19516a..bfc0692af 100644 --- a/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/string_column_reader.hpp @@ -14,12 +14,25 @@ namespace duckdb { class StringColumnReader : public ColumnReader { + enum class StringColumnType : uint8_t { VARCHAR, JSON, OTHER }; + + static StringColumnType GetStringColumnType(const LogicalType &type) { + if (type.IsJSONType()) { + return StringColumnType::JSON; + } + if (type.id() == LogicalTypeId::VARCHAR) { + return StringColumnType::VARCHAR; + } + return StringColumnType::OTHER; + } + public: static constexpr const PhysicalType TYPE = PhysicalType::VARCHAR; public: StringColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema); idx_t fixed_width_string_length; + const StringColumnType string_column_type; public: static void VerifyString(const char *str_data, uint32_t str_len, const bool isVarchar); diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp index a7c717709..17efcd46e 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp @@ -137,10 +137,8 @@ class VariantBinaryDecoder { static VariantValue Decode(const VariantMetadata &metadata, const_data_ptr_t data); public: - static VariantValue PrimitiveTypeDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); - static VariantValue ShortStringDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); + static VariantValue PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); + static VariantValue ShortStringDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); static VariantValue ObjectDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, const_data_ptr_t data); static VariantValue ArrayDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, diff --git a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp index f1070b0f1..902d3001c 100644 --- a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp @@ -28,13 +28,11 @@ class ListColumnWriter : public ColumnWriter { public: ListColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, unique_ptr child_writer_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), - child_writer(std::move(child_writer_p)) { + : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + child_writers.push_back(std::move(child_writer_p)); } ~ListColumnWriter() override = default; - unique_ptr child_writer; - public: unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; bool HasAnalyze() override; @@ -46,6 +44,9 @@ class ListColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; void FinalizeWrite(ColumnWriterState &state) override; + +protected: + ColumnWriter &GetChildWriter(); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp index 8927c391b..bbb6cd06b 100644 --- a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp @@ -16,13 +16,11 @@ class StructColumnWriter : public ColumnWriter { public: StructColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, vector> child_writers_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), - child_writers(std::move(child_writers_p)) { + : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls) { + child_writers = std::move(child_writers_p); } ~StructColumnWriter() override = default; - vector> child_writers; - public: unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; bool HasAnalyze() override; diff --git a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp index c035bba43..4c9f1d8aa 100644 --- a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp @@ -197,6 +197,7 @@ class StandardColumnWriter : public PrimitiveColumnWriter { const bool check_parent_empty = parent && !parent->is_empty.empty(); const idx_t parent_index = state.definition_levels.size(); + D_ASSERT(!check_parent_empty || parent_index < parent->is_empty.size()); const idx_t vcount = check_parent_empty ? parent->definition_levels.size() - state.definition_levels.size() : count; diff --git a/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp new file mode 100644 index 000000000..d1c5af1cf --- /dev/null +++ b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp @@ -0,0 +1,43 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// writer/variant_column_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "column_writer.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" + +namespace duckdb { + +class VariantColumnWriter : public ColumnWriter { +public: + VariantColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, + vector> child_writers_p, bool can_have_nulls) + : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), + child_writers(std::move(child_writers_p)) { + } + ~VariantColumnWriter() override = default; + + vector> child_writers; + +public: + unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; + bool HasAnalyze() override; + void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; + void FinalizeAnalyze(ColumnWriterState &state) override; + void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count, + bool vector_can_span_multiple_pages) override; + + void BeginWrite(ColumnWriterState &state) override; + void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; + void FinalizeWrite(ColumnWriterState &state) override; + +public: + static ScalarFunction GetTransformFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 37e6cd0b7..6e8e6fe1b 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -15,6 +15,7 @@ #include "reader/struct_column_reader.hpp" #include "zstd_file_system.hpp" #include "writer/primitive_column_writer.hpp" +#include "writer/variant_column_writer.hpp" #include #include @@ -43,6 +44,8 @@ #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/table/row_group.hpp" @@ -828,6 +831,19 @@ static bool IsTypeLossy(const LogicalType &type) { return type.id() == LogicalTypeId::HUGEINT || type.id() == LogicalTypeId::UHUGEINT; } +static bool IsGeometryType(const LogicalType &type, ClientContext &context) { + if (type.id() != LogicalTypeId::BLOB) { + return false; + } + if (!type.HasAlias()) { + return false; + } + if (type.GetAlias() != "GEOMETRY") { + return false; + } + return GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context); +} + static vector> ParquetWriteSelect(CopyToSelectInput &input) { auto &context = input.context; @@ -843,8 +859,7 @@ static vector> ParquetWriteSelect(CopyToSelectInput &inpu // Spatial types need to be encoded into WKB when writing GeoParquet. // But dont perform this conversion if this is a EXPORT DATABASE statement - if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::BLOB && type.HasAlias() && - type.GetAlias() == "GEOMETRY" && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { + if (input.copy_to_type == CopyToType::COPY_TO_FILE && IsGeometryType(type, context)) { LogicalType wkb_blob_type(LogicalTypeId::BLOB); wkb_blob_type.SetAlias("WKB_BLOB"); @@ -853,6 +868,17 @@ static vector> ParquetWriteSelect(CopyToSelectInput &inpu cast_expr->SetAlias(name); result.push_back(std::move(cast_expr)); any_change = true; + } else if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::VARIANT) { + vector> arguments; + arguments.push_back(std::move(expr)); + + auto transform_func = VariantColumnWriter::GetTransformFunction(); + transform_func.bind(context, transform_func, arguments); + + auto func_expr = make_uniq(transform_func.return_type, transform_func, + std::move(arguments), nullptr, false); + result.push_back(std::move(func_expr)); + any_change = true; } // If this is an EXPORT DATABASE statement, we dont want to write "lossy" types, instead cast them to VARCHAR else if (input.copy_to_type == CopyToType::EXPORT_DATABASE && TypeVisitor::Contains(type, IsTypeLossy)) { @@ -924,6 +950,9 @@ static void LoadInternal(ExtensionLoader &loader) { ParquetBloomProbeFunction bloom_probe_fun; loader.RegisterFunction(MultiFileReader::CreateFunctionSet(bloom_probe_fun)); + // variant_to_parquet_variant + loader.RegisterFunction(VariantColumnWriter::GetTransformFunction()); + CopyFunction function("parquet"); function.copy_to_select = ParquetWriteSelect; function.copy_to_bind = ParquetWriteBind; diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index cad5f3a9b..2ce1a4403 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -92,7 +92,7 @@ static shared_ptr LoadMetadata(ClientContext &context, Allocator &allocator, CachingFileHandle &file_handle, const shared_ptr &encryption_config, const EncryptionUtil &encryption_util, optional_idx footer_size) { - auto file_proto = CreateThriftFileProtocol(QueryContext(context), file_handle, false); + auto file_proto = CreateThriftFileProtocol(context, file_handle, false); auto &transport = reinterpret_cast(*file_proto->getTransport()); auto file_size = transport.GetSize(); if (file_size < 12) { @@ -570,7 +570,10 @@ ParquetColumnSchema ParquetReader::ParseSchemaRecursive(idx_t depth, idx_t max_d auto file_meta_data = GetFileMetadata(); D_ASSERT(file_meta_data); - D_ASSERT(next_schema_idx < file_meta_data->schema.size()); + if (next_schema_idx >= file_meta_data->schema.size()) { + throw InvalidInputException("Malformed Parquet schema in file \"%s\": invalid schema index %d", file.path, + next_schema_idx); + } auto &s_ele = file_meta_data->schema[next_schema_idx]; auto this_idx = next_schema_idx; @@ -837,7 +840,7 @@ ParquetReader::ParquetReader(ClientContext &context_p, OpenFileInfo file_p, Parq shared_ptr metadata_p) : BaseFileReader(std::move(file_p)), fs(CachingFileSystem::Get(context_p)), allocator(BufferAllocator::Get(context_p)), parquet_options(std::move(parquet_options_p)) { - file_handle = fs.OpenFile(QueryContext(context_p), file, FileFlags::FILE_FLAGS_READ); + file_handle = fs.OpenFile(context_p, file, FileFlags::FILE_FLAGS_READ); if (!file_handle->CanSeek()) { throw NotImplementedException( "Reading parquet files from a FIFO stream is not supported and cannot be efficiently supported since " @@ -1236,7 +1239,7 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat state.prefetch_mode = false; } - state.file_handle = fs.OpenFile(QueryContext(context), file, flags); + state.file_handle = fs.OpenFile(context, file, flags); } state.adaptive_filter.reset(); state.scan_filters.clear(); @@ -1247,7 +1250,7 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat } } - state.thrift_file_proto = CreateThriftFileProtocol(QueryContext(context), *state.file_handle, state.prefetch_mode); + state.thrift_file_proto = CreateThriftFileProtocol(context, *state.file_handle, state.prefetch_mode); state.root_reader = CreateReader(context); state.define_buf.resize(allocator, STANDARD_VECTOR_SIZE); state.repeat_buf.resize(allocator, STANDARD_VECTOR_SIZE); diff --git a/src/duckdb/extension/parquet/parquet_statistics.cpp b/src/duckdb/extension/parquet/parquet_statistics.cpp index 5f7d93718..a22613271 100644 --- a/src/duckdb/extension/parquet/parquet_statistics.cpp +++ b/src/duckdb/extension/parquet/parquet_statistics.cpp @@ -395,23 +395,21 @@ unique_ptr ParquetStatisticsUtils::TransformColumnStatistics(con } break; case LogicalTypeId::VARCHAR: { - auto string_stats = StringStats::CreateEmpty(type); + auto string_stats = StringStats::CreateUnknown(type); if (parquet_stats.__isset.min_value) { StringColumnReader::VerifyString(parquet_stats.min_value.c_str(), parquet_stats.min_value.size(), true); - StringStats::Update(string_stats, parquet_stats.min_value); + StringStats::SetMin(string_stats, parquet_stats.min_value); } else if (parquet_stats.__isset.min) { StringColumnReader::VerifyString(parquet_stats.min.c_str(), parquet_stats.min.size(), true); - StringStats::Update(string_stats, parquet_stats.min); + StringStats::SetMin(string_stats, parquet_stats.min); } if (parquet_stats.__isset.max_value) { StringColumnReader::VerifyString(parquet_stats.max_value.c_str(), parquet_stats.max_value.size(), true); - StringStats::Update(string_stats, parquet_stats.max_value); + StringStats::SetMax(string_stats, parquet_stats.max_value); } else if (parquet_stats.__isset.max) { StringColumnReader::VerifyString(parquet_stats.max.c_str(), parquet_stats.max.size(), true); - StringStats::Update(string_stats, parquet_stats.max); + StringStats::SetMax(string_stats, parquet_stats.max); } - StringStats::SetContainsUnicode(string_stats); - StringStats::ResetMaxStringLength(string_stats); row_group_stats = string_stats.ToUnique(); break; } diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 2021335ad..99a420242 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -459,7 +459,7 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGro write_states.emplace_back(col_writers.back().get().InitializeWriteState(row_group)); } - for (auto &chunk : buffer.Chunks({column_ids})) { + for (auto &chunk : buffer.Chunks(column_ids)) { for (idx_t i = 0; i < next; i++) { if (col_writers[i].get().HasAnalyze()) { col_writers[i].get().Analyze(*write_states[i], nullptr, chunk.data[i], chunk.size()); diff --git a/src/duckdb/extension/parquet/reader/string_column_reader.cpp b/src/duckdb/extension/parquet/reader/string_column_reader.cpp index 6b2a3db6d..867dbb4d8 100644 --- a/src/duckdb/extension/parquet/reader/string_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/string_column_reader.cpp @@ -9,7 +9,7 @@ namespace duckdb { // String Column Reader //===--------------------------------------------------------------------===// StringColumnReader::StringColumnReader(ParquetReader &reader, const ParquetColumnSchema &schema) - : ColumnReader(reader, schema) { + : ColumnReader(reader, schema), string_column_type(GetStringColumnType(Type())) { fixed_width_string_length = 0; if (schema.parquet_type == Type::FIXED_LEN_BYTE_ARRAY) { fixed_width_string_length = schema.type_length; @@ -26,13 +26,26 @@ void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len, co size_t pos; auto utf_type = Utf8Proc::Analyze(str_data, str_len, &reason, &pos); if (utf_type == UnicodeType::INVALID) { - throw InvalidInputException("Invalid string encoding found in Parquet file: value \"" + - Blob::ToString(string_t(str_data, str_len)) + "\" is not valid UTF8!"); + throw InvalidInputException("Invalid string encoding found in Parquet file: value \"%s\" is not valid UTF8!", + Blob::ToString(string_t(str_data, str_len))); } } void StringColumnReader::VerifyString(const char *str_data, uint32_t str_len) { - VerifyString(str_data, str_len, Type().id() == LogicalTypeId::VARCHAR); + switch (string_column_type) { + case StringColumnType::VARCHAR: + VerifyString(str_data, str_len, true); + break; + case StringColumnType::JSON: { + const auto error = StringUtil::ValidateJSON(str_data, str_len); + if (!error.empty()) { + throw InvalidInputException("Invalid JSON found in Parquet file: %s", error); + } + break; + } + default: + break; + } } class ParquetStringVectorBuffer : public VectorBuffer { diff --git a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp index eacff5501..0388da0b3 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp @@ -15,7 +15,7 @@ static constexpr uint8_t VERSION_MASK = 0xF; static constexpr uint8_t SORTED_STRINGS_MASK = 0x1; static constexpr uint8_t SORTED_STRINGS_SHIFT = 4; static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_MASK = 0x3; -static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_SHIFT = 5; +static constexpr uint8_t OFFSET_SIZE_MINUS_ONE_SHIFT = 6; static constexpr uint8_t BASIC_TYPE_MASK = 0x3; static constexpr uint8_t VALUE_HEADER_SHIFT = 2; @@ -74,8 +74,8 @@ VariantMetadata::VariantMetadata(const string_t &metadata) : metadata(metadata) const_data_ptr_t ptr = reinterpret_cast(metadata_data + sizeof(uint8_t)); idx_t dictionary_size = ReadVariableLengthLittleEndian(header.offset_size, ptr); - offsets = ptr; - bytes = offsets + ((dictionary_size + 1) * header.offset_size); + auto offsets = ptr; + auto bytes = offsets + ((dictionary_size + 1) * header.offset_size); idx_t last_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); for (idx_t i = 0; i < dictionary_size; i++) { auto next_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); @@ -140,8 +140,7 @@ hugeint_t DecodeDecimal(const_data_ptr_t data, uint8_t &scale, uint8_t &width) { return result; } -VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, +VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data) { switch (value_metadata.primitive_type) { case VariantPrimitiveType::NULL_TYPE: { @@ -267,8 +266,7 @@ VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantMetadata &me } } -VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, +VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data) { D_ASSERT(value_metadata.string_size < 64); auto string_data = reinterpret_cast(data); @@ -348,10 +346,10 @@ VariantValue VariantBinaryDecoder::Decode(const VariantMetadata &variant_metadat data++; switch (value_metadata.basic_type) { case VariantBasicType::PRIMITIVE: { - return PrimitiveTypeDecode(variant_metadata, value_metadata, data); + return PrimitiveTypeDecode(value_metadata, data); } case VariantBasicType::SHORT_STRING: { - return ShortStringDecode(variant_metadata, value_metadata, data); + return ShortStringDecode(value_metadata, data); } case VariantBasicType::OBJECT: { return ObjectDecode(variant_metadata, value_metadata, data); diff --git a/src/duckdb/extension/parquet/writer/array_column_writer.cpp b/src/duckdb/extension/parquet/writer/array_column_writer.cpp index 60284ff28..2a9c9a9d5 100644 --- a/src/duckdb/extension/parquet/writer/array_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/array_column_writer.cpp @@ -6,7 +6,7 @@ void ArrayColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *p auto &state = state_p.Cast(); auto &array_child = ArrayVector::GetEntry(vector); auto array_size = ArrayType::GetSize(vector.GetType()); - child_writer->Analyze(*state.child_state, &state_p, array_child, array_size * count); + GetChildWriter().Analyze(*state.child_state, &state_p, array_child, array_size * count); } void ArrayColumnWriter::WriteArrayState(ListColumnWriterState &state, idx_t array_size, uint16_t first_repeat_level, @@ -35,10 +35,9 @@ void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p // write definition levels and repeats // the main difference between this and ListColumnWriter::Prepare is that we need to make sure to write out // repetition levels and definitions for the child elements of the array even if the array itself is NULL. - idx_t start = 0; idx_t vcount = parent ? parent->definition_levels.size() - state.parent_index : count; idx_t vector_index = 0; - for (idx_t i = start; i < vcount; i++) { + for (idx_t i = 0; i < vcount; i++) { idx_t parent_index = state.parent_index + i; if (parent && !parent->is_empty.empty() && parent->is_empty[parent_index]) { WriteArrayState(state, array_size, parent->repetition_levels[parent_index], @@ -63,14 +62,14 @@ void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *p auto &array_child = ArrayVector::GetEntry(vector); // The elements of a single array should not span multiple Parquet pages // So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false" - child_writer->Prepare(*state.child_state, &state_p, array_child, count * array_size, false); + GetChildWriter().Prepare(*state.child_state, &state_p, array_child, count * array_size, false); } void ArrayColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { auto &state = state_p.Cast(); auto array_size = ArrayType::GetSize(vector.GetType()); auto &array_child = ArrayVector::GetEntry(vector); - child_writer->Write(*state.child_state, array_child, count * array_size); + GetChildWriter().Write(*state.child_state, array_child, count * array_size); } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/list_column_writer.cpp b/src/duckdb/extension/parquet/writer/list_column_writer.cpp index 8fba00c23..b043a94bc 100644 --- a/src/duckdb/extension/parquet/writer/list_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/list_column_writer.cpp @@ -4,23 +4,23 @@ namespace duckdb { unique_ptr ListColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { auto result = make_uniq(row_group, row_group.columns.size()); - result->child_state = child_writer->InitializeWriteState(row_group); + result->child_state = GetChildWriter().InitializeWriteState(row_group); return std::move(result); } bool ListColumnWriter::HasAnalyze() { - return child_writer->HasAnalyze(); + return GetChildWriter().HasAnalyze(); } void ListColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { auto &state = state_p.Cast(); auto &list_child = ListVector::GetEntry(vector); auto list_count = ListVector::GetListSize(vector); - child_writer->Analyze(*state.child_state, &state_p, list_child, list_count); + GetChildWriter().Analyze(*state.child_state, &state_p, list_child, list_count); } void ListColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->FinalizeAnalyze(*state.child_state); + GetChildWriter().FinalizeAnalyze(*state.child_state); } static idx_t GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { @@ -114,12 +114,12 @@ void ListColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *pa auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); // The elements of a single list should not span multiple Parquet pages // So, we force the entire vector to fit on a single page by setting "vector_can_span_multiple_pages=false" - child_writer->Prepare(*state.child_state, &state_p, child_list, child_length, false); + GetChildWriter().Prepare(*state.child_state, &state_p, child_list, child_length, false); } void ListColumnWriter::BeginWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->BeginWrite(*state.child_state); + GetChildWriter().BeginWrite(*state.child_state); } void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { @@ -128,12 +128,17 @@ void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t c auto &list_child = ListVector::GetEntry(vector); Vector child_list(list_child); auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); - child_writer->Write(*state.child_state, child_list, child_length); + GetChildWriter().Write(*state.child_state, child_list, child_length); } void ListColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - child_writer->FinalizeWrite(*state.child_state); + GetChildWriter().FinalizeWrite(*state.child_state); +} + +ColumnWriter &ListColumnWriter::GetChildWriter() { + D_ASSERT(child_writers.size() == 1); + return *child_writers[0]; } } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp index d3ebd7dfc..16189ab24 100644 --- a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp @@ -44,7 +44,7 @@ void PrimitiveColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterStat idx_t vcount = parent ? parent->definition_levels.size() - state.definition_levels.size() : count; idx_t parent_index = state.definition_levels.size(); auto &validity = FlatVector::Validity(vector); - HandleRepeatLevels(state, parent, count, MaxRepeat()); + HandleRepeatLevels(state, parent, count); HandleDefineLevels(state, parent, validity, count, MaxDefine(), MaxDefine() - 1); idx_t vector_index = 0; diff --git a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp index e65515ad5..c9b6bcf9d 100644 --- a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp @@ -67,7 +67,7 @@ void StructColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState * parent->is_empty.end()); } } - HandleRepeatLevels(state_p, parent, count, MaxRepeat()); + HandleRepeatLevels(state_p, parent, count); HandleDefineLevels(state_p, parent, validity, count, PARQUET_DEFINE_VALID, MaxDefine() - 1); auto &child_vectors = StructVector::GetEntries(vector); for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { diff --git a/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp new file mode 100644 index 000000000..ba6c707dd --- /dev/null +++ b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp @@ -0,0 +1,633 @@ +#include "writer/variant_column_writer.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "reader/variant/variant_binary_decoder.hpp" +#include "duckdb/common/types/uuid.hpp" + +namespace duckdb { + +static idx_t CalculateByteLength(idx_t value) { + if (value == 0) { + return 1; + } + auto value_data = reinterpret_cast(&value); + idx_t irrelevant_bytes = 0; + //! Check how many of the most significant bytes are 0 + for (idx_t i = sizeof(idx_t); i > 0 && value_data[i - 1] == 0; i--) { + irrelevant_bytes++; + } + return sizeof(idx_t) - irrelevant_bytes; +} + +static uint8_t EncodeMetadataHeader(idx_t byte_length) { + D_ASSERT(byte_length <= 4); + + uint8_t header_byte = 0; + //! Set 'version' to 1 + header_byte |= static_cast(1); + //! Set 'sorted_strings' to 1 + header_byte |= static_cast(1) << 4; + //! Set 'offset_size_minus_one' to byte_length-1 + header_byte |= (static_cast(byte_length) - 1) << 6; + +#ifdef DEBUG + auto decoded_header = VariantMetadataHeader::FromHeaderByte(header_byte); + D_ASSERT(decoded_header.offset_size == byte_length); +#endif + + return header_byte; +} + +static void CreateMetadata(UnifiedVariantVectorData &variant, Vector &metadata, idx_t count) { + auto &keys = variant.keys; + auto keys_data = variant.keys_data; + + //! NOTE: the parquet variant is limited to a max dictionary size of NumericLimits::Maximum() + //! Whereas we can have NumericLimits::Maximum() *per* string in DuckDB + auto metadata_data = FlatVector::GetData(metadata); + for (idx_t row = 0; row < count; row++) { + uint64_t dictionary_count = 0; + if (variant.RowIsValid(row)) { + auto list_entry = keys_data[keys.sel->get_index(row)]; + dictionary_count = list_entry.length; + } + idx_t dictionary_size = 0; + for (idx_t i = 0; i < dictionary_count; i++) { + auto &key = variant.GetKey(row, i); + dictionary_size += key.GetSize(); + } + if (dictionary_size >= NumericLimits::Maximum()) { + throw InvalidInputException("The total length of the dictionary exceeds a 4 byte value (uint32_t), failed " + "to export VARIANT to Parquet"); + } + + auto byte_length = CalculateByteLength(dictionary_size); + auto total_length = 1 + (byte_length * (dictionary_count + 2)) + dictionary_size; + + metadata_data[row] = StringVector::EmptyString(metadata, total_length); + auto &metadata_blob = metadata_data[row]; + auto metadata_blob_data = metadata_blob.GetDataWriteable(); + + metadata_blob_data[0] = EncodeMetadataHeader(byte_length); + memcpy(metadata_blob_data + 1, reinterpret_cast(&dictionary_count), byte_length); + + auto offset_ptr = metadata_blob_data + 1 + byte_length; + auto string_ptr = metadata_blob_data + 1 + byte_length + ((dictionary_count + 1) * byte_length); + idx_t total_offset = 0; + for (idx_t i = 0; i < dictionary_count; i++) { + memcpy(offset_ptr + (i * byte_length), reinterpret_cast(&total_offset), byte_length); + auto &key = variant.GetKey(row, i); + + memcpy(string_ptr + total_offset, key.GetData(), key.GetSize()); + total_offset += key.GetSize(); + } + memcpy(offset_ptr + (dictionary_count * byte_length), reinterpret_cast(&total_offset), byte_length); + D_ASSERT(offset_ptr + ((dictionary_count + 1) * byte_length) == string_ptr); + D_ASSERT(string_ptr + total_offset == metadata_blob_data + total_length); + metadata_blob.SetSizeAndFinalize(total_length, total_length); + +#ifdef DEBUG + auto decoded_metadata = VariantMetadata(metadata_blob); + D_ASSERT(decoded_metadata.strings.size() == dictionary_count); + for (idx_t i = 0; i < dictionary_count; i++) { + D_ASSERT(decoded_metadata.strings[i] == variant.GetKey(row, i).GetString()); + } +#endif + } +} + +static idx_t AnalyzeValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + vector &offsets) { + idx_t total_size = 0; + //! Every value has at least a value header + total_size++; + + idx_t offset_size = offsets.size(); + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + switch (type_id) { + case VariantLogicalType::OBJECT: { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! Calculate value and key offsets for all children + idx_t total_offset = 0; + uint32_t highest_keys_index = 0; + offsets.resize(offset_size + nested_data.child_count + 1); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + offsets[offset_size + i] = total_offset; + + total_offset += AnalyzeValueData(variant, row, values_index, offsets); + highest_keys_index = MaxValue(highest_keys_index, keys_index); + } + offsets[offset_size + nested_data.child_count] = total_offset; + + //! Calculate the sizes for the objects value data + auto field_id_size = CalculateByteLength(highest_keys_index); + auto field_offset_size = CalculateByteLength(total_offset); + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + //! Now add the sizes for the objects value data + if (is_large) { + total_size += sizeof(uint32_t); + } else { + total_size += sizeof(uint8_t); + } + total_size += num_elements * field_id_size; + total_size += (num_elements + 1) * field_offset_size; + total_size += total_offset; + break; + } + case VariantLogicalType::ARRAY: { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + idx_t total_offset = 0; + offsets.resize(offset_size + nested_data.child_count + 1); + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + offsets[offset_size + i] = total_offset; + + total_offset += AnalyzeValueData(variant, row, values_index, offsets); + } + offsets[offset_size + nested_data.child_count] = total_offset; + + auto field_offset_size = CalculateByteLength(total_offset); + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + if (is_large) { + total_size += sizeof(uint32_t); + } else { + total_size += sizeof(uint8_t); + } + total_size += (num_elements + 1) * field_offset_size; + total_size += total_offset; + break; + } + case VariantLogicalType::BLOB: + case VariantLogicalType::VARCHAR: { + auto string_value = VariantUtils::DecodeStringData(variant, row, values_index); + total_size += string_value.GetSize(); + if (type_id == VariantLogicalType::BLOB || string_value.GetSize() > 64) { + //! Save as regular string value + total_size += sizeof(uint32_t); + } + break; + } + case VariantLogicalType::VARIANT_NULL: + case VariantLogicalType::BOOL_TRUE: + case VariantLogicalType::BOOL_FALSE: + break; + case VariantLogicalType::INT8: + total_size += sizeof(uint8_t); + break; + case VariantLogicalType::INT16: + total_size += sizeof(uint16_t); + break; + case VariantLogicalType::INT32: + total_size += sizeof(uint32_t); + break; + case VariantLogicalType::INT64: + total_size += sizeof(uint64_t); + break; + case VariantLogicalType::FLOAT: + total_size += sizeof(float); + break; + case VariantLogicalType::DOUBLE: + total_size += sizeof(double); + break; + case VariantLogicalType::DECIMAL: { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + total_size += 1; + if (decimal_data.width <= 9) { + total_size += sizeof(int32_t); + } else if (decimal_data.width <= 18) { + total_size += sizeof(int64_t); + } else if (decimal_data.width <= 38) { + total_size += sizeof(uhugeint_t); + } else { + throw InvalidInputException("Can't convert VARIANT DECIMAL(%d, %d) to Parquet VARIANT", decimal_data.width, + decimal_data.scale); + } + break; + } + case VariantLogicalType::UUID: + total_size += sizeof(uhugeint_t); + break; + case VariantLogicalType::DATE: + total_size += sizeof(uint32_t); + break; + case VariantLogicalType::TIME_MICROS: + case VariantLogicalType::TIMESTAMP_MICROS: + case VariantLogicalType::TIMESTAMP_NANOS: + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + total_size += sizeof(uint64_t); + break; + case VariantLogicalType::INTERVAL: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::TIMESTAMP_MILIS: + case VariantLogicalType::TIMESTAMP_SEC: + case VariantLogicalType::TIME_MICROS_TZ: + case VariantLogicalType::TIME_NANOS: + case VariantLogicalType::UINT8: + case VariantLogicalType::UINT16: + case VariantLogicalType::UINT32: + case VariantLogicalType::UINT64: + case VariantLogicalType::UINT128: + case VariantLogicalType::INT128: + default: + throw InvalidInputException("Can't convert VARIANT of type '%s' to Parquet VARIANT", + EnumUtil::ToString(type_id)); + } + + return total_size; +} + +template +void WritePrimitiveTypeHeader(data_ptr_t &value_data) { + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::PRIMITIVE); + value_header |= static_cast(TYPE_ID) << 2; + + *value_data = value_header; + value_data++; +} + +template +void CopySimplePrimitiveData(const UnifiedVariantVectorData &variant, data_ptr_t &value_data, idx_t row, + uint32_t values_index) { + auto byte_offset = variant.GetByteOffset(row, values_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + memcpy(value_data, ptr, sizeof(T)); + value_data += sizeof(T); +} + +void CopyUUIDData(const UnifiedVariantVectorData &variant, data_ptr_t &value_data, idx_t row, uint32_t values_index) { + + auto byte_offset = variant.GetByteOffset(row, values_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + + auto uuid = Load(ptr); + BaseUUID::ToBlob(uuid, value_data); + value_data += sizeof(uhugeint_t); +} + +static void WritePrimitiveValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + data_ptr_t &value_data, const vector &offsets, idx_t &offset_index) { + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + + D_ASSERT(type_id != VariantLogicalType::OBJECT && type_id != VariantLogicalType::ARRAY); + switch (type_id) { + case VariantLogicalType::BLOB: + case VariantLogicalType::VARCHAR: { + auto string_value = VariantUtils::DecodeStringData(variant, row, values_index); + auto string_size = string_value.GetSize(); + if (type_id == VariantLogicalType::BLOB || string_size > 64) { + if (type_id == VariantLogicalType::BLOB) { + WritePrimitiveTypeHeader(value_data); + } else { + WritePrimitiveTypeHeader(value_data); + } + Store(string_size, value_data); + value_data += sizeof(uint32_t); + } else { + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::SHORT_STRING); + value_header |= static_cast(string_size) << 2; + + *value_data = value_header; + value_data++; + } + memcpy(value_data, reinterpret_cast(string_value.GetData()), string_size); + value_data += string_size; + break; + } + case VariantLogicalType::VARIANT_NULL: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::BOOL_TRUE: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::BOOL_FALSE: + WritePrimitiveTypeHeader(value_data); + break; + case VariantLogicalType::INT8: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT16: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT32: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::INT64: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::FLOAT: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DOUBLE: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::UUID: + WritePrimitiveTypeHeader(value_data); + CopyUUIDData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DATE: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIME_MICROS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_MICROS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_NANOS: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + WritePrimitiveTypeHeader(value_data); + CopySimplePrimitiveData(variant, value_data, row, values_index); + break; + case VariantLogicalType::DECIMAL: { + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + + if (decimal_data.width <= 4 || decimal_data.width > 38) { + throw InvalidInputException("Can't convert VARIANT DECIMAL(%d, %d) to Parquet VARIANT", decimal_data.width, + decimal_data.scale); + } else if (decimal_data.width <= 9) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(int32_t)); + value_data += sizeof(int32_t); + } else if (decimal_data.width <= 18) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(int64_t)); + value_data += sizeof(int64_t); + } else if (decimal_data.width <= 38) { + WritePrimitiveTypeHeader(value_data); + Store(decimal_data.scale, value_data); + value_data++; + memcpy(value_data, decimal_data.value_ptr, sizeof(hugeint_t)); + value_data += sizeof(hugeint_t); + } else { + throw InternalException( + "Uncovered VARIANT(DECIMAL) -> Parquet VARIANT conversion for type 'DECIMAL(%d, %d)'", + decimal_data.width, decimal_data.scale); + } + break; + } + case VariantLogicalType::INTERVAL: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::TIMESTAMP_MILIS: + case VariantLogicalType::TIMESTAMP_SEC: + case VariantLogicalType::TIME_MICROS_TZ: + case VariantLogicalType::TIME_NANOS: + case VariantLogicalType::UINT8: + case VariantLogicalType::UINT16: + case VariantLogicalType::UINT32: + case VariantLogicalType::UINT64: + case VariantLogicalType::UINT128: + case VariantLogicalType::INT128: + default: + throw InvalidInputException("Can't convert VARIANT of type '%s' to Parquet VARIANT", + EnumUtil::ToString(type_id)); + } +} + +static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, + data_ptr_t &value_data, const vector &offsets, idx_t &offset_index) { + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; + if (variant.RowIsValid(row)) { + type_id = variant.GetTypeId(row, values_index); + } + if (type_id == VariantLogicalType::OBJECT) { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! -- Object value header -- + + //! Determine the 'field_id_size' + uint32_t highest_keys_index = 0; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + highest_keys_index = MaxValue(highest_keys_index, keys_index); + } + auto field_id_size = CalculateByteLength(highest_keys_index); + + uint32_t last_offset = 0; + if (nested_data.child_count) { + last_offset = offsets[offset_index + nested_data.child_count]; + } + offset_index += nested_data.child_count + 1; + auto field_offset_size = CalculateByteLength(last_offset); + + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::OBJECT); + value_header |= static_cast(is_large) << 6; + value_header |= (static_cast(field_id_size) - 1) << 4; + value_header |= (static_cast(field_offset_size) - 1) << 2; + +#ifdef DEBUG + auto object_value_header = VariantValueMetadata::FromHeaderByte(value_header); + D_ASSERT(object_value_header.basic_type == VariantBasicType::OBJECT); + D_ASSERT(object_value_header.is_large == is_large); + D_ASSERT(object_value_header.field_offset_size == field_offset_size); + D_ASSERT(object_value_header.field_id_size == field_id_size); +#endif + + *value_data = value_header; + value_data++; + + //! Write the 'num_elements' + if (is_large) { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint32_t); + } else { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint8_t); + } + + //! Write the 'field_id' entries + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + memcpy(value_data, reinterpret_cast(&keys_index), field_id_size); + value_data += field_id_size; + } + + //! Write the 'field_offset' entries and the child 'value's + auto children_ptr = value_data + ((num_elements + 1) * field_offset_size); + idx_t total_offset = 0; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + auto start_ptr = children_ptr; + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index); + total_offset += (children_ptr - start_ptr); + } + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + D_ASSERT(children_ptr - total_offset == value_data); + value_data = children_ptr; + } else if (type_id == VariantLogicalType::ARRAY) { + auto nested_data = VariantUtils::DecodeNestedData(variant, row, values_index); + + //! -- Array value header -- + + uint32_t last_offset = 0; + if (nested_data.child_count) { + last_offset = offsets[offset_index + nested_data.child_count]; + } + offset_index += nested_data.child_count + 1; + auto field_offset_size = CalculateByteLength(last_offset); + + auto num_elements = nested_data.child_count; + const bool is_large = num_elements > NumericLimits::Maximum(); + + uint8_t value_header = 0; + value_header |= static_cast(VariantBasicType::ARRAY); + value_header |= static_cast(is_large) << 4; + value_header |= (static_cast(field_offset_size) - 1) << 2; + +#ifdef DEBUG + auto array_value_header = VariantValueMetadata::FromHeaderByte(value_header); + D_ASSERT(array_value_header.basic_type == VariantBasicType::ARRAY); + D_ASSERT(array_value_header.is_large == is_large); + D_ASSERT(array_value_header.field_offset_size == field_offset_size); +#endif + + *value_data = value_header; + value_data++; + + //! Write the 'num_elements' + if (is_large) { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint32_t); + } else { + Store(static_cast(num_elements), value_data); + value_data += sizeof(uint8_t); + } + + //! Write the 'field_offset' entries and the child 'value's + auto children_ptr = value_data + ((num_elements + 1) * field_offset_size); + idx_t total_offset = 0; + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); + + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + auto start_ptr = children_ptr; + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index); + total_offset += (children_ptr - start_ptr); + } + memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); + value_data += field_offset_size; + D_ASSERT(children_ptr - total_offset == value_data); + value_data = children_ptr; + } else { + WritePrimitiveValueData(variant, row, values_index, value_data, offsets, offset_index); + } +} + +static void CreateValues(UnifiedVariantVectorData &variant, Vector &value, idx_t count) { + auto value_data = FlatVector::GetData(value); + + for (idx_t row = 0; row < count; row++) { + //! The (relative) offsets for each value, in the case of nesting + vector offsets; + //! Determine the size of this 'value' blob + idx_t blob_length = AnalyzeValueData(variant, row, 0, offsets); + if (!blob_length) { + continue; + } + value_data[row] = StringVector::EmptyString(value, blob_length); + auto &value_blob = value_data[row]; + auto value_blob_data = reinterpret_cast(value_blob.GetDataWriteable()); + + idx_t offset_index = 0; + WriteValueData(variant, row, 0, value_blob_data, offsets, offset_index); + D_ASSERT(data_ptr_cast(value_blob.GetDataWriteable() + blob_length) == value_blob_data); + value_blob.SetSizeAndFinalize(blob_length, blob_length); + } +} + +static void ToParquetVariant(DataChunk &input, ExpressionState &state, Vector &result) { + // DuckDB Variant: + // - keys = VARCHAR[] + // - children = STRUCT(keys_index UINTEGER, values_index UINTEGER)[] + // - values = STRUCT(type_id UTINYINT, byte_offset UINTEGER)[] + // - data = BLOB + + // Parquet VARIANT: + // - metadata = BLOB + // - value = BLOB + + auto &variant_vec = input.data[0]; + auto count = input.size(); + + RecursiveUnifiedVectorFormat recursive_format; + Vector::RecursiveToUnifiedFormat(variant_vec, count, recursive_format); + UnifiedVariantVectorData variant(recursive_format); + + auto &result_vectors = StructVector::GetEntries(result); + CreateMetadata(variant, *result_vectors[0], count); + CreateValues(variant, *result_vectors[1], count); +} + +static LogicalType GetParquetVariantType(const LogicalType &type) { + (void)type; + child_list_t children; + children.emplace_back("metadata", LogicalType::BLOB); + children.emplace_back("value", LogicalType::BLOB); + auto res = LogicalType::STRUCT(std::move(children)); + res.SetAlias("PARQUET_VARIANT"); + return res; +} + +static unique_ptr BindTransform(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + if (arguments.empty()) { + return nullptr; + } + auto type = ExpressionBinder::GetExpressionReturnType(*arguments[0]); + bound_function.return_type = GetParquetVariantType(type); + return nullptr; +} + +ScalarFunction VariantColumnWriter::GetTransformFunction() { + ScalarFunction transform("variant_to_parquet_variant", {LogicalType::VARIANT()}, LogicalType::ANY, ToParquetVariant, + BindTransform); + transform.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + return transform; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/variant_column_writer.cpp b/src/duckdb/extension/parquet/writer/variant_column_writer.cpp new file mode 100644 index 000000000..b4f401da8 --- /dev/null +++ b/src/duckdb/extension/parquet/writer/variant_column_writer.cpp @@ -0,0 +1,131 @@ +#include "writer/variant_column_writer.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/helper.hpp" + +namespace duckdb { + +namespace { + +class VariantColumnWriterState : public ColumnWriterState { +public: + VariantColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) + : row_group(row_group), col_idx(col_idx) { + } + ~VariantColumnWriterState() override = default; + + duckdb_parquet::RowGroup &row_group; + idx_t col_idx; + vector> child_states; +}; + +} // namespace + +unique_ptr VariantColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { + auto result = make_uniq(row_group, row_group.columns.size()); + + result->child_states.reserve(child_writers.size()); + for (auto &child_writer : child_writers) { + result->child_states.push_back(child_writer->InitializeWriteState(row_group)); + } + return std::move(result); +} + +bool VariantColumnWriter::HasAnalyze() { + for (auto &child_writer : child_writers) { + if (child_writer->HasAnalyze()) { + return true; + } + } + return false; +} + +void VariantColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { + auto &state = state_p.Cast(); + auto &child_vectors = StructVector::GetEntries(vector); + for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { + // Need to check again. It might be that just one child needs it but the rest not + if (child_writers[child_idx]->HasAnalyze()) { + child_writers[child_idx]->Analyze(*state.child_states[child_idx], &state_p, *child_vectors[child_idx], + count); + } + } +} + +void VariantColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) { + auto &state = state_p.Cast(); + for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { + // Need to check again. It might be that just one child needs it but the rest not + if (child_writers[child_idx]->HasAnalyze()) { + child_writers[child_idx]->FinalizeAnalyze(*state.child_states[child_idx]); + } + } +} + +void VariantColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count, + bool vector_can_span_multiple_pages) { + D_ASSERT(child_writers.size() == 2); + auto &metadata_writer = *child_writers[0]; + auto &value_writer = *child_writers[1]; + + auto &state = state_p.Cast(); + auto &metadata_state = *state.child_states[0]; + auto &value_state = *state.child_states[1]; + + auto &validity = FlatVector::Validity(vector); + if (parent) { + // propagate empty entries from the parent + if (state.is_empty.size() < parent->is_empty.size()) { + state.is_empty.insert(state.is_empty.end(), parent->is_empty.begin() + state.is_empty.size(), + parent->is_empty.end()); + } + } + HandleRepeatLevels(state_p, parent, count); + HandleDefineLevels(state_p, parent, validity, count, PARQUET_DEFINE_VALID, MaxDefine() - 1); + + auto &child_vectors = StructVector::GetEntries(vector); + metadata_writer.Prepare(metadata_state, &state_p, *child_vectors[0], count, vector_can_span_multiple_pages); + value_writer.Prepare(value_state, &state_p, *child_vectors[1], count, vector_can_span_multiple_pages); +} + +void VariantColumnWriter::BeginWrite(ColumnWriterState &state_p) { + D_ASSERT(child_writers.size() == 2); + auto &metadata_writer = *child_writers[0]; + auto &value_writer = *child_writers[1]; + + auto &state = state_p.Cast(); + auto &metadata_state = *state.child_states[0]; + auto &value_state = *state.child_states[1]; + + metadata_writer.BeginWrite(metadata_state); + value_writer.BeginWrite(value_state); +} + +void VariantColumnWriter::Write(ColumnWriterState &state_p, Vector &input, idx_t count) { + D_ASSERT(child_writers.size() == 2); + + auto &metadata_writer = *child_writers[0]; + auto &value_writer = *child_writers[1]; + + auto &state = state_p.Cast(); + auto &metadata_state = *state.child_states[0]; + auto &value_state = *state.child_states[1]; + + auto &child_vectors = StructVector::GetEntries(input); + metadata_writer.Write(metadata_state, *child_vectors[0], count); + value_writer.Write(value_state, *child_vectors[1], count); +} + +void VariantColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { + D_ASSERT(child_writers.size() == 2); + auto &metadata_writer = *child_writers[0]; + auto &value_writer = *child_writers[1]; + + auto &state = state_p.Cast(); + auto &metadata_state = *state.child_states[0]; + auto &value_state = *state.child_states[1]; + + metadata_writer.FinalizeWrite(metadata_state); + value_writer.FinalizeWrite(value_state); +} + +} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp index b80204ac0..6202e1030 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -1285,8 +1285,8 @@ TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr return TableScanFunction::GetFunction(); } -vector DuckTableEntry::GetColumnSegmentInfo() { - return storage->GetColumnSegmentInfo(); +vector DuckTableEntry::GetColumnSegmentInfo(const QueryContext &context) { + return storage->GetColumnSegmentInfo(context); } TableStorageInfo DuckTableEntry::GetStorageInfo(ClientContext &context) { diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 22a173fd8..5ed480e86 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -266,7 +266,7 @@ void LogicalUpdate::BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, } } -vector TableCatalogEntry::GetColumnSegmentInfo() { +vector TableCatalogEntry::GetColumnSegmentInfo(const QueryContext &context) { return {}; } diff --git a/src/duckdb/src/catalog/catalog_search_path.cpp b/src/duckdb/src/catalog/catalog_search_path.cpp index 6af56c22d..6388b9134 100644 --- a/src/duckdb/src/catalog/catalog_search_path.cpp +++ b/src/duckdb/src/catalog/catalog_search_path.cpp @@ -24,8 +24,8 @@ string CatalogSearchEntry::ToString() const { string CatalogSearchEntry::WriteOptionallyQuoted(const string &input) { for (idx_t i = 0; i < input.size(); i++) { - if (input[i] == '.' || input[i] == ',') { - return "\"" + input + "\""; + if (input[i] == '.' || input[i] == ',' || input[i] == '"') { + return "\"" + StringUtil::Replace(input, "\"", "\"\"") + "\""; } } return input; diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 324ba7004..49417ac45 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -89,6 +89,7 @@ #include "duckdb/common/types/column/partitioned_column_data.hpp" #include "duckdb/common/types/conflict_manager.hpp" #include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/geometry.hpp" #include "duckdb/common/types/hyperloglog.hpp" #include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/common/types/row/partitioned_tuple_data.hpp" @@ -1795,19 +1796,20 @@ const StringUtil::EnumStringLiteral *GetExtraTypeInfoTypeValues() { { static_cast(ExtraTypeInfoType::ARRAY_TYPE_INFO), "ARRAY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::ANY_TYPE_INFO), "ANY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), "INTEGER_LITERAL_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" } + { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" }, + { static_cast(ExtraTypeInfoType::GEO_TYPE_INFO), "GEO_TYPE_INFO" } }; return values; } template<> const char* EnumUtil::ToChars(ExtraTypeInfoType value) { - return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", static_cast(value)); + return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", static_cast(value)); } template<> ExtraTypeInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", value)); + return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 14, "ExtraTypeInfoType", value)); } const StringUtil::EnumStringLiteral *GetFileBufferTypeValues() { @@ -2059,6 +2061,30 @@ GateStatus EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetGateStatusValues(), 2, "GateStatus", value)); } +const StringUtil::EnumStringLiteral *GetGeometryTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(GeometryType::INVALID), "INVALID" }, + { static_cast(GeometryType::POINT), "POINT" }, + { static_cast(GeometryType::LINESTRING), "LINESTRING" }, + { static_cast(GeometryType::POLYGON), "POLYGON" }, + { static_cast(GeometryType::MULTIPOINT), "MULTIPOINT" }, + { static_cast(GeometryType::MULTILINESTRING), "MULTILINESTRING" }, + { static_cast(GeometryType::MULTIPOLYGON), "MULTIPOLYGON" }, + { static_cast(GeometryType::GEOMETRYCOLLECTION), "GEOMETRYCOLLECTION" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(GeometryType value) { + return StringUtil::EnumToString(GetGeometryTypeValues(), 8, "GeometryType", static_cast(value)); +} + +template<> +GeometryType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetGeometryTypeValues(), 8, "GeometryType", value)); +} + const StringUtil::EnumStringLiteral *GetHLLStorageTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(HLLStorageType::HLL_V1), "HLL_V1" }, @@ -2599,6 +2625,7 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { { static_cast(LogicalTypeId::POINTER), "POINTER" }, { static_cast(LogicalTypeId::VALIDITY), "VALIDITY" }, { static_cast(LogicalTypeId::UUID), "UUID" }, + { static_cast(LogicalTypeId::GEOMETRY), "GEOMETRY" }, { static_cast(LogicalTypeId::STRUCT), "STRUCT" }, { static_cast(LogicalTypeId::LIST), "LIST" }, { static_cast(LogicalTypeId::MAP), "MAP" }, @@ -2615,12 +2642,12 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { template<> const char* EnumUtil::ToChars(LogicalTypeId value) { - return StringUtil::EnumToString(GetLogicalTypeIdValues(), 50, "LogicalTypeId", static_cast(value)); + return StringUtil::EnumToString(GetLogicalTypeIdValues(), 51, "LogicalTypeId", static_cast(value)); } template<> LogicalTypeId EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 50, "LogicalTypeId", value)); + return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 51, "LogicalTypeId", value)); } const StringUtil::EnumStringLiteral *GetLookupResultTypeValues() { @@ -2825,19 +2852,20 @@ const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { { static_cast(MetricsType::OPTIMIZER_MATERIALIZED_CTE), "OPTIMIZER_MATERIALIZED_CTE" }, { static_cast(MetricsType::OPTIMIZER_SUM_REWRITER), "OPTIMIZER_SUM_REWRITER" }, { static_cast(MetricsType::OPTIMIZER_LATE_MATERIALIZATION), "OPTIMIZER_LATE_MATERIALIZATION" }, - { static_cast(MetricsType::OPTIMIZER_CTE_INLINING), "OPTIMIZER_CTE_INLINING" } + { static_cast(MetricsType::OPTIMIZER_CTE_INLINING), "OPTIMIZER_CTE_INLINING" }, + { static_cast(MetricsType::OPTIMIZER_COMMON_SUBPLAN), "OPTIMIZER_COMMON_SUBPLAN" } }; return values; } template<> const char* EnumUtil::ToChars(MetricsType value) { - return StringUtil::EnumToString(GetMetricsTypeValues(), 54, "MetricsType", static_cast(value)); + return StringUtil::EnumToString(GetMetricsTypeValues(), 55, "MetricsType", static_cast(value)); } template<> MetricsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 54, "MetricsType", value)); + return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 55, "MetricsType", value)); } const StringUtil::EnumStringLiteral *GetMultiFileColumnMappingModeValues() { @@ -3069,19 +3097,20 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::MATERIALIZED_CTE), "MATERIALIZED_CTE" }, { static_cast(OptimizerType::SUM_REWRITER), "SUM_REWRITER" }, { static_cast(OptimizerType::LATE_MATERIALIZATION), "LATE_MATERIALIZATION" }, - { static_cast(OptimizerType::CTE_INLINING), "CTE_INLINING" } + { static_cast(OptimizerType::CTE_INLINING), "CTE_INLINING" }, + { static_cast(OptimizerType::COMMON_SUBPLAN), "COMMON_SUBPLAN" } }; return values; } template<> const char* EnumUtil::ToChars(OptimizerType value) { - return StringUtil::EnumToString(GetOptimizerTypeValues(), 29, "OptimizerType", static_cast(value)); + return StringUtil::EnumToString(GetOptimizerTypeValues(), 30, "OptimizerType", static_cast(value)); } template<> OptimizerType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 29, "OptimizerType", value)); + return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 30, "OptimizerType", value)); } const StringUtil::EnumStringLiteral *GetOrderByNullTypeValues() { @@ -4806,6 +4835,7 @@ const StringUtil::EnumStringLiteral *GetVariantLogicalTypeValues() { { static_cast(VariantLogicalType::ARRAY), "ARRAY" }, { static_cast(VariantLogicalType::BIGNUM), "BIGNUM" }, { static_cast(VariantLogicalType::BITSTRING), "BITSTRING" }, + { static_cast(VariantLogicalType::GEOMETRY), "GEOMETRY" }, { static_cast(VariantLogicalType::ENUM_SIZE), "ENUM_SIZE" } }; return values; @@ -4813,12 +4843,12 @@ const StringUtil::EnumStringLiteral *GetVariantLogicalTypeValues() { template<> const char* EnumUtil::ToChars(VariantLogicalType value) { - return StringUtil::EnumToString(GetVariantLogicalTypeValues(), 34, "VariantLogicalType", static_cast(value)); + return StringUtil::EnumToString(GetVariantLogicalTypeValues(), 35, "VariantLogicalType", static_cast(value)); } template<> VariantLogicalType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVariantLogicalTypeValues(), 34, "VariantLogicalType", value)); + return static_cast(StringUtil::StringToEnum(GetVariantLogicalTypeValues(), 35, "VariantLogicalType", value)); } const StringUtil::EnumStringLiteral *GetVectorAuxiliaryDataTypeValues() { diff --git a/src/duckdb/src/common/enums/metric_type.cpp b/src/duckdb/src/common/enums/metric_type.cpp index 866049251..c2bb79f50 100644 --- a/src/duckdb/src/common/enums/metric_type.cpp +++ b/src/duckdb/src/common/enums/metric_type.cpp @@ -41,6 +41,7 @@ profiler_settings_t MetricsUtils::GetOptimizerMetrics() { MetricsType::OPTIMIZER_SUM_REWRITER, MetricsType::OPTIMIZER_LATE_MATERIALIZATION, MetricsType::OPTIMIZER_CTE_INLINING, + MetricsType::OPTIMIZER_COMMON_SUBPLAN, }; } @@ -115,6 +116,8 @@ MetricsType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { return MetricsType::OPTIMIZER_LATE_MATERIALIZATION; case OptimizerType::CTE_INLINING: return MetricsType::OPTIMIZER_CTE_INLINING; + case OptimizerType::COMMON_SUBPLAN: + return MetricsType::OPTIMIZER_COMMON_SUBPLAN; default: throw InternalException("OptimizerType %s cannot be converted to a MetricsType", EnumUtil::ToString(type)); }; @@ -178,6 +181,8 @@ OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricsType type) { return OptimizerType::LATE_MATERIALIZATION; case MetricsType::OPTIMIZER_CTE_INLINING: return OptimizerType::CTE_INLINING; + case MetricsType::OPTIMIZER_COMMON_SUBPLAN: + return OptimizerType::COMMON_SUBPLAN; default: return OptimizerType::INVALID; }; @@ -213,6 +218,7 @@ bool MetricsUtils::IsOptimizerMetric(MetricsType type) { case MetricsType::OPTIMIZER_SUM_REWRITER: case MetricsType::OPTIMIZER_LATE_MATERIALIZATION: case MetricsType::OPTIMIZER_CTE_INLINING: + case MetricsType::OPTIMIZER_COMMON_SUBPLAN: return true; default: return false; diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp index b0d669500..be5fc8309 100644 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -40,6 +40,7 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"sum_rewriter", OptimizerType::SUM_REWRITER}, {"late_materialization", OptimizerType::LATE_MATERIALIZATION}, {"cte_inlining", OptimizerType::CTE_INLINING}, + {"common_subplan", OptimizerType::COMMON_SUBPLAN}, {nullptr, OptimizerType::INVALID}}; string OptimizerTypeToString(OptimizerType type) { diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp index 1d3160814..6218f3e7b 100644 --- a/src/duckdb/src/common/extra_type_info.cpp +++ b/src/duckdb/src/common/extra_type_info.cpp @@ -507,4 +507,19 @@ shared_ptr TemplateTypeInfo::Copy() const { return make_shared_ptr(*this); } +//===--------------------------------------------------------------------===// +// Geo Type Info +//===--------------------------------------------------------------------===// +GeoTypeInfo::GeoTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::GEO_TYPE_INFO) { +} + +bool GeoTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + // No additional info to compare + return true; +} + +shared_ptr GeoTypeInfo::Copy() const { + return make_shared_ptr(*this); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp index 926cfb6a0..32288e77a 100644 --- a/src/duckdb/src/common/file_system.cpp +++ b/src/duckdb/src/common/file_system.cpp @@ -628,39 +628,9 @@ bool FileSystem::CanHandleFile(const string &fpath) { throw NotImplementedException("%s: CanHandleFile is not implemented!", GetName()); } -static string LookupExtensionForPattern(const string &pattern) { - for (const auto &entry : EXTENSION_FILE_PREFIXES) { - if (StringUtil::StartsWith(pattern, entry.name)) { - return entry.extension; - } - } - return ""; -} - vector FileSystem::GlobFiles(const string &pattern, ClientContext &context, const FileGlobInput &input) { auto result = Glob(pattern); if (result.empty()) { - string required_extension = LookupExtensionForPattern(pattern); - if (!required_extension.empty() && !context.db->ExtensionIsLoaded(required_extension)) { - auto &dbconfig = DBConfig::GetConfig(context); - if (!ExtensionHelper::CanAutoloadExtension(required_extension) || - !dbconfig.options.autoload_known_extensions) { - auto error_message = - "File " + pattern + " requires the extension " + required_extension + " to be loaded"; - error_message = - ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, required_extension); - throw MissingExtensionException(error_message); - } - // an extension is required to read this file, but it is not loaded - try to load it - ExtensionHelper::AutoLoadExtension(context, required_extension); - // success! glob again - // check the extension is loaded just in case to prevent an infinite loop here - if (!context.db->ExtensionIsLoaded(required_extension)) { - throw InternalException("Extension load \"%s\" did not throw but somehow the extension was not loaded", - required_extension); - } - return GlobFiles(pattern, context, input); - } if (input.behavior == FileGlobOptions::FALLBACK_GLOB && !HasGlob(pattern)) { // if we have no glob in the pattern and we have an extension, we try to glob if (!HasGlob(pattern)) { diff --git a/src/duckdb/src/common/operator/cast_operators.cpp b/src/duckdb/src/common/operator/cast_operators.cpp index f26c16131..723897a1c 100644 --- a/src/duckdb/src/common/operator/cast_operators.cpp +++ b/src/duckdb/src/common/operator/cast_operators.cpp @@ -19,6 +19,7 @@ #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/vector.hpp" +#include "duckdb/common/types/geometry.hpp" #include "duckdb/common/types.hpp" #include "fast_float/fast_float.h" #include "duckdb/common/types/bit.hpp" @@ -1560,6 +1561,14 @@ bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, bool strict return true; } +//===--------------------------------------------------------------------===// +// Cast To Geometry +//===--------------------------------------------------------------------===// +template <> +bool TryCastToGeometry::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters) { + return Geometry::FromString(input, result, result_vector, parameters.strict); +} + //===--------------------------------------------------------------------===// // Cast To Date //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/render_tree.cpp b/src/duckdb/src/common/render_tree.cpp index 582d5e1ad..ee9621814 100644 --- a/src/duckdb/src/common/render_tree.cpp +++ b/src/duckdb/src/common/render_tree.cpp @@ -103,7 +103,7 @@ static unique_ptr CreateNode(const ProfilingNode &op) { auto &info = op.GetProfilingInfo(); InsertionOrderPreservingMap extra_info; if (info.Enabled(info.settings, MetricsType::EXTRA_INFO)) { - extra_info = op.GetProfilingInfo().extra_info; + extra_info = op.GetProfilingInfo().GetMetricValue>(MetricsType::EXTRA_INFO); } string node_name = "QUERY"; diff --git a/src/duckdb/src/common/sorting/sort.cpp b/src/duckdb/src/common/sorting/sort.cpp index 2159878ff..56bde8499 100644 --- a/src/duckdb/src/common/sorting/sort.cpp +++ b/src/duckdb/src/common/sorting/sort.cpp @@ -141,7 +141,7 @@ class SortLocalSinkState : public LocalSinkState { D_ASSERT(!sorted_run); // TODO: we want to pass "sort.is_index_sort" instead of just "false" here // so that we can do an approximate sort, but that causes issues in the ART - sorted_run = make_uniq(context, sort.key_layout, sort.payload_layout, false); + sorted_run = make_uniq(context, sort, false); } public: @@ -366,8 +366,7 @@ ProgressData Sort::GetSinkProgress(ClientContext &context, GlobalSinkState &gsta class SortGlobalSourceState : public GlobalSourceState { public: SortGlobalSourceState(const Sort &sort, ClientContext &context, SortGlobalSinkState &sink_p) - : sink(sink_p), merger(*sort.decode_sort_key, sort.key_layout, std::move(sink.sorted_runs), - sort.output_projection_columns, sink.partition_size, sink.external, false), + : sink(sink_p), merger(sort, std::move(sink.sorted_runs), sink.partition_size, sink.external, false), merger_global_state(merger.total_count == 0 ? nullptr : merger.GetGlobalSourceState(context)) { // TODO: we want to pass "sort.is_index_sort" instead of just "false" here // so that we can do an approximate sort, but that causes issues in the ART @@ -502,12 +501,15 @@ SourceResultType Sort::MaterializeSortedRun(ExecutionContext &context, OperatorS } auto &lstate = input.local_state.Cast(); OperatorSourceInput merger_input {*gstate.merger_global_state, *lstate.merger_local_state, input.interrupt_state}; - return gstate.merger.MaterializeMerge(context, merger_input); + return gstate.merger.MaterializeSortedRun(context, merger_input); } unique_ptr Sort::GetSortedRun(GlobalSourceState &global_state) { auto &gstate = global_state.Cast(); - return gstate.merger.GetMaterialized(gstate); + if (gstate.merger.total_count == 0) { + return nullptr; + } + return gstate.merger.GetSortedRun(*gstate.merger_global_state); } } // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sorted_run.cpp b/src/duckdb/src/common/sorting/sorted_run.cpp index 57c390d32..c72c554ec 100644 --- a/src/duckdb/src/common/sorting/sorted_run.cpp +++ b/src/duckdb/src/common/sorting/sorted_run.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/types/row/tuple_data_collection.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/row/block_iterator.hpp" @@ -9,14 +10,161 @@ namespace duckdb { -SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_layout, - shared_ptr payload_layout, bool is_index_sort_p) - : context(context_p), - key_data(make_uniq(BufferManager::GetBufferManager(context), std::move(key_layout))), - payload_data( - payload_layout && payload_layout->ColumnCount() != 0 - ? make_uniq(BufferManager::GetBufferManager(context), std::move(payload_layout)) - : nullptr), +//===--------------------------------------------------------------------===// +// SortedRunScanState +//===--------------------------------------------------------------------===// +SortedRunScanState::SortedRunScanState(ClientContext &context, const Sort &sort_p) + : sort(sort_p), key_executor(context, *sort.decode_sort_key) { + key.Initialize(context, {sort.key_layout->GetTypes()[0]}); + decoded_key.Initialize(context, {sort.decode_sort_key->return_type}); +} + +void SortedRunScanState::Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk) { + const auto sort_key_type = sort.key_layout->GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedScan(sorted_run, sort_key_pointers, count, chunk); + default: + throw NotImplementedException("SortedRunMergerLocalState::ScanPartition for %s", + EnumUtil::ToString(sort_key_type)); + } +} + +template +void TemplatedGetKeyAndPayload(SORT_KEY *const *const sort_keys, const idx_t &count, DataChunk &key, + data_ptr_t *const payload_ptrs) { + const auto key_data = FlatVector::GetData(key.data[0]); + for (idx_t i = 0; i < count; i++) { + auto &sort_key = *sort_keys[i]; + sort_key.Deconstruct(key_data[i]); + if (SORT_KEY::HAS_PAYLOAD) { + payload_ptrs[i] = sort_key.GetPayload(); + } + } + key.SetCardinality(count); +} + +template +void GetKeyAndPayload(SORT_KEY *const *const sort_keys, const idx_t &count, DataChunk &key, + data_ptr_t *const payload_ptrs) { + const auto type_id = key.data[0].GetType().id(); + switch (type_id) { + case LogicalTypeId::BLOB: + return TemplatedGetKeyAndPayload(sort_keys, count, key, payload_ptrs); + case LogicalTypeId::BIGINT: + return TemplatedGetKeyAndPayload(sort_keys, count, key, payload_ptrs); + default: + throw NotImplementedException("GetKeyAndPayload for %s", EnumUtil::ToString(type_id)); + } +} + +template +void TemplatedReconstructSortKey(SORT_KEY *const *const sort_keys, const idx_t &count) { + for (idx_t i = 0; i < count; i++) { + sort_keys[i]->ByteSwap(); + } +} + +template +void ReconstructSortKey(SORT_KEY *const *const sort_keys, const idx_t &count, const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::BLOB: + return TemplatedReconstructSortKey(sort_keys, count); + case LogicalTypeId::BIGINT: + break; // NOP + default: + throw NotImplementedException("ReconstructSortKey for %s", EnumUtil::ToString(type.id())); + } +} + +template +void SortedRunScanState::TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk) { + using SORT_KEY = SortKey; + + const auto &output_projection_columns = sort.output_projection_columns; + idx_t opc_idx = 0; + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + const auto payload_ptrs = FlatVector::GetData(payload_state.chunk_state.row_locations); + bool gathered_payload = false; + + // Decode from key + if (!output_projection_columns[0].is_payload) { + key.Reset(); + GetKeyAndPayload(sort_keys, count, key, payload_ptrs); + + decoded_key.Reset(); + key_executor.Execute(key, decoded_key); + ReconstructSortKey(sort_keys, count, key.data[0].GetType()); + + const auto &decoded_key_entries = StructVector::GetEntries(decoded_key.data[0]); + for (; opc_idx < output_projection_columns.size(); opc_idx++) { + const auto &opc = output_projection_columns[opc_idx]; + if (opc.is_payload) { + break; + } + chunk.data[opc.output_col_idx].Reference(*decoded_key_entries[opc.layout_col_idx]); + } + + gathered_payload = true; + } + + // If there are no payload columns, we're done here + if (opc_idx != output_projection_columns.size()) { + if (!gathered_payload) { + // Gather row pointers from keys + for (idx_t i = 0; i < count; i++) { + payload_ptrs[i] = sort_keys[i]->GetPayload(); + } + } + + // Init scan state + auto &payload_data = *sorted_run.payload_data; + if (payload_state.pin_state.properties == TupleDataPinProperties::INVALID) { + payload_data.InitializeScan(payload_state, TupleDataPinProperties::ALREADY_PINNED); + } + TupleDataCollection::ResetCachedCastVectors(payload_state.chunk_state, payload_state.chunk_state.column_ids); + + // Now gather from payload + for (; opc_idx < output_projection_columns.size(); opc_idx++) { + const auto &opc = output_projection_columns[opc_idx]; + D_ASSERT(opc.is_payload); + payload_data.Gather(payload_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), + count, opc.layout_col_idx, chunk.data[opc.output_col_idx], + *FlatVector::IncrementalSelectionVector(), + payload_state.chunk_state.cached_cast_vectors[opc.layout_col_idx]); + } + } + + chunk.SetCardinality(count); +} + +//===--------------------------------------------------------------------===// +// SortedRun +//===--------------------------------------------------------------------===// +SortedRun::SortedRun(ClientContext &context_p, const Sort &sort_p, bool is_index_sort_p) + : context(context_p), sort(sort_p), key_data(make_uniq(context, sort.key_layout)), + payload_data(sort.payload_layout && sort.payload_layout->ColumnCount() != 0 + ? make_uniq(context, sort.payload_layout) + : nullptr), is_index_sort(is_index_sort_p), finalized(false) { key_data->InitializeAppend(key_append_state, TupleDataPinProperties::KEEP_EVERYTHING_PINNED); if (payload_data) { @@ -25,8 +173,7 @@ SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_l } unique_ptr SortedRun::CreateRunForMaterialization() const { - auto res = make_uniq(context, key_data->GetLayoutPtr(), - payload_data ? payload_data->GetLayoutPtr() : nullptr, is_index_sort); + auto res = make_uniq(context, sort, is_index_sort); res->key_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; res->payload_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; res->finalized = true; diff --git a/src/duckdb/src/common/sorting/sorted_run_merger.cpp b/src/duckdb/src/common/sorting/sorted_run_merger.cpp index eb879edc5..d87cef470 100644 --- a/src/duckdb/src/common/sorting/sorted_run_merger.cpp +++ b/src/duckdb/src/common/sorting/sorted_run_merger.cpp @@ -1,5 +1,6 @@ #include "duckdb/common/sorting/sorted_run_merger.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/row/block_iterator.hpp" @@ -100,7 +101,7 @@ class SortedRunMergerLocalState : public LocalSourceState { //! Whether this thread has finished the work it has been assigned bool TaskFinished() const; //! Do the work this thread has been assigned - void ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk); + SourceResultType ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk); private: //! Computes upper partition boundaries using K-way Merge Path @@ -154,12 +155,10 @@ class SortedRunMergerLocalState : public LocalSourceState { //! Variables for scanning idx_t merged_partition_count; idx_t merged_partition_index; - TupleDataScanState payload_state; - //! For decoding sort keys - ExpressionExecutor key_executor; - DataChunk key; - DataChunk decoded_key; + //! For scanning + Vector sort_key_pointers; + SortedRunScanState sorted_run_scan_state; }; //===--------------------------------------------------------------------===// @@ -172,7 +171,7 @@ class SortedRunMergerGlobalState : public GlobalSourceState { merger(merger_p), num_runs(merger.sorted_runs.size()), num_partitions((merger.total_count + (merger.partition_size - 1)) / merger.partition_size), iterator_state_type(GetBlockIteratorStateType(merger.external)), - sort_key_type(merger.key_layout->GetSortKeyType()), next_partition_idx(0), total_scanned(0), + sort_key_type(merger.sort.key_layout->GetSortKeyType()), next_partition_idx(0), total_scanned(0), destroy_partition_idx(0) { // Initialize partitions partitions.resize(num_partitions); @@ -292,7 +291,7 @@ SortedRunMergerLocalState::SortedRunMergerLocalState(SortedRunMergerGlobalState : iterator_state_type(gstate.iterator_state_type), sort_key_type(gstate.sort_key_type), task(SortedRunMergerTask::FINISHED), run_boundaries(gstate.num_runs), merged_partition_count(DConstants::INVALID_INDEX), merged_partition_index(DConstants::INVALID_INDEX), - key_executor(gstate.context, gstate.merger.decode_sort_key) { + sorted_run_scan_state(gstate.context, gstate.merger.sort), sort_key_pointers(LogicalType::POINTER) { for (const auto &run : gstate.merger.sorted_runs) { auto &key_data = *run->key_data; switch (iterator_state_type) { @@ -308,8 +307,6 @@ SortedRunMergerLocalState::SortedRunMergerLocalState(SortedRunMergerGlobalState EnumUtil::ToString(iterator_state_type)); } } - key.Initialize(gstate.context, {gstate.merger.key_layout->GetTypes()[0]}); - decoded_key.Initialize(gstate.context, {gstate.merger.decode_sort_key.return_type}); } bool SortedRunMergerLocalState::TaskFinished() const { @@ -328,7 +325,8 @@ bool SortedRunMergerLocalState::TaskFinished() const { } } -void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk) { +SourceResultType SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, + optional_ptr chunk) { D_ASSERT(task != SortedRunMergerTask::FINISHED); switch (task) { case SortedRunMergerTask::COMPUTE_BOUNDARIES: @@ -352,14 +350,18 @@ void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, if (!chunk || chunk->size() == 0) { gstate.DestroyScannedData(); gstate.partitions[partition_idx.GetIndex()]->scanned = true; - gstate.total_scanned += merged_partition_count; + const auto scan_count_after_adding = gstate.total_scanned.fetch_add(merged_partition_count); partition_idx = optional_idx::Invalid(); task = SortedRunMergerTask::FINISHED; + if (scan_count_after_adding == gstate.merger.total_count) { + return SourceResultType::FINISHED; + } } break; default: throw NotImplementedException("SortedRunMergerLocalState::ExecuteTask for task"); } + return SourceResultType::HAVE_MORE_OUTPUT; } void SortedRunMergerLocalState::ComputePartitionBoundaries(SortedRunMergerGlobalState &gstate, @@ -718,61 +720,16 @@ void SortedRunMergerLocalState::TemplatedScanPartition(SortedRunMergerGlobalStat using SORT_KEY = SortKey; const auto count = MinValue(merged_partition_count - merged_partition_index, STANDARD_VECTOR_SIZE); - const auto &output_projection_columns = gstate.merger.output_projection_columns; - idx_t opc_idx = 0; - + // Grab pointers to sort keys const auto merged_partition_keys = reinterpret_cast(merged_partition.get()) + merged_partition_index; - const auto payload_ptrs = FlatVector::GetData(payload_state.chunk_state.row_locations); - bool gathered_payload = false; - - // Decode from key - if (!output_projection_columns[0].is_payload) { - key.Reset(); - GetKeyAndPayload(merged_partition_keys, count, key, payload_ptrs); - - decoded_key.Reset(); - key_executor.Execute(key, decoded_key); - - const auto &decoded_key_entries = StructVector::GetEntries(decoded_key.data[0]); - for (; opc_idx < output_projection_columns.size(); opc_idx++) { - const auto &opc = output_projection_columns[opc_idx]; - if (opc.is_payload) { - break; - } - chunk.data[opc.output_col_idx].Reference(*decoded_key_entries[opc.layout_col_idx]); - } - gathered_payload = true; - } - - // If there are no payload columns, we're done here - if (opc_idx != output_projection_columns.size()) { - if (!gathered_payload) { - // Gather row pointers from keys - for (idx_t i = 0; i < count; i++) { - payload_ptrs[i] = merged_partition_keys[i].GetPayload(); - } - } - - // Init scan state - auto &payload_data = *gstate.merger.sorted_runs.back()->payload_data; - if (payload_state.pin_state.properties == TupleDataPinProperties::INVALID) { - payload_data.InitializeScan(payload_state, TupleDataPinProperties::ALREADY_PINNED); - } - TupleDataCollection::ResetCachedCastVectors(payload_state.chunk_state, payload_state.chunk_state.column_ids); - - // Now gather from payload - for (; opc_idx < output_projection_columns.size(); opc_idx++) { - const auto &opc = output_projection_columns[opc_idx]; - D_ASSERT(opc.is_payload); - payload_data.Gather(payload_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), - count, opc.layout_col_idx, chunk.data[opc.output_col_idx], - *FlatVector::IncrementalSelectionVector(), - payload_state.chunk_state.cached_cast_vectors[opc.layout_col_idx]); - } + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + for (idx_t i = 0; i < count; i++) { + sort_keys[i] = &merged_partition_keys[i]; } - merged_partition_index += count; - chunk.SetCardinality(count); + + // Scan + sorted_run_scan_state.Scan(*gstate.merger.sorted_runs[0], sort_key_pointers, count, chunk); } void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState &gstate) { @@ -812,7 +769,9 @@ void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState // Add to global state lock_guard guard(gstate.materialized_partition_lock); - gstate.materialized_partitions.resize(partition_idx.GetIndex()); + if (gstate.materialized_partitions.size() < partition_idx.GetIndex() + 1) { + gstate.materialized_partitions.resize(partition_idx.GetIndex() + 1); + } gstate.materialized_partitions[partition_idx.GetIndex()] = std::move(sorted_run); } @@ -833,7 +792,7 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S while (merged_partition_index < merged_partition_count) { const auto count = MinValue(merged_partition_count - merged_partition_index, STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < count + count; i++) { + for (idx_t i = 0; i < count; i++) { auto &key = merged_partition_keys[merged_partition_index + i]; key_locations[i] = data_ptr_cast(&key); if (!SORT_KEY::CONSTANT_SIZE) { @@ -855,7 +814,7 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S if (!sorted_run->payload_data->GetLayout().AllConstant()) { sorted_run->payload_data->FindHeapPointers(payload_data_input, count); } - sorted_run->payload_append_state.chunk_state.heap_sizes.Reference(key_data_input.heap_sizes); + sorted_run->payload_append_state.chunk_state.heap_sizes.Reference(payload_data_input.heap_sizes); sorted_run->payload_data->Build(sorted_run->payload_append_state.pin_state, sorted_run->payload_append_state.chunk_state, 0, count); sorted_run->payload_data->CopyRows(sorted_run->payload_append_state.chunk_state, payload_data_input, @@ -876,12 +835,9 @@ unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(S //===--------------------------------------------------------------------===// // Sorted Run Merger //===--------------------------------------------------------------------===// -SortedRunMerger::SortedRunMerger(const Expression &decode_sort_key_p, shared_ptr key_layout_p, - vector> &&sorted_runs_p, - const vector &output_projection_columns_p, +SortedRunMerger::SortedRunMerger(const Sort &sort_p, vector> &&sorted_runs_p, idx_t partition_size_p, bool external_p, bool is_index_sort_p) - : decode_sort_key(decode_sort_key_p), key_layout(std::move(key_layout_p)), sorted_runs(std::move(sorted_runs_p)), - output_projection_columns(output_projection_columns_p), total_count(SortedRunsTotalCount(sorted_runs)), + : sort(sort_p), sorted_runs(std::move(sorted_runs_p)), total_count(SortedRunsTotalCount(sorted_runs)), partition_size(partition_size_p), external(external_p), is_index_sort(is_index_sort_p) { } @@ -929,30 +885,28 @@ ProgressData SortedRunMerger::GetProgress(ClientContext &, GlobalSourceState &gs //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// -SourceResultType SortedRunMerger::MaterializeMerge(ExecutionContext &, OperatorSourceInput &input) const { +SourceResultType SortedRunMerger::MaterializeSortedRun(ExecutionContext &, OperatorSourceInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); + SourceResultType res = SourceResultType::HAVE_MORE_OUTPUT; while (true) { if (!lstate.TaskFinished() || gstate.AssignTask(lstate)) { - lstate.ExecuteTask(gstate, nullptr); + res = lstate.ExecuteTask(gstate, nullptr); } else { break; } } - if (gstate.total_scanned == total_count) { - // This signals that the data has been fully materialized - return SourceResultType::FINISHED; - } - // This signals that no more tasks are left, but that the data has not yet been fully materialized - return SourceResultType::HAVE_MORE_OUTPUT; + // The thread that completes the materialization returns FINISHED, all other threads return HAVE_MORE_OUTPUT + return res; } -unique_ptr SortedRunMerger::GetMaterialized(GlobalSourceState &global_state) { +unique_ptr SortedRunMerger::GetSortedRun(GlobalSourceState &global_state) { auto &gstate = global_state.Cast(); + D_ASSERT(total_count != 0); + lock_guard guard(gstate.materialized_partition_lock); if (gstate.materialized_partitions.empty()) { - D_ASSERT(total_count == 0); return nullptr; } auto &target = *gstate.materialized_partitions[0]; @@ -963,7 +917,9 @@ unique_ptr SortedRunMerger::GetMaterialized(GlobalSourceState &global target.payload_data->Combine(*source.payload_data); } } - return std::move(gstate.materialized_partitions[0]); + auto res = std::move(gstate.materialized_partitions[0]); + gstate.materialized_partitions.clear(); + return res; } } // namespace duckdb diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp index 51be7c3eb..1e6309ee0 100644 --- a/src/duckdb/src/common/string_util.cpp +++ b/src/duckdb/src/common/string_util.cpp @@ -702,6 +702,21 @@ string StringUtil::ToComplexJSONMap(const ComplexJSON &complex_json) { return ComplexJSON::GetValueRecursive(complex_json); } +string StringUtil::ValidateJSON(const char *data, const idx_t &len) { + // Same flags as in JSON extension + static constexpr auto READ_FLAG = + YYJSON_READ_ALLOW_INF_AND_NAN | YYJSON_READ_ALLOW_TRAILING_COMMAS | YYJSON_READ_BIGNUM_AS_RAW; + yyjson_read_err error; + yyjson_doc *doc = yyjson_read_opts((char *)data, len, READ_FLAG, nullptr, &error); // NOLINT: for yyjson + if (error.code != YYJSON_READ_SUCCESS) { + return StringUtil::Format("Malformed JSON at byte %lld of input: %s. Input: \"%s\"", error.pos, error.msg, + string(data, len)); + } + + yyjson_doc_free(doc); + return string(); +} + string StringUtil::ExceptionToJSONMap(ExceptionType type, const string &message, const unordered_map &map) { D_ASSERT(map.find("exception_type") == map.end()); diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp index 40ff794e6..15fd9364e 100644 --- a/src/duckdb/src/common/types.cpp +++ b/src/duckdb/src/common/types.cpp @@ -159,6 +159,8 @@ PhysicalType LogicalType::GetInternalType() { return PhysicalType::UNKNOWN; case LogicalTypeId::AGGREGATE_STATE: return PhysicalType::VARCHAR; + case LogicalTypeId::GEOMETRY: + return PhysicalType::VARCHAR; default: throw InternalException("Invalid LogicalType %s", ToString()); } @@ -1344,6 +1346,8 @@ static idx_t GetLogicalTypeScore(const LogicalType &type) { return 102; case LogicalTypeId::BIGNUM: return 103; + case LogicalTypeId::GEOMETRY: + return 104; // nested types case LogicalTypeId::STRUCT: return 125; @@ -2014,6 +2018,15 @@ LogicalType LogicalType::VARIANT() { return LogicalType(LogicalTypeId::VARIANT, std::move(info)); } +//===--------------------------------------------------------------------===// +// Spatial Types +//===--------------------------------------------------------------------===// + +LogicalType LogicalType::GEOMETRY() { + auto info = make_shared_ptr(); + return LogicalType(LogicalTypeId::GEOMETRY, std::move(info)); +} + //===--------------------------------------------------------------------===// // Logical Type //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/types/column/column_data_collection.cpp b/src/duckdb/src/common/types/column/column_data_collection.cpp index b53e07d68..ad478d515 100644 --- a/src/duckdb/src/common/types/column/column_data_collection.cpp +++ b/src/duckdb/src/common/types/column/column_data_collection.cpp @@ -842,7 +842,8 @@ void ColumnDataCopyArray(ColumnDataMetaData &meta_data, const UnifiedVectorForma child_vector.ToUnifiedFormat(copy_count * array_size, child_vector_data); // Broadcast and sync the validity of the array vector to the child vector - + // This requires creating a copy of the validity mask: we cannot modify the input validity + child_vector_data.validity = ValidityMask(child_vector_data.validity, child_vector_data.validity.Capacity()); if (source_data.validity.IsMaskSet()) { for (idx_t i = 0; i < copy_count; i++) { auto source_idx = source_data.sel->get_index(offset + i); diff --git a/src/duckdb/src/common/types/geometry.cpp b/src/duckdb/src/common/types/geometry.cpp new file mode 100644 index 000000000..2ec9ac53a --- /dev/null +++ b/src/duckdb/src/common/types/geometry.cpp @@ -0,0 +1,773 @@ +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/common/types/vector.hpp" +#include "fast_float/fast_float.h" +#include "fmt/format.h" + +//---------------------------------------------------------------------------------------------------------------------- +// Internals +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { + +namespace { + +class BlobWriter { +public: + template + void Write(const T &value) { + auto ptr = reinterpret_cast(&value); + buffer.insert(buffer.end(), ptr, ptr + sizeof(T)); + } + + template + struct Reserved { + size_t offset; + T value; + }; + + template + Reserved Reserve() { + auto offset = buffer.size(); + buffer.resize(buffer.size() + sizeof(T)); + return {offset, T()}; + } + + template + void Write(const Reserved &reserved) { + if (reserved.offset + sizeof(T) > buffer.size()) { + throw InternalException("Write out of bounds in BinaryWriter"); + } + auto ptr = reinterpret_cast(&reserved.value); + // We've reserved 0 bytes, so we can safely memcpy + memcpy(buffer.data() + reserved.offset, ptr, sizeof(T)); + } + + void Write(const char *data, size_t size) { + D_ASSERT(data != nullptr); + buffer.insert(buffer.end(), data, data + size); + } + + const vector &GetBuffer() const { + return buffer; + } + +private: + vector buffer; +}; + +class BlobReader { +public: + BlobReader(const char *data, uint32_t size) : beg(data), pos(data), end(data + size) { + } + + template + T Read() { + if (pos + sizeof(T) > end) { + throw InvalidInputException("Unexpected end of binary data at position %zu", pos - beg); + } + T value; + if (LE) { + memcpy(&value, pos, sizeof(T)); + pos += sizeof(T); + } else { + char temp[sizeof(T)]; + for (size_t i = 0; i < sizeof(T); ++i) { + temp[i] = pos[sizeof(T) - 1 - i]; + } + memcpy(&value, temp, sizeof(T)); + pos += sizeof(T); + } + return value; + } + + void Skip(size_t size) { + if (pos + size > end) { + throw InvalidInputException("Skipping beyond end of binary data at position %zu", pos - beg); + } + pos += size; + } + + const char *Reserve(size_t size) { + if (pos + size > end) { + throw InvalidInputException("Reserving beyond end of binary data at position %zu", pos - beg); + } + auto current_pos = pos; + pos += size; + return current_pos; + } + + size_t GetPosition() const { + return static_cast(pos - beg); + } + + bool IsAtEnd() const { + return pos >= end; + } + +private: + const char *beg; + const char *pos; + const char *end; +}; + +class TextWriter { +public: + void Write(const char *str) { + buffer.insert(buffer.end(), str, str + strlen(str)); + } + void Write(char c) { + buffer.push_back(c); + } + void Write(double value) { + duckdb_fmt::format_to(std::back_inserter(buffer), "{}", value); + // Remove trailing zero + if (buffer.back() == '0') { + buffer.pop_back(); + if (buffer.back() == '.') { + buffer.pop_back(); + } + } + } + const vector &GetBuffer() const { + return buffer; + } + +private: + vector buffer; +}; + +class TextReader { +public: + TextReader(const char *text, const uint32_t size) : beg(text), pos(text), end(text + size) { + } + + bool TryMatch(const char *str) { + auto ptr = pos; + while (*str && pos < end && tolower(*pos) == tolower(*str)) { + pos++; + str++; + } + if (*str == '\0') { + SkipWhitespace(); // remove trailing whitespace + return true; // matched + } else { + pos = ptr; // reset position + return false; // not matched + } + } + + bool TryMatch(char c) { + if (pos < end && tolower(*pos) == tolower(c)) { + pos++; + SkipWhitespace(); // remove trailing whitespace + return true; // matched + } + return false; // not matched + } + + void Match(const char *str) { + if (!TryMatch(str)) { + throw InvalidInputException("Expected '%s' but got '%c' at position %zu", str, *pos, pos - beg); + } + } + + void Match(char c) { + if (!TryMatch(c)) { + throw InvalidInputException("Expected '%c' but got '%c' at position %zu", c, *pos, pos - beg); + } + } + + double MatchNumber() { + // Now use fast_float to parse the number + double num; + const auto res = duckdb_fast_float::from_chars(pos, end, num); + if (res.ec != std::errc()) { + throw InvalidInputException("Expected number at position %zu", pos - beg); + } + + pos = res.ptr; // update position to the end of the parsed number + + SkipWhitespace(); // remove trailing whitespace + return num; // return the parsed number + } + + idx_t GetPosition() const { + return static_cast(pos - beg); + } + + void Reset() { + pos = beg; + } + +private: + void SkipWhitespace() { + while (pos < end && isspace(*pos)) { + pos++; + } + } + + const char *beg; + const char *pos; + const char *end; +}; + +void FromStringRecursive(TextReader &reader, BlobWriter &writer, uint32_t depth, bool parent_has_z, bool parent_has_m) { + + if (depth == Geometry::MAX_RECURSION_DEPTH) { + throw InvalidInputException("Geometry string exceeds maximum recursion depth of %d", + Geometry::MAX_RECURSION_DEPTH); + } + + GeometryType type; + + if (reader.TryMatch("point")) { + type = GeometryType::POINT; + } else if (reader.TryMatch("linestring")) { + type = GeometryType::LINESTRING; + } else if (reader.TryMatch("polygon")) { + type = GeometryType::POLYGON; + } else if (reader.TryMatch("multipoint")) { + type = GeometryType::MULTIPOINT; + } else if (reader.TryMatch("multilinestring")) { + type = GeometryType::MULTILINESTRING; + } else if (reader.TryMatch("multipolygon")) { + type = GeometryType::MULTIPOLYGON; + } else if (reader.TryMatch("geometrycollection")) { + type = GeometryType::GEOMETRYCOLLECTION; + } else { + throw InvalidInputException("Unknown geometry type at position %zu", reader.GetPosition()); + } + + const auto has_z = reader.TryMatch("z"); + const auto has_m = reader.TryMatch("m"); + + const auto is_empty = reader.TryMatch("empty"); + + if ((depth != 0) && ((parent_has_z != has_z) || (parent_has_m != has_m))) { + throw InvalidInputException("Geometry has inconsistent Z/M dimensions, starting at position %zu", + reader.GetPosition()); + } + + // How many dimensions does this geometry have? + const uint32_t dims = 2 + (has_z ? 1 : 0) + (has_m ? 1 : 0); + + // WKB type + const auto meta = static_cast(type) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + // Write the geometry type and vertex type + writer.Write(1); // LE Byte Order + writer.Write(meta); + + switch (type) { + case GeometryType::POINT: { + if (is_empty) { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + // Write NaN for each dimension, if point is empty + writer.Write(std::numeric_limits::quiet_NaN()); + } + } else { + reader.Match('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + reader.Match(')'); + } + } break; + case GeometryType::LINESTRING: { + if (is_empty) { + writer.Write(0); // No vertices in empty linestring + break; + } + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + } break; + case GeometryType::POLYGON: { + if (is_empty) { + writer.Write(0); + break; // No rings in empty polygon + } + auto ring_count = writer.Reserve(); + reader.Match('('); + do { + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + ring_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(ring_count); + } break; + case GeometryType::MULTIPOINT: { + if (is_empty) { + writer.Write(0); // No points in empty multipoint + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + bool has_paren = reader.TryMatch('('); + + const auto part_meta = static_cast(GeometryType::POINT) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + if (reader.TryMatch("EMPTY")) { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + // Write NaN for each dimension, if point is empty + writer.Write(std::numeric_limits::quiet_NaN()); + } + } else { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + } + if (has_paren) { + reader.Match(')'); // Match the closing parenthesis if it was opened + } + part_count.value++; + } while (reader.TryMatch(',')); + writer.Write(part_count); + } break; + case GeometryType::MULTILINESTRING: { + if (is_empty) { + writer.Write(0); + return; // No linestrings in empty multilinestring + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + + const auto part_meta = + static_cast(GeometryType::LINESTRING) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + case GeometryType::MULTIPOLYGON: { + if (is_empty) { + writer.Write(0); // No polygons in empty multipolygon + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + + const auto part_meta = + static_cast(GeometryType::POLYGON) + (has_z ? 1000 : 0) + (has_m ? 2000 : 0); + writer.Write(1); + writer.Write(part_meta); + + auto ring_count = writer.Reserve(); + reader.Match('('); + do { + auto vert_count = writer.Reserve(); + reader.Match('('); + do { + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + auto value = reader.MatchNumber(); + writer.Write(value); + } + vert_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(vert_count); + ring_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(ring_count); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + case GeometryType::GEOMETRYCOLLECTION: { + if (is_empty) { + writer.Write(0); // No geometries in empty geometry collection + break; + } + auto part_count = writer.Reserve(); + reader.Match('('); + do { + // Recursively parse the geometry inside the collection + FromStringRecursive(reader, writer, depth + 1, has_z, has_m); + part_count.value++; + } while (reader.TryMatch(',')); + reader.Match(')'); + writer.Write(part_count); + } break; + default: + throw InvalidInputException("Unknown geometry type %d at position %zu", static_cast(type), + reader.GetPosition()); + } +} + +void ToStringRecursive(BlobReader &reader, TextWriter &writer, idx_t depth, bool parent_has_z, bool parent_has_m) { + if (depth == Geometry::MAX_RECURSION_DEPTH) { + throw InvalidInputException("Geometry exceeds maximum recursion depth of %d", Geometry::MAX_RECURSION_DEPTH); + } + + // Read the byte order (should always be 1 for little-endian) + auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + + const auto meta = reader.Read(); + const auto type = static_cast(meta % 1000); + const auto flag = meta / 1000; + const auto has_z = (flag & 0x01) != 0; + const auto has_m = (flag & 0x02) != 0; + + if ((depth != 0) && ((parent_has_z != has_z) || (parent_has_m != has_m))) { + throw InvalidInputException("Geometry has inconsistent Z/M dimensions, starting at position %zu", + reader.GetPosition()); + } + + const uint32_t dims = 2 + (has_z ? 1 : 0) + (has_m ? 1 : 0); + const auto flag_str = has_z ? (has_m ? " ZM " : " Z ") : (has_m ? " M " : " "); + + switch (type) { + case GeometryType::POINT: { + writer.Write("POINT"); + writer.Write(flag_str); + + double vert[4] = {0, 0, 0, 0}; + auto all_nan = true; + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + vert[d_idx] = reader.Read(); + all_nan &= std::isnan(vert[d_idx]); + } + if (all_nan) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + writer.Write(vert[d_idx]); + } + writer.Write(')'); + } break; + case GeometryType::LINESTRING: { + writer.Write("LINESTRING"); + ; + writer.Write(flag_str); + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } break; + case GeometryType::POLYGON: { + writer.Write("POLYGON"); + writer.Write(flag_str); + const auto ring_count = reader.Read(); + if (ring_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + if (ring_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::MULTIPOINT: { + writer.Write("MULTIPOINT"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast(part_meta % 1000); + const auto part_flag = part_meta / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + + if (part_type != GeometryType::POINT) { + throw InvalidInputException("Expected POINT in MULTIPOINT but got %d", static_cast(part_type)); + } + + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTIPOINT, starting at position %zu", + reader.GetPosition()); + } + if (part_idx > 0) { + writer.Write(", "); + } + double vert[4] = {0, 0, 0, 0}; + auto all_nan = true; + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + vert[d_idx] = reader.Read(); + all_nan &= std::isnan(vert[d_idx]); + } + if (all_nan) { + writer.Write("EMPTY"); + continue; + } + // writer.Write('('); + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + writer.Write(vert[d_idx]); + } + // writer.Write(')'); + } + writer.Write(')'); + + } break; + case GeometryType::MULTILINESTRING: { + writer.Write("MULTILINESTRING"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast(part_meta % 1000); + const auto part_flag = part_meta / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + + if (part_type != GeometryType::LINESTRING) { + throw InvalidInputException("Expected LINESTRING in MULTILINESTRING but got %d", + static_cast(part_type)); + } + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTILINESTRING, starting at position %zu", + reader.GetPosition()); + } + if (part_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::MULTIPOLYGON: { + writer.Write("MULTIPOLYGON"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + if (part_idx > 0) { + writer.Write(", "); + } + + const auto part_byte_order = reader.Read(); + if (part_byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", part_byte_order); + } + const auto part_meta = reader.Read(); + const auto part_type = static_cast(part_meta % 1000); + const auto part_flag = part_meta / 1000; + const auto part_has_z = (part_flag & 0x01) != 0; + const auto part_has_m = (part_flag & 0x02) != 0; + if (part_type != GeometryType::POLYGON) { + throw InvalidInputException("Expected POLYGON in MULTIPOLYGON but got %d", static_cast(part_type)); + } + if ((has_z != part_has_z) || (has_m != part_has_m)) { + throw InvalidInputException( + "Geometry has inconsistent Z/M dimensions in MULTIPOLYGON, starting at position %zu", + reader.GetPosition()); + } + + const auto ring_count = reader.Read(); + if (ring_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + if (ring_idx > 0) { + writer.Write(", "); + } + const auto vert_count = reader.Read(); + if (vert_count == 0) { + writer.Write("EMPTY"); + continue; + } + writer.Write('('); + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + if (vert_idx > 0) { + writer.Write(", "); + } + for (uint32_t d_idx = 0; d_idx < dims; d_idx++) { + if (d_idx > 0) { + writer.Write(' '); + } + auto value = reader.Read(); + writer.Write(value); + } + } + writer.Write(')'); + } + writer.Write(')'); + } + writer.Write(')'); + } break; + case GeometryType::GEOMETRYCOLLECTION: { + writer.Write("GEOMETRYCOLLECTION"); + writer.Write(flag_str); + const auto part_count = reader.Read(); + if (part_count == 0) { + writer.Write("EMPTY"); + return; + } + writer.Write('('); + for (uint32_t part_idx = 0; part_idx < part_count; part_idx++) { + if (part_idx > 0) { + writer.Write(", "); + } + // Recursively parse the geometry inside the collection + ToStringRecursive(reader, writer, depth + 1, has_z, has_m); + } + writer.Write(')'); + } break; + default: + throw InvalidInputException("Unsupported geometry type %d in WKB", static_cast(type)); + } +} + +} // namespace + +} // namespace duckdb + +//---------------------------------------------------------------------------------------------------------------------- +// Public interface +//---------------------------------------------------------------------------------------------------------------------- +namespace duckdb { + +bool Geometry::FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict) { + TextReader reader(wkt_text.GetData(), static_cast(wkt_text.GetSize())); + BlobWriter writer; + + FromStringRecursive(reader, writer, 0, false, false); + + const auto &buffer = writer.GetBuffer(); + result = StringVector::AddStringOrBlob(result_vector, buffer.data(), buffer.size()); + return true; +} + +string_t Geometry::ToString(Vector &result, const string_t &geom) { + BlobReader reader(geom.GetData(), static_cast(geom.GetSize())); + TextWriter writer; + + ToStringRecursive(reader, writer, 0, false, false); + + // Convert the buffer to string_t + const auto &buffer = writer.GetBuffer(); + return StringVector::AddString(result, buffer.data(), buffer.size()); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp index ffd4a2b4c..9329dbb58 100644 --- a/src/duckdb/src/common/types/row/tuple_data_collection.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_collection.cpp @@ -18,6 +18,10 @@ TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, shared_p Initialize(); } +TupleDataCollection::TupleDataCollection(ClientContext &context, shared_ptr layout_ptr) + : TupleDataCollection(BufferManager::GetBufferManager(context), std::move(layout_ptr)) { +} + TupleDataCollection::~TupleDataCollection() { } @@ -562,11 +566,16 @@ void TupleDataCollection::InitializeScan(TupleDataParallelScanState &state, vect InitializeScan(state.scan_state, std::move(column_ids), properties); } -idx_t TupleDataCollection::FetchChunk(TupleDataScanState &state, const idx_t segment_idx, const idx_t chunk_idx, - const bool init_heap) { - auto &segment = *segments[segment_idx]; - allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, chunk_idx, init_heap); - return segment.chunks[chunk_idx].count; +idx_t TupleDataCollection::FetchChunk(TupleDataScanState &state, idx_t chunk_idx, bool init_heap) { + for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { + auto &segment = *segments[segment_idx]; + if (chunk_idx < segment.ChunkCount()) { + segment.allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, chunk_idx, init_heap); + return segment.chunks[chunk_idx].count; + } + chunk_idx -= segment.ChunkCount(); + } + throw InternalException("Chunk index out of in TupleDataCollection::FetchChunk"); } bool TupleDataCollection::Scan(TupleDataScanState &state, DataChunk &result) { diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp index 2bef3a82d..a232904e1 100644 --- a/src/duckdb/src/common/types/value.cpp +++ b/src/duckdb/src/common/types/value.cpp @@ -919,6 +919,13 @@ Value Value::BIGNUM(const string &data) { return result; } +Value Value::GEOMETRY(const_data_ptr_t data, idx_t len) { + Value result(LogicalTypeId::GEOMETRY); + result.is_null = false; + result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); + return result; +} + Value Value::BLOB(const string &data) { Value result(LogicalType::BLOB); result.is_null = false; diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp index ad27b162d..c4b3c5e02 100644 --- a/src/duckdb/src/common/types/vector.cpp +++ b/src/duckdb/src/common/types/vector.cpp @@ -724,6 +724,10 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { auto str = reinterpret_cast(data)[index]; return Value::BIGNUM(const_data_ptr_cast(str.data.GetData()), str.data.GetSize()); } + case LogicalTypeId::GEOMETRY: { + auto str = reinterpret_cast(data)[index]; + return Value::GEOMETRY(const_data_ptr_cast(str.GetData()), str.GetSize()); + } case LogicalTypeId::AGGREGATE_STATE: { auto str = reinterpret_cast(data)[index]; return Value::AGGREGATE_STATE(vector->GetType(), const_data_ptr_cast(str.GetData()), str.GetSize()); diff --git a/src/duckdb/src/common/virtual_file_system.cpp b/src/duckdb/src/common/virtual_file_system.cpp index 7940d0120..dfb542ec7 100644 --- a/src/duckdb/src/common/virtual_file_system.cpp +++ b/src/duckdb/src/common/virtual_file_system.cpp @@ -17,6 +17,7 @@ VirtualFileSystem::VirtualFileSystem(unique_ptr &&inner) : default_f unique_ptr VirtualFileSystem::OpenFileExtended(const OpenFileInfo &file, FileOpenFlags flags, optional_ptr opener) { + auto compression = flags.Compression(); if (compression == FileCompressionType::AUTO_DETECT) { // auto-detect compression settings based on file name @@ -34,8 +35,9 @@ unique_ptr VirtualFileSystem::OpenFileExtended(const OpenFileInfo &f } } // open the base file handle in UNCOMPRESSED mode + flags.SetCompression(FileCompressionType::UNCOMPRESSED); - auto file_handle = FindFileSystem(file.path).OpenFile(file, flags, opener); + auto file_handle = FindFileSystem(file.path, opener).OpenFile(file, flags, opener); if (!file_handle) { return nullptr; } @@ -111,7 +113,7 @@ void VirtualFileSystem::RemoveDirectory(const string &directory, optional_ptr &callback, optional_ptr opener) { - return FindFileSystem(directory).ListFiles(directory, callback, opener); + return FindFileSystem(directory, opener).ListFiles(directory, callback, opener); } void VirtualFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { @@ -119,7 +121,7 @@ void VirtualFileSystem::MoveFile(const string &source, const string &target, opt } bool VirtualFileSystem::FileExists(const string &filename, optional_ptr opener) { - return FindFileSystem(filename).FileExists(filename, opener); + return FindFileSystem(filename, opener).FileExists(filename, opener); } bool VirtualFileSystem::IsPipe(const string &filename, optional_ptr opener) { @@ -139,7 +141,7 @@ string VirtualFileSystem::PathSeparator(const string &path) { } vector VirtualFileSystem::Glob(const string &path, FileOpener *opener) { - return FindFileSystem(path).Glob(path, opener); + return FindFileSystem(path, opener).Glob(path, opener); } void VirtualFileSystem::RegisterSubSystem(unique_ptr fs) { @@ -224,16 +226,61 @@ bool VirtualFileSystem::SubSystemIsDisabled(const string &name) { return disabled_file_systems.find(name) != disabled_file_systems.end(); } +FileSystem &VirtualFileSystem::FindFileSystem(const string &path, optional_ptr opener) { + return FindFileSystem(path, FileOpener::TryGetDatabase(opener)); +} + +FileSystem &VirtualFileSystem::FindFileSystem(const string &path, optional_ptr db_instance) { + auto fs = FindFileSystemInternal(path); + + if (!fs && db_instance) { + string required_extension; + + for (const auto &entry : EXTENSION_FILE_PREFIXES) { + if (StringUtil::StartsWith(path, entry.name)) { + required_extension = entry.extension; + } + } + if (!required_extension.empty() && db_instance && !db_instance->ExtensionIsLoaded(required_extension)) { + auto &dbconfig = DBConfig::GetConfig(*db_instance); + if (!ExtensionHelper::CanAutoloadExtension(required_extension) || + !dbconfig.options.autoload_known_extensions) { + auto error_message = "File " + path + " requires the extension " + required_extension + " to be loaded"; + error_message = + ExtensionHelper::AddExtensionInstallHintToErrorMsg(*db_instance, error_message, required_extension); + throw MissingExtensionException(error_message); + } + // an extension is required to read this file, but it is not loaded - try to load it + ExtensionHelper::AutoLoadExtension(*db_instance, required_extension); + } + + // Retry after having autoloaded + fs = FindFileSystem(path); + } + + if (!fs) { + fs = default_fs; + } + if (!disabled_file_systems.empty() && disabled_file_systems.find(fs->GetName()) != disabled_file_systems.end()) { + throw PermissionException("File system %s has been disabled by configuration", fs->GetName()); + } + return *fs; +} + FileSystem &VirtualFileSystem::FindFileSystem(const string &path) { - auto &fs = FindFileSystemInternal(path); - if (!disabled_file_systems.empty() && disabled_file_systems.find(fs.GetName()) != disabled_file_systems.end()) { - throw PermissionException("File system %s has been disabled by configuration", fs.GetName()); + auto fs = FindFileSystemInternal(path); + if (!fs) { + fs = default_fs; + } + if (!disabled_file_systems.empty() && disabled_file_systems.find(fs->GetName()) != disabled_file_systems.end()) { + throw PermissionException("File system %s has been disabled by configuration", fs->GetName()); } - return fs; + return *fs; } -FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { +optional_ptr VirtualFileSystem::FindFileSystemInternal(const string &path) { FileSystem *fs = nullptr; + for (auto &sub_system : sub_systems) { if (sub_system->CanHandleFile(path)) { if (sub_system->IsManuallySet()) { @@ -245,7 +292,9 @@ FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { if (fs) { return *fs; } - return *default_fs; + + // We could use default_fs, that's on the caller + return nullptr; } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/art_merger.cpp b/src/duckdb/src/execution/index/art/art_merger.cpp index 70781cbfb..61d2ec317 100644 --- a/src/duckdb/src/execution/index/art/art_merger.cpp +++ b/src/duckdb/src/execution/index/art/art_merger.cpp @@ -217,9 +217,6 @@ void ARTMerger::MergeNodeAndPrefix(Node &node, Node &prefix, const GateStatus pa auto child = node.GetChildMutable(art, byte); // Reduce the prefix to the bytes after pos. - // We always reduce by at least one byte, - // thus, if the prefix was a gate, it no longer is. - prefix.SetGateStatus(GateStatus::GATE_NOT_SET); Prefix::Reduce(art, prefix, pos); if (child) { diff --git a/src/duckdb/src/execution/index/art/prefix.cpp b/src/duckdb/src/execution/index/art/prefix.cpp index 00e94967a..1d7861135 100644 --- a/src/duckdb/src/execution/index/art/prefix.cpp +++ b/src/duckdb/src/execution/index/art/prefix.cpp @@ -100,6 +100,10 @@ void Prefix::Reduce(ART &art, Node &node, const idx_t pos) { D_ASSERT(node.HasMetadata()); D_ASSERT(pos < Count(art)); + // We always reduce by at least one byte, + // thus, if the prefix was a gate, it no longer is. + node.SetGateStatus(GateStatus::GATE_NOT_SET); + Prefix prefix(art, node); if (pos == idx_t(prefix.data[Count(art)] - 1)) { auto next = *prefix.ptr; diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp index 719781992..54c224b9f 100644 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -8,6 +8,7 @@ #include "duckdb/execution/operator/join/outer_join_marker.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parallel/event.hpp" +#include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/thread_context.hpp" namespace duckdb { @@ -74,40 +75,54 @@ PhysicalAsOfJoin::PhysicalAsOfJoin(PhysicalPlan &physical_plan, LogicalCompariso //===--------------------------------------------------------------------===// class AsOfGlobalSinkState : public GlobalSinkState { public: - AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) - : rhs_sink(context, op.rhs_partitions, op.rhs_orders, op.children[1].get().GetTypes(), {}, - op.estimated_cardinality), - is_outer(IsRightOuterJoin(op.join_type)), has_null(false) { + using PartitionSinkPtr = unique_ptr; + using PartitionMarkers = vector; + using LocalBuffers = vector>; + + AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) : is_outer(IsRightOuterJoin(op.join_type)) { + // Set up partitions for both sides + partition_sinks.reserve(2); + const vector> partitions_stats; + auto &lhs = op.children[0].get(); + auto sink = make_uniq(context, op.lhs_partitions, op.lhs_orders, lhs.GetTypes(), + partitions_stats, lhs.estimated_cardinality); + partition_sinks.emplace_back(std::move(sink)); + auto &rhs = op.children[1].get(); + sink = make_uniq(context, op.rhs_partitions, op.rhs_orders, rhs.GetTypes(), + partitions_stats, rhs.estimated_cardinality); + partition_sinks.emplace_back(std::move(sink)); + + local_buffers.resize(2); } idx_t Count() const { - return rhs_sink.count; + return partition_sinks[child]->count; } PartitionLocalSinkState *RegisterBuffer(ClientContext &context) { lock_guard guard(lock); - lhs_buffers.emplace_back(make_uniq(context, *lhs_sink)); - return lhs_buffers.back().get(); + auto &buffers = local_buffers[child]; + buffers.emplace_back(make_uniq(context, *partition_sinks[child])); + return buffers.back().get(); } - PartitionGlobalSinkState rhs_sink; - - // One per partition + //! The child that is being materialised (right/1 then left/0) + size_t child = 1; + //! The child's partitioning buffer + vector partition_sinks; + //! Whether the right side is outer const bool is_outer; + //! The right outer join markers (one per partition) vector right_outers; - bool has_null; - - // Left side buffering - unique_ptr lhs_sink; mutex lock; - vector> lhs_buffers; + vector local_buffers; }; class AsOfLocalSinkState : public LocalSinkState { public: - explicit AsOfLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : local_partition(context, gstate_p) { + AsOfLocalSinkState(ClientContext &context, AsOfGlobalSinkState &gsink) + : local_partition(context, *gsink.partition_sinks[gsink.child]) { } void Sink(DataChunk &input_chunk) { @@ -126,9 +141,8 @@ unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext & } unique_ptr PhysicalAsOfJoin::GetLocalSinkState(ExecutionContext &context) const { - // We only sink the RHS auto &gsink = sink_state->Cast(); - return make_uniq(context.client, gsink.rhs_sink); + return make_uniq(context.client, gsink); } SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { @@ -152,176 +166,34 @@ SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, Cl OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); - // The data is all in so we can initialise the left partitioning. - const vector> partitions_stats; - gstate.lhs_sink = make_uniq(context, lhs_partitions, lhs_orders, - children[0].get().GetTypes(), partitions_stats, 0U); - gstate.lhs_sink->SyncPartitioning(gstate.rhs_sink); - - // Find the first group to sort - if (!gstate.rhs_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(gstate.rhs_sink, pipeline, *this); - event.InsertEvent(std::move(new_event)); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class AsOfGlobalState : public GlobalOperatorState { -public: - explicit AsOfGlobalState(AsOfGlobalSinkState &gsink) { - // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple - auto &rhs_partition = gsink.rhs_sink; - auto &right_outers = gsink.right_outers; - right_outers.reserve(rhs_partition.hash_groups.size()); - for (const auto &hash_group : rhs_partition.hash_groups) { - right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); - right_outers.back().Initialize(hash_group->count); - } - } -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalOperatorState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); -} + // The data is all in so we can synchronise the left partitioning. + auto result = SinkFinalizeType::READY; + auto &partition_sink = *gstate.partition_sinks[gstate.child]; + if (gstate.child == 1) { + gstate.partition_sinks[1 - gstate.child]->SyncPartitioning(partition_sink); -class AsOfLocalState : public CachingOperatorState { -public: - AsOfLocalState(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), lhs_executor(context), - left_outer(IsLeftOuterJoin(op.join_type)), fetch_next_left(true) { - lhs_keys.Initialize(allocator, op.join_key_types); - for (const auto &cond : op.conditions) { - lhs_executor.AddExpression(*cond.left); + // Find the first group to sort + if (!partition_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { + // Empty input! + result = SinkFinalizeType::NO_OUTPUT_POSSIBLE; } - - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); - lhs_sel.Initialize(); - left_outer.Initialize(STANDARD_VECTOR_SIZE); - - auto &gsink = op.sink_state->Cast(); - lhs_partition_sink = gsink.RegisterBuffer(context); } - bool Sink(DataChunk &input); - OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk); - - ClientContext &context; - Allocator &allocator; - const PhysicalAsOfJoin &op; - - ExpressionExecutor lhs_executor; - DataChunk lhs_keys; - ValidityMask lhs_valid_mask; - SelectionVector lhs_sel; - DataChunk lhs_payload; - - OuterJoinMarker left_outer; - bool fetch_next_left; - - optional_ptr lhs_partition_sink; -}; - -bool AsOfLocalState::Sink(DataChunk &input) { - // Compute the join keys - lhs_keys.Reset(); - lhs_executor.Execute(input, lhs_keys); - lhs_keys.Flatten(); - - // Combine the NULLs - const auto count = input.size(); - lhs_valid_mask.Reset(); - for (auto col_idx : op.null_sensitive) { - auto &col = lhs_keys.data[col_idx]; - UnifiedVectorFormat unified; - col.ToUnifiedFormat(count, unified); - lhs_valid_mask.Combine(unified.validity, count); - } - - // Convert the mask to a selection vector - // and mark all the rows that cannot match for early return. - idx_t lhs_valid = 0; - const auto entry_count = lhs_valid_mask.EntryCount(count); - idx_t base_idx = 0; - left_outer.Reset(); - for (idx_t entry_idx = 0; entry_idx < entry_count;) { - const auto validity_entry = lhs_valid_mask.GetValidityEntry(entry_idx++); - const auto next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - for (; base_idx < next; ++base_idx) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } else if (ValidityMask::NoneValid(validity_entry)) { - base_idx = next; - } else { - const auto start = base_idx; - for (; base_idx < next; ++base_idx) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } - } - } - - // Slice the keys to the ones we can match - lhs_payload.Reset(); - if (lhs_valid == count) { - lhs_payload.Reference(input); - lhs_payload.SetCardinality(input); - } else { - lhs_payload.Slice(input, lhs_sel, lhs_valid); - lhs_payload.SetCardinality(lhs_valid); - - // Flush the ones that can't match - fetch_next_left = false; + // Schedule all the sorts for maximum thread utilisation + if (partition_sink.HasMergeTasks()) { + auto new_event = make_shared_ptr(partition_sink, pipeline, *this); + event.InsertEvent(std::move(new_event)); } - lhs_partition_sink->Sink(lhs_payload); + // Switch sides + gstate.child = 1 - gstate.child; - return false; -} - -OperatorResultType AsOfLocalState::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk) { - input.Verify(); - Sink(input); - - // If there were any unmatchable rows, return them now so we can forget about them. - if (!fetch_next_left) { - fetch_next_left = true; - left_outer.ConstructLeftJoinResult(input, chunk); - left_outer.Reset(); - } - - // Just keep asking for data and buffering it - return OperatorResultType::NEED_MORE_INPUT; + return result; } OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, GlobalOperatorState &gstate, OperatorState &lstate_p) const { - auto &gsink = sink_state->Cast(); - auto &lstate = lstate_p.Cast(); - - if (gsink.rhs_sink.count == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gsink.has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - return lstate.ExecuteInternal(context, input, chunk); + return OperatorResultType::FINISHED; } //===--------------------------------------------------------------------===// @@ -356,8 +228,7 @@ class AsOfProbeBuffer { return !fetch_next_left || (lhs_scanner && lhs_scanner->Remaining()); } - ClientContext &context; - Allocator &allocator; + ClientContext &client; const PhysicalAsOfJoin &op; BufferManager &buffer_manager; const bool force_external; @@ -365,13 +236,18 @@ class AsOfProbeBuffer { Orders lhs_orders; // LHS scanning - SelectionVector lhs_sel; + SelectionVector lhs_scan_sel; optional_ptr left_hash; OuterJoinMarker left_outer; unique_ptr left_itr; unique_ptr lhs_scanner; + DataChunk lhs_scanned; DataChunk lhs_payload; + ExpressionExecutor lhs_executor; + DataChunk lhs_keys; + ValidityMask lhs_valid_mask; idx_t left_group = 0; + SelectionVector lhs_match_sel; // RHS scanning optional_ptr right_hash; @@ -389,21 +265,27 @@ class AsOfProbeBuffer { bool fetch_next_left; }; -AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(IsExternal(context)), - memory_per_thread(op.GetMaxThreadMemory(context)), left_outer(IsLeftOuterJoin(op.join_type)), filterer(context), - fetch_next_left(true) { +AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op) + : client(client), op(op), buffer_manager(BufferManager::GetBufferManager(client)), + force_external(IsExternal(client)), memory_per_thread(op.GetMaxThreadMemory(client)), + left_outer(IsLeftOuterJoin(op.join_type)), lhs_executor(client), filterer(client), fetch_next_left(true) { vector> partition_stats; Orders partitions; // Not used. PartitionGlobalSinkState::GenerateOrderings(partitions, lhs_orders, op.lhs_partitions, op.lhs_orders, partition_stats); + lhs_keys.Initialize(client, op.join_key_types); + for (const auto &cond : op.conditions) { + lhs_executor.AddExpression(*cond.left); + } + // We sort the row numbers of the incoming block, not the rows - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); - rhs_payload.Initialize(allocator, op.children[1].get().GetTypes()); + lhs_scanned.Initialize(client, op.children[0].get().GetTypes()); + lhs_payload.Initialize(client, op.children[0].get().GetTypes()); + rhs_payload.Initialize(client, op.children[1].get().GetTypes()); - lhs_sel.Initialize(); + lhs_scan_sel.Initialize(); + lhs_match_sel.Initialize(); left_outer.Initialize(STANDARD_VECTOR_SIZE); if (op.predicate) { @@ -415,17 +297,22 @@ AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { auto &gsink = op.sink_state->Cast(); - auto &lhs_sink = *gsink.lhs_sink; - left_group = lhs_sink.bin_groups[scan_bin]; - // Always set right_group too for memory management - auto &rhs_sink = gsink.rhs_sink; + auto &rhs_sink = *gsink.partition_sinks[1]; if (scan_bin < rhs_sink.bin_groups.size()) { right_group = rhs_sink.bin_groups[scan_bin]; } else { right_group = rhs_sink.bin_groups.size(); } + auto &lhs_sink = *gsink.partition_sinks[0]; + left_group = lhs_sink.bin_groups[scan_bin]; + if (scan_bin < lhs_sink.bin_groups.size()) { + left_group = lhs_sink.bin_groups[scan_bin]; + } else { + left_group = lhs_sink.bin_groups.size(); + } + if (left_group >= lhs_sink.bin_groups.size()) { return; } @@ -457,13 +344,15 @@ void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { left_itr = make_uniq(left_sort, iterator_comp); // We are only probing the corresponding right side bin, which may be empty - // If they are empty, we leave the iterator as null so we can emit left matches + // If it is empty, we leave the iterator as null so we can emit left matches if (right_group < rhs_sink.bin_groups.size()) { right_hash = rhs_sink.hash_groups[right_group].get(); right_outer = gsink.right_outers.data() + right_group; auto &right_sort = *(right_hash->global_sort); - right_itr = make_uniq(right_sort, iterator_comp); - rhs_scanner = make_uniq(right_sort, false); + if (!right_sort.sorted_blocks.empty()) { + right_itr = make_uniq(right_sort, iterator_comp); + rhs_scanner = make_uniq(right_sort, false); + } } } @@ -473,9 +362,55 @@ bool AsOfProbeBuffer::NextLeft() { } // Scan the next sorted chunk - lhs_payload.Reset(); + lhs_scanned.Reset(); left_itr->SetIndex(lhs_scanner->Scanned()); - lhs_scanner->Scan(lhs_payload); + lhs_scanner->Scan(lhs_scanned); + + // Compute the join keys + lhs_keys.Reset(); + lhs_executor.Execute(lhs_scanned, lhs_keys); + lhs_keys.Flatten(); + + // Combine the NULLs + const auto count = lhs_scanned.size(); + lhs_valid_mask.Reset(); + for (auto col_idx : op.null_sensitive) { + auto &col = lhs_keys.data[col_idx]; + UnifiedVectorFormat unified; + col.ToUnifiedFormat(count, unified); + lhs_valid_mask.Combine(unified.validity, count); + } + + // Convert the mask to a selection vector + // and mark all the rows that cannot match for early return. + idx_t lhs_valid = 0; + const auto entry_count = lhs_valid_mask.EntryCount(count); + idx_t base_idx = 0; + for (idx_t entry_idx = 0; entry_idx < entry_count;) { + const auto validity_entry = lhs_valid_mask.GetValidityEntry(entry_idx++); + const auto next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); + if (ValidityMask::AllValid(validity_entry)) { + for (; base_idx < next; ++base_idx) { + lhs_scan_sel.set_index(lhs_valid++, base_idx); + } + } else if (ValidityMask::NoneValid(validity_entry)) { + base_idx = next; + } else { + const auto start = base_idx; + for (; base_idx < next; ++base_idx) { + if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { + lhs_scan_sel.set_index(lhs_valid++, base_idx); + } + } + } + } + + // Slice the keys to the ones we can match + if (lhs_valid < count) { + lhs_payload.Slice(lhs_scanned, lhs_scan_sel, lhs_valid); + } else { + lhs_payload.Reference(lhs_scanned); + } return true; } @@ -488,7 +423,7 @@ void AsOfProbeBuffer::EndLeftScan() { rhs_scanner.reset(); right_outer = nullptr; - auto &rhs_sink = gsink.rhs_sink; + auto &rhs_sink = *gsink.partition_sinks[1]; if (!gsink.is_outer && right_group < rhs_sink.bin_groups.size()) { rhs_sink.hash_groups[right_group].reset(); } @@ -497,7 +432,7 @@ void AsOfProbeBuffer::EndLeftScan() { left_itr.reset(); lhs_scanner.reset(); - auto &lhs_sink = *gsink.lhs_sink; + auto &lhs_sink = *gsink.partition_sinks[0]; if (left_group < lhs_sink.bin_groups.size()) { lhs_sink.hash_groups[left_group].reset(); } @@ -564,14 +499,10 @@ void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { if (matches) { matches[i] = first; } - lhs_sel.set_index(lhs_match_count++, i); + lhs_match_sel.set_index(lhs_match_count++, i); } } -unique_ptr PhysicalAsOfJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join bool found_match[STANDARD_VECTOR_SIZE] = {false}; @@ -596,7 +527,7 @@ void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &c ResolveJoin(nullptr, matches); for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = lhs_sel[i]; + const auto idx = lhs_match_sel[i]; const auto match_pos = matches[idx]; // Skip to the range containing the match while (match_pos >= rhs_scanner->Scanned()) { @@ -616,10 +547,10 @@ void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &c // Slice the left payload into the result for (column_t i = 0; i < lhs_payload.ColumnCount(); ++i) { - chunk.data[i].Slice(lhs_payload.data[i], lhs_sel, lhs_match_count); + chunk.data[i].Slice(lhs_payload.data[i], lhs_match_sel, lhs_match_count); } chunk.SetCardinality(lhs_match_count); - auto match_sel = &lhs_sel; + auto match_sel = &lhs_match_sel; if (filterer.expressions.size() == 1) { lhs_match_count = filterer.SelectExpression(chunk, filter_sel); chunk.Slice(filter_sel, lhs_match_count); @@ -646,7 +577,7 @@ void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { if (left_outer.Enabled()) { // left join: before we move to the next chunk, see if we need to output any vectors that didn't // have a match found - left_outer.ConstructLeftJoinResult(lhs_payload, chunk); + left_outer.ConstructLeftJoinResult(lhs_scanned, chunk); left_outer.Reset(); } return; @@ -678,39 +609,31 @@ void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { class AsOfGlobalSourceState : public GlobalSourceState { public: explicit AsOfGlobalSourceState(AsOfGlobalSinkState &gsink_p) - : gsink(gsink_p), next_combine(0), combined(0), merged(0), mergers(0), next_left(0), flushed(0), next_right(0) { - } - - PartitionGlobalMergeStates &GetMergeStates() { - lock_guard guard(lock); - if (!merge_states) { - merge_states = make_uniq(*gsink.lhs_sink); + : gsink(gsink_p), next_left(0), flushed(0), next_right(0) { + + if (gsink.child == 1) { + // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple + auto &rhs_partition = *gsink.partition_sinks[gsink.child]; + auto &right_outers = gsink.right_outers; + right_outers.reserve(rhs_partition.hash_groups.size()); + for (const auto &hash_group : rhs_partition.hash_groups) { + right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); + right_outers.back().Initialize(hash_group->count); + } } - return *merge_states; } AsOfGlobalSinkState &gsink; - //! The next buffer to combine - atomic next_combine; - //! The number of combined buffers - atomic combined; - //! The number of combined buffers - atomic merged; - //! The number of combined buffers - atomic mergers; //! The next buffer to flush atomic next_left; //! The number of flushed buffers atomic flushed; //! The right outer output read position. atomic next_right; - //! The merge handler - mutex lock; - unique_ptr merge_states; public: idx_t MaxThreads() override { - return gsink.lhs_buffers.size(); + return gsink.local_buffers[0].size(); } }; @@ -723,16 +646,12 @@ class AsOfLocalSourceState : public LocalSourceState { public: using HashGroupPtr = unique_ptr; - AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, ClientContext &client_p); - - // Return true if we were not interrupted (another thread died) - bool CombineLeftPartitions(); - bool MergeLeftPartitions(); + AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op); idx_t BeginRightScan(const idx_t hash_bin); AsOfGlobalSourceState &gsource; - ClientContext &client; + ExecutionContext &context; //! The left side partition being probed AsOfProbeBuffer probe_buffer; @@ -742,51 +661,26 @@ class AsOfLocalSourceState : public LocalSourceState { HashGroupPtr hash_group; //! The read cursor unique_ptr scanner; - //! Pointer to the matches - const bool *found_match = {}; + //! Pointer to the right marker + const bool *rhs_matches = {}; }; -AsOfLocalSourceState::AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, - ClientContext &client_p) - : gsource(gsource), client(client_p), probe_buffer(gsource.gsink.lhs_sink->context, op) { - gsource.mergers++; -} - -bool AsOfLocalSourceState::CombineLeftPartitions() { - const auto buffer_count = gsource.gsink.lhs_buffers.size(); - while (gsource.combined < buffer_count && !client.interrupted) { - const auto next_combine = gsource.next_combine++; - if (next_combine < buffer_count) { - gsource.gsink.lhs_buffers[next_combine]->Combine(); - ++gsource.combined; - } else { - TaskScheduler::GetScheduler(client).YieldThread(); - } - } - - return !client.interrupted; -} - -bool AsOfLocalSourceState::MergeLeftPartitions() { - PartitionGlobalMergeStates::Callback local_callback; - PartitionLocalMergeState local_merge(*gsource.gsink.lhs_sink); - gsource.GetMergeStates().ExecuteTask(local_merge, local_callback); - gsource.merged++; - while (gsource.merged < gsource.mergers && !client.interrupted) { - TaskScheduler::GetScheduler(client).YieldThread(); - } - return !client.interrupted; +AsOfLocalSourceState::AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, + const PhysicalAsOfJoin &op) + : gsource(gsource), context(context), probe_buffer(context.client, op) { } idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { hash_bin = hash_bin_p; - hash_group = std::move(gsource.gsink.rhs_sink.hash_groups[hash_bin]); + auto &rhs_sink = *gsource.gsink.partition_sinks[1]; + hash_group = std::move(rhs_sink.hash_groups[hash_bin]); if (hash_group->global_sort->sorted_blocks.empty()) { return 0; } scanner = make_uniq(*hash_group->global_sort); - found_match = gsource.gsink.right_outers[hash_bin].GetMatches(); + + rhs_matches = gsource.gsink.right_outers[hash_bin].GetMatches(); return scanner->Remaining(); } @@ -794,28 +688,18 @@ idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { unique_ptr PhysicalAsOfJoin::GetLocalSourceState(ExecutionContext &context, GlobalSourceState &gstate) const { auto &gsource = gstate.Cast(); - return make_uniq(gsource, *this, context.client); + return make_uniq(context, gsource, *this); } SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { auto &gsource = input.global_state.Cast(); auto &lsource = input.local_state.Cast(); - auto &rhs_sink = gsource.gsink.rhs_sink; + auto &rhs_sink = *gsource.gsink.partition_sinks[1]; auto &client = context.client; - // Step 1: Combine the partitions - if (!lsource.CombineLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 2: Sort on all threads - if (!lsource.MergeLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 3: Join the partitions - auto &lhs_sink = *gsource.gsink.lhs_sink; + // Step 1: Join the partitions + auto &lhs_sink = *gsource.gsink.partition_sinks[0]; const auto left_bins = lhs_sink.grouping_data ? lhs_sink.grouping_data->GetPartitions().size() : 1; while (gsource.flushed < left_bins) { // Make sure we have something to flush @@ -847,7 +731,7 @@ SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk } } - // Step 4: Emit right join matches + // Step 2: Emit right join matches if (!IsRightOuterJoin(join_type)) { return SourceResultType::FINISHED; } @@ -856,7 +740,7 @@ SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk const auto right_groups = hash_groups.size(); DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), rhs_sink.payload_types); + rhs_chunk.Initialize(context.client, rhs_sink.payload_types); SelectionVector rsel(STANDARD_VECTOR_SIZE); while (chunk.size() == 0) { @@ -885,10 +769,10 @@ SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk } // figure out which tuples didn't find a match in the RHS - auto found_match = lsource.found_match; + auto rhs_matches = lsource.rhs_matches; idx_t result_count = 0; for (idx_t i = 0; i < count; i++) { - if (!found_match[rhs_position + i]) { + if (!rhs_matches[rhs_position + i]) { rsel.set_index(result_count++, i); } } @@ -912,4 +796,31 @@ SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; } +//===--------------------------------------------------------------------===// +// Pipeline Construction +//===--------------------------------------------------------------------===// +void PhysicalAsOfJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { + D_ASSERT(children.size() == 2); + if (meta_pipeline.HasRecursiveCTE()) { + throw NotImplementedException("AsOf joins are not supported in recursive CTEs yet"); + } + + // becomes a source after both children fully sink their data + meta_pipeline.GetState().SetPipelineSource(current, *this); + + // Create one child meta pipeline that will hold the LHS and RHS pipelines + auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); + + // Build out RHS first because that is the order the join planner expects. + auto rhs_pipeline = child_meta_pipeline.GetBasePipeline(); + children[1].get().BuildPipelines(*rhs_pipeline, child_meta_pipeline); + + // Build out LHS + auto &lhs_pipeline = child_meta_pipeline.CreatePipeline(); + children[0].get().BuildPipelines(lhs_pipeline, child_meta_pipeline); + + // Despite having the same sink, LHS and everything created after it need their own (same) PipelineFinishEvent + child_meta_pipeline.AddFinishEvent(lhs_pipeline); +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp index 90aba4722..24981993e 100644 --- a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp +++ b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp @@ -1,15 +1,8 @@ -#include - #include "duckdb/execution/operator/join/physical_iejoin.hpp" #include "duckdb/common/atomic.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/thread.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parallel/event.hpp" #include "duckdb/parallel/meta_pipeline.hpp" @@ -17,6 +10,8 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include + namespace duckdb { PhysicalIEJoin::PhysicalIEJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, @@ -82,17 +77,15 @@ class IEJoinGlobalState : public GlobalSinkState { public: IEJoinGlobalState(ClientContext &context, const PhysicalIEJoin &op) : child(1) { tables.resize(2); - RowLayout lhs_layout; - lhs_layout.Initialize(op.children[0].get().GetTypes()); + const auto &lhs_types = op.children[0].get().GetTypes(); vector lhs_order; lhs_order.emplace_back(op.lhs_orders[0].Copy()); - tables[0] = make_uniq(context, lhs_order, lhs_layout, op); + tables[0] = make_uniq(context, lhs_order, lhs_types, op); - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1].get().GetTypes()); + const auto &rhs_types = op.children[1].get().GetTypes(); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - tables[1] = make_uniq(context, rhs_order, rhs_layout, op); + tables[1] = make_uniq(context, rhs_order, rhs_types, op); if (op.filter_pushdown) { skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); @@ -100,11 +93,18 @@ class IEJoinGlobalState : public GlobalSinkState { } } - void Sink(DataChunk &input, IEJoinLocalState &lstate); - void Finalize(Pipeline &pipeline, Event &event) { + void Sink(ExecutionContext &context, DataChunk &input, IEJoinLocalState &lstate); + + void Finalize(ClientContext &client, InterruptState &interrupt) { // Sort the current input child D_ASSERT(child < tables.size()); - tables[child]->Finalize(pipeline, event); + tables[child]->Finalize(client, interrupt); + }; + + void Materialize(Pipeline &pipeline, Event &event) { + // Sort the current input child + D_ASSERT(child < tables.size()); + tables[child]->Materialize(pipeline, event); child = child ? 0 : 2; skip_filter_pushdown = true; }; @@ -123,8 +123,8 @@ class IEJoinLocalState : public LocalSinkState { public: using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; - IEJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, IEJoinGlobalState &gstate) - : table(context, op, gstate.child) { + IEJoinLocalState(ExecutionContext &context, const PhysicalRangeJoin &op, IEJoinGlobalState &gstate) + : table(context, *gstate.tables[gstate.child], gstate.child) { if (op.filter_pushdown) { local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); @@ -144,32 +144,23 @@ unique_ptr PhysicalIEJoin::GetGlobalSinkState(ClientContext &co unique_ptr PhysicalIEJoin::GetLocalSinkState(ExecutionContext &context) const { auto &ie_sink = sink_state->Cast(); - return make_uniq(context.client, *this, ie_sink); + return make_uniq(context, *this, ie_sink); } -void IEJoinGlobalState::Sink(DataChunk &input, IEJoinLocalState &lstate) { - auto &table = *tables[child]; - auto &global_sort_state = table.global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - +void IEJoinGlobalState::Sink(ExecutionContext &context, DataChunk &input, IEJoinLocalState &lstate) { // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table.memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } + lstate.table.Sink(context, input); } SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - if (gstate.child == 0 && gstate.tables[1]->global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + if (gstate.child == 0 && gstate.tables[1]->Count() == 0 && EmptyResultIfRHSIsEmpty()) { return SinkResultType::FINISHED; } - gstate.Sink(chunk, lstate); + gstate.Sink(context, chunk, lstate); if (filter_pushdown && !gstate.skip_filter_pushdown) { filter_pushdown->Sink(lstate.table.keys, *lstate.local_filter_state); @@ -181,7 +172,7 @@ SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.tables[gstate.child]->Combine(lstate.table); + gstate.tables[gstate.child]->Combine(context, lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); context.thread.profiler.Flush(*this); @@ -197,14 +188,13 @@ SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, Operato //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, +SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); if (filter_pushdown && !gstate.skip_filter_pushdown) { - (void)filter_pushdown->Finalize(context, nullptr, *gstate.global_filter_state, *this); + (void)filter_pushdown->Finalize(client, nullptr, *gstate.global_filter_state, *this); } auto &table = *gstate.tables[gstate.child]; - auto &global_sort_state = table.global_sort_state; if ((gstate.child == 1 && PropagatesBuildSide(join_type)) || (gstate.child == 0 && IsLeftOuterJoin(join_type))) { // for FULL/LEFT/RIGHT OUTER JOIN, initialize found_match to false for every tuple @@ -212,15 +202,18 @@ SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, Clie } SinkFinalizeType res; - if (gstate.child == 1 && global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + if (gstate.child == 1 && table.Count() == 0 && EmptyResultIfRHSIsEmpty()) { // Empty input! res = SinkFinalizeType::NO_OUTPUT_POSSIBLE; } else { res = SinkFinalizeType::READY; } + // Clean up the current table + gstate.Finalize(client, input.interrupt_state); + // Move to the next input child - gstate.Finalize(pipeline, event); + gstate.Materialize(pipeline, event); return res; } @@ -238,19 +231,48 @@ OperatorResultType PhysicalIEJoin::ExecuteInternal(ExecutionContext &context, Da //===--------------------------------------------------------------------===// struct IEJoinUnion { using SortedTable = PhysicalRangeJoin::GlobalSortedTable; + using ChunkRange = std::pair; + + // Comparison utilities + static bool IsStrictComparison(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return true; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return false; + default: + throw InternalException("Unimplemented comparison type for IEJoin!"); + } + } - static idx_t AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx); - - static void Sort(SortedTable &table) { - auto &global_sort_state = table.global_sort_state; - global_sort_state.PrepareMergePhase(); - while (global_sort_state.sorted_blocks.size() > 1) { - global_sort_state.InitializeMergeRound(); - MergeSorter merge_sorter(global_sort_state, global_sort_state.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort_state.CompleteMergeRound(true); + template + static inline bool Compare(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); } + return less_than; + } + + template + static bool TemplatedCompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, + ExternalBlockIteratorState &state2, const idx_t pos2, bool strict); + + static bool CompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, ExternalBlockIteratorState &state2, + const idx_t pos2, bool strict, const SortKeyType &sort_key_type); + + static bool CompareBounds(SortedTable &t1, const ChunkRange &b1, SortedTable &t2, const ChunkRange &b2, + bool strict); + + static idx_t AppendKey(ExecutionContext &context, InterruptState &interrupt, SortedTable &table, + ExpressionExecutor &executor, SortedTable &marked, int64_t increment, int64_t rid, + const ChunkRange &range); + + static void Sort(ExecutionContext &context, InterruptState &interrupt, SortedTable &table) { + table.Finalize(context.client, interrupt); + table.Materialize(context, interrupt); } template @@ -258,21 +280,17 @@ struct IEJoinUnion { vector result; result.reserve(table.count); - auto &gstate = table.global_sort_state; - auto &blocks = *gstate.sorted_blocks[0]->payload_data; - PayloadScanner scanner(blocks, gstate, false); + auto &collection = *table.sorted->payload_data; + vector scan_ids(1, col_idx); + TupleDataScanState state; + collection.InitializeScan(state, scan_ids); DataChunk payload; - payload.Initialize(Allocator::DefaultAllocator(), gstate.payload_layout.GetTypes()); - for (;;) { - payload.Reset(); - scanner.Scan(payload); - const auto count = payload.size(); - if (!count) { - break; - } + collection.InitializeScanChunk(state, payload); - const auto data_ptr = FlatVector::GetData(payload.data[col_idx]); + while (collection.Scan(state, payload)) { + const auto count = payload.size(); + const auto data_ptr = FlatVector::GetData(payload.data[0]); for (idx_t i = 0; i < count; i++) { result.push_back(UnsafeNumericCast(data_ptr[i])); } @@ -281,12 +299,40 @@ struct IEJoinUnion { return result; } - IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, SortedTable &t2, - const idx_t b2); + class UnionIterator { + public: + UnionIterator(SortedTable &table, bool strict) : state(table.CreateIteratorState()), strict(strict) { + } + + inline idx_t GetIndex() const { + return index; + } + + inline void SetIndex(idx_t i) { + index = i; + } + + UnionIterator &operator++() { + ++index; + return *this; + } + + unique_ptr state; + idx_t index = 0; + const bool strict; + }; + + IEJoinUnion(ExecutionContext &context, const PhysicalIEJoin &op, SortedTable &t1, const ChunkRange &b1, + SortedTable &t2, const ChunkRange &b2); idx_t SearchL1(idx_t pos); + + template bool NextRow(); + using next_row_t = bool (duckdb::IEJoinUnion::*)(); + next_row_t next_row_func; + //! Inverted loop idx_t JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel); @@ -314,49 +360,64 @@ struct IEJoinUnion { idx_t n; idx_t i; idx_t j; - unique_ptr op1; - unique_ptr off1; - unique_ptr op2; - unique_ptr off2; + unique_ptr op1; + unique_ptr off1; + unique_ptr op2; + unique_ptr off2; int64_t lrid; }; -idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx) { - LocalSortState local_sort_state; - local_sort_state.Initialize(marked.global_sort_state, marked.global_sort_state.buffer_manager); +idx_t IEJoinUnion::AppendKey(ExecutionContext &context, InterruptState &interrupt, SortedTable &table, + ExpressionExecutor &executor, SortedTable &marked, int64_t increment, int64_t rid, + const ChunkRange &chunk_range) { + const auto chunk_begin = chunk_range.first; + const auto chunk_end = chunk_range.second; // Reading const auto valid = table.count - table.has_null; - auto &gstate = table.global_sort_state; - PayloadScanner scanner(gstate, block_idx); - auto table_idx = block_idx * gstate.block_capacity; + auto &source = *table.sorted->payload_data; + TupleDataScanState scanner; + source.InitializeScan(scanner); DataChunk scanned; - scanned.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); + source.InitializeScanChunk(scanner, scanned); - // Writing - auto types = local_sort_state.sort_layout->logical_types; - const idx_t payload_idx = types.size(); + // TODO: Random access into TupleDataCollection (NextScanIndex is private...) + idx_t table_idx = 0; + for (idx_t i = 0; i < chunk_begin; ++i) { + source.Scan(scanner, scanned); + table_idx += scanned.size(); + } - const auto &payload_types = local_sort_state.payload_layout->GetTypes(); - types.insert(types.end(), payload_types.begin(), payload_types.end()); - const idx_t rid_idx = types.size() - 1; + // Writing + auto &sort = *marked.sort; + auto local_sort_state = sort.GetLocalSinkState(context); + vector types; + for (const auto &expr : executor.expressions) { + types.emplace_back(expr->return_type); + } + const idx_t rid_idx = types.size(); + types.emplace_back(LogicalType::BIGINT); DataChunk keys; DataChunk payload; keys.Initialize(Allocator::DefaultAllocator(), types); + OperatorSinkInput sink {*marked.global_sink, *local_sort_state, interrupt}; idx_t inserted = 0; - for (auto rid = base; table_idx < valid;) { - scanned.Reset(); - scanner.Scan(scanned); + for (auto chunk_idx = chunk_begin; chunk_idx < chunk_end; ++chunk_idx) { + source.Scan(scanner, scanned); // NULLs are at the end, so stop when we reach them auto scan_count = scanned.size(); if (table_idx + scan_count > valid) { - scan_count = valid - table_idx; - scanned.SetCardinality(scan_count); + if (table_idx >= valid) { + scan_count = 0; + ; + } else { + scan_count = valid - table_idx; + scanned.SetCardinality(scan_count); + } } if (scan_count == 0) { break; @@ -375,43 +436,89 @@ idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, S rid += increment * UnsafeNumericCast(scan_count); // Sort on the sort columns (which will no longer be needed) - keys.Split(payload, payload_idx); - local_sort_state.SinkChunk(keys, payload); + sort.Sink(context, keys, sink); inserted += scan_count; - keys.Fuse(payload); - - // Flush when we have enough data - if (local_sort_state.SizeInBytes() >= marked.memory_per_thread) { - local_sort_state.Sort(marked.global_sort_state, true); - } } - marked.global_sort_state.AddLocalState(local_sort_state); + OperatorSinkCombineInput combine {*marked.global_sink, *local_sort_state, interrupt}; + sort.Combine(context, combine); marked.count += inserted; return inserted; } -IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, - SortedTable &t2, const idx_t b2) +// TODO: Function pointers? +template +bool IEJoinUnion::TemplatedCompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, + ExternalBlockIteratorState &state2, const idx_t pos2, bool strict) { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + + BLOCKS_ITERATOR bounds1(state1, pos1); + BLOCKS_ITERATOR bounds2(state2, pos2); + + return Compare(*bounds1, *bounds2, strict); +} + +bool IEJoinUnion::CompareKeys(ExternalBlockIteratorState &state1, const idx_t pos1, ExternalBlockIteratorState &state2, + const idx_t pos2, bool strict, const SortKeyType &sort_key_type) { + + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedCompareKeys(state1, pos1, state2, pos2, strict); + default: + throw NotImplementedException("IEJoinUnion::CompareKeys for %s", EnumUtil::ToString(sort_key_type)); + } +} + +bool IEJoinUnion::CompareBounds(SortedTable &t1, const ChunkRange &b1, SortedTable &t2, const ChunkRange &b2, + bool strict) { + auto &keys1 = *t1.sorted->key_data; + ExternalBlockIteratorState state1(keys1, nullptr); + const idx_t pos1 = t1.BlockStart(b1.first); + + auto &keys2 = *t2.sorted->key_data; + ExternalBlockIteratorState state2(keys2, nullptr); + const idx_t pos2 = t2.BlockEnd(b2.second - 1); + + const auto sort_key_type = t1.GetSortKeyType(); + D_ASSERT(sort_key_type == t2.GetSortKeyType()); + return CompareKeys(state1, pos1, state2, pos2, strict, sort_key_type); +} + +IEJoinUnion::IEJoinUnion(ExecutionContext &context, const PhysicalIEJoin &op, SortedTable &t1, const ChunkRange &b1, + SortedTable &t2, const ChunkRange &b2) : n(0), i(0) { // input : query Q with 2 join predicates t1.X op1 t2.X' and t1.Y op2 t2.Y', tables T, T' of sizes m and n resp. // output: a list of tuple pairs (ti , tj) // Note that T/T' are already sorted on X/X' and contain the payload data // We only join the two block numbers and use the sizes of the blocks as the counts + InterruptState interrupt; + // 0. Filter out tables with no overlap - if (!t1.BlockSize(b1) || !t2.BlockSize(b2)) { + if (t1.sorted->key_data->ChunkCount() <= b1.first || t2.sorted->key_data->ChunkCount() <= b2.first) { return; } - const auto &cmp1 = op.conditions[0].comparison; - SBIterator bounds1(t1.global_sort_state, cmp1); - SBIterator bounds2(t2.global_sort_state, cmp1); - // t1.X[0] op1 t2.X'[-1] - bounds1.SetIndex(bounds1.block_capacity * b1); - bounds2.SetIndex(bounds2.block_capacity * b2 + t2.BlockSize(b2) - 1); - if (!bounds1.Compare(bounds2)) { + const auto strict1 = IsStrictComparison(op.conditions[0].comparison); + if (!CompareBounds(t1, b1, t2, b2, strict1)) { return; } @@ -428,8 +535,6 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte vector types; types.emplace_back(order2.expression->return_type); types.emplace_back(LogicalType::BIGINT); - RowLayout payload_layout; - payload_layout.Initialize(types); // Sort on the first expression auto ref = make_uniq(order1.expression->return_type, 0U); @@ -451,37 +556,37 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte // Using this OrderType, if i < j then value[i] (from left table) and value[j] (from right table) match // the condition (t1.time <= t2.time or t1.time < t2.time), then from_left will force them into the correct order. auto from_left = make_uniq(Value::BOOLEAN(true)); - orders.emplace_back(SBIterator::ComparisonValue(cmp1) == 0 ? OrderType::DESCENDING : OrderType::ASCENDING, - OrderByNullType::ORDER_DEFAULT, std::move(from_left)); + orders.emplace_back(!strict1 ? OrderType::DESCENDING : OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, + std::move(from_left)); - l1 = make_uniq(context, orders, payload_layout, op); + l1 = make_uniq(context.client, orders, types, op); // LHS has positive rids - ExpressionExecutor l_executor(context); + ExpressionExecutor l_executor(context.client); l_executor.AddExpression(*order1.expression); // add const column true auto left_const = make_uniq(Value::BOOLEAN(true)); l_executor.AddExpression(*left_const); l_executor.AddExpression(*order2.expression); - AppendKey(t1, l_executor, *l1, 1, 1, b1); + AppendKey(context, interrupt, t1, l_executor, *l1, 1, 1, b1); // RHS has negative rids - ExpressionExecutor r_executor(context); + ExpressionExecutor r_executor(context.client); r_executor.AddExpression(*op.rhs_orders[0].expression); // add const column flase auto right_const = make_uniq(Value::BOOLEAN(false)); r_executor.AddExpression(*right_const); r_executor.AddExpression(*op.rhs_orders[1].expression); - AppendKey(t2, r_executor, *l1, -1, -1, b2); + AppendKey(context, interrupt, t2, r_executor, *l1, -1, -1, b2); - if (l1->global_sort_state.sorted_blocks.empty()) { + if (!l1->Count()) { return; } - Sort(*l1); + Sort(context, interrupt, *l1); - op1 = make_uniq(l1->global_sort_state, cmp1); - off1 = make_uniq(l1->global_sort_state, cmp1); + op1 = make_uniq(*l1, strict1); + off1 = make_uniq(*l1, strict1); // We don't actually need the L1 column, just its sort key, which is in the sort blocks li = ExtractColumn(*l1, types.size() - 1); @@ -493,22 +598,19 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte // For this we just need a two-column table of Y, P types.clear(); types.emplace_back(LogicalType::BIGINT); - payload_layout.Initialize(types); // Sort on the first expression orders.clear(); ref = make_uniq(order2.expression->return_type, 0U); orders.emplace_back(order2.type, order2.null_order, std::move(ref)); - ExpressionExecutor executor(context); + ExpressionExecutor executor(context.client); executor.AddExpression(*orders[0].expression); - l2 = make_uniq(context, orders, payload_layout, op); - for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { - base += AppendKey(*l1, executor, *l2, 1, NumericCast(base), block_idx); - } + l2 = make_uniq(context.client, orders, types, op); + AppendKey(context, interrupt, *l1, executor, *l2, 1, 0, {0, l1->BlockCount()}); - Sort(*l2); + Sort(context, interrupt, *l2); // We don't actually need the L2 column, just its sort key, which is in the sort blocks @@ -526,15 +628,57 @@ IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, Sorte bloom_filter.Initialize(bloom_array.data(), bloom_count); // 11. for(i←1 to n) do - const auto &cmp2 = op.conditions[1].comparison; - op2 = make_uniq(l2->global_sort_state, cmp2); - off2 = make_uniq(l2->global_sort_state, cmp2); + const auto strict2 = IsStrictComparison(op.conditions[1].comparison); + op2 = make_uniq(*l2, strict2); + off2 = make_uniq(*l2, strict2); i = 0; j = 0; - (void)NextRow(); + + const auto sort_key_type = l2->GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_16: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_24: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_FIXED_32: + next_row_func = &IEJoinUnion::NextRow; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + next_row_func = &IEJoinUnion::NextRow; + break; + default: + throw NotImplementedException("IEJoinUnion for %s", EnumUtil::ToString(sort_key_type)); + } + + (this->*next_row_func)(); } +template bool IEJoinUnion::NextRow() { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + + BLOCKS_ITERATOR off2_itr(*off2->state); + BLOCKS_ITERATOR op2_itr(*op2->state); + const auto strict = off2->strict; + for (; i < n; ++i) { // 12. pos ← P[i] auto pos = p[i]; @@ -546,7 +690,7 @@ bool IEJoinUnion::NextRow() { // 16. B[pos] ← 1 op2->SetIndex(i); for (; off2->GetIndex() < n; ++(*off2)) { - if (!off2->Compare(*op2)) { + if (!Compare(off2_itr[off2->GetIndex()], op2_itr[op2->GetIndex()], strict)) { break; } const auto p2 = p[off2->GetIndex()]; @@ -652,7 +796,7 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse } ++i; - if (!NextRow()) { + if (!(this->*next_row_func)()) { break; } } @@ -662,11 +806,28 @@ idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rse class IEJoinLocalSourceState : public LocalSourceState { public: - explicit IEJoinLocalSourceState(ClientContext &context, const PhysicalIEJoin &op) - : op(op), true_sel(STANDARD_VECTOR_SIZE), left_executor(context), right_executor(context), - left_matches(nullptr), right_matches(nullptr) { - auto &allocator = Allocator::Get(context); - unprojected.Initialize(allocator, op.unprojected_types); + IEJoinLocalSourceState(ClientContext &client, const PhysicalIEJoin &op) + : op(op), lsel(STANDARD_VECTOR_SIZE), rsel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), + left_executor(client), right_executor(client), left_matches(nullptr), right_matches(nullptr) + + { + auto &allocator = Allocator::Get(client); + unprojected.InitializeEmpty(op.unprojected_types); + lpayload.Initialize(allocator, op.children[0].get().GetTypes()); + rpayload.Initialize(allocator, op.children[1].get().GetTypes()); + + auto &ie_sink = op.sink_state->Cast(); + auto &left_table = *ie_sink.tables[0]; + auto &right_table = *ie_sink.tables[1]; + + left_iterator = left_table.CreateIteratorState(); + right_iterator = right_table.CreateIteratorState(); + + left_table.InitializePayloadState(left_chunk_state); + right_table.InitializePayloadState(right_chunk_state); + + left_scan_state = left_table.CreateScanState(client); + right_scan_state = right_table.CreateScanState(client); if (op.conditions.size() < 3) { return; @@ -710,9 +871,19 @@ class IEJoinLocalSourceState : public LocalSourceState { idx_t left_base; idx_t left_block_index; + unique_ptr left_iterator; + TupleDataChunkState left_chunk_state; + SelectionVector lsel; + DataChunk lpayload; + unique_ptr left_scan_state; idx_t right_base; idx_t right_block_index; + unique_ptr right_iterator; + TupleDataChunkState right_chunk_state; + SelectionVector rsel; + DataChunk rpayload; + unique_ptr right_scan_state; // Trailing predicates SelectionVector true_sel; @@ -735,14 +906,27 @@ class IEJoinLocalSourceState : public LocalSourceState { void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state_p) const { auto &state = state_p.Cast(); auto &ie_sink = sink_state->Cast(); + + auto &chunk = state.unprojected; + auto &left_table = *ie_sink.tables[0]; + auto &lsel = state.lsel; + auto &lpayload = state.lpayload; + auto &left_iterator = *state.left_iterator; + auto &left_chunk_state = state.left_chunk_state; + auto &left_block_index = state.left_block_index; + auto &left_scan_state = *state.left_scan_state; + const auto left_cols = children[0].get().GetTypes().size(); + auto &right_table = *ie_sink.tables[1]; + auto &rsel = state.rsel; + auto &rpayload = state.rpayload; + auto &right_iterator = *state.right_iterator; + auto &right_chunk_state = state.right_chunk_state; + auto &right_block_index = state.right_block_index; + auto &right_scan_state = *state.right_scan_state; - const auto left_cols = children[0].get().GetTypes().size(); - auto &chunk = state.unprojected; do { - SelectionVector lsel(STANDARD_VECTOR_SIZE); - SelectionVector rsel(STANDARD_VECTOR_SIZE); auto result_count = state.joiner->JoinComplexBlocks(lsel, rsel); if (result_count == 0) { // exhausted this pair @@ -751,23 +935,23 @@ void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &re // found matches: extract them - chunk.Reset(); - SliceSortedPayload(chunk, left_table.global_sort_state, state.left_block_index, lsel, result_count, 0); - SliceSortedPayload(chunk, right_table.global_sort_state, state.right_block_index, rsel, result_count, - left_cols); - chunk.SetCardinality(result_count); + left_table.Repin(left_iterator); + right_table.Repin(right_iterator); + + SliceSortedPayload(lpayload, left_table, left_iterator, left_chunk_state, left_block_index, lsel, result_count, + left_scan_state); + SliceSortedPayload(rpayload, right_table, right_iterator, right_chunk_state, right_block_index, rsel, + result_count, right_scan_state); auto sel = FlatVector::IncrementalSelectionVector(); if (conditions.size() > 2) { // If there are more expressions to compute, - // split the result chunk into the left and right halves - // so we can compute the values for comparison. + // use the left and right payloads + // to we can compute the values for comparison. const auto tail_cols = conditions.size() - 2; - DataChunk right_chunk; - chunk.Split(right_chunk, left_cols); - state.left_executor.SetChunk(chunk); - state.right_executor.SetChunk(right_chunk); + state.left_executor.SetChunk(lpayload); + state.right_executor.SetChunk(rpayload); auto tail_count = result_count; auto true_sel = &state.true_sel; @@ -785,14 +969,25 @@ void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &re tail_count = SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, true_sel); sel = true_sel; } - chunk.Fuse(right_chunk); if (tail_count < result_count) { result_count = tail_count; - chunk.Slice(*sel, result_count); + lpayload.Slice(*sel, result_count); + rpayload.Slice(*sel, result_count); } } + // Merge the payloads + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Reference(lpayload.data[col_idx]); + } else { + chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); + } + } + chunk.SetCardinality(result_count); + // We need all of the data to compute other predicates, // but we only return what is in the projection map ProjectResult(chunk, result); @@ -814,35 +1009,23 @@ void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &re class IEJoinGlobalSourceState : public GlobalSourceState { public: - explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) + IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) : op(op), gsink(gsink), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), right_outers(0), next_right(0) { } - void Initialize() { + void Initialize(ClientContext &client) { auto guard = Lock(); if (initialized) { return; } - // Compute the starting row for reach block - // (In theory these are all the same size, but you never know...) + // Compute the starting row for each block auto &left_table = *gsink.tables[0]; const auto left_blocks = left_table.BlockCount(); - idx_t left_base = 0; - - for (size_t lhs = 0; lhs < left_blocks; ++lhs) { - left_bases.emplace_back(left_base); - left_base += left_table.BlockSize(lhs); - } auto &right_table = *gsink.tables[1]; const auto right_blocks = right_table.BlockCount(); - idx_t right_base = 0; - for (size_t rhs = 0; rhs < right_blocks; ++rhs) { - right_bases.emplace_back(right_base); - right_base += right_table.BlockSize(rhs); - } // Outer join block counts if (left_table.found_match) { @@ -864,27 +1047,34 @@ class IEJoinGlobalSourceState : public GlobalSourceState { return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); } - void GetNextPair(ClientContext &client, IEJoinLocalSourceState &lstate) { + void GetNextPair(ExecutionContext &context, IEJoinLocalSourceState &lstate) { + using ChunkRange = IEJoinUnion::ChunkRange; auto &left_table = *gsink.tables[0]; auto &right_table = *gsink.tables[1]; const auto left_blocks = left_table.BlockCount(); + const auto left_ranges = (left_blocks + left_per_thread - 1) / left_per_thread; + const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; + const auto right_ranges = (right_blocks + right_per_thread - 1) / right_per_thread; + + const auto pair_count = left_ranges * right_ranges; // Regular block const auto i = next_pair++; if (i < pair_count) { - const auto b1 = i / right_blocks; - const auto b2 = i % right_blocks; + const auto b1 = (i / right_ranges) * left_per_thread; + const auto b2 = (i % right_ranges) * right_per_thread; - lstate.left_block_index = b1; - lstate.left_base = left_bases[b1]; + ChunkRange l_range {b1, MinValue(left_blocks, b1 + left_per_thread)}; + lstate.left_block_index = l_range.first; + lstate.left_base = left_table.BlockStart(l_range.first); - lstate.right_block_index = b2; - lstate.right_base = right_bases[b2]; + ChunkRange r_range {b2, MinValue(right_blocks, b2 + right_per_thread)}; + lstate.right_block_index = r_range.first; + lstate.right_base = right_table.BlockStart(r_range.first); - lstate.joiner = make_uniq(client, op, left_table, b1, right_table, b2); + lstate.joiner = make_uniq(context, op, left_table, l_range, right_table, r_range); return; } @@ -895,7 +1085,7 @@ class IEJoinGlobalSourceState : public GlobalSourceState { // Spin wait for regular blocks to finish(!) while (completed < pair_count) { - std::this_thread::yield(); + TaskScheduler::GetScheduler(context.client).YieldThread(); } // Left outer blocks @@ -903,7 +1093,7 @@ class IEJoinGlobalSourceState : public GlobalSourceState { if (l < left_outers) { lstate.joiner = nullptr; lstate.left_block_index = l; - lstate.left_base = left_bases[l]; + lstate.left_base = left_table.BlockStart(l); lstate.left_matches = left_table.found_match.get() + lstate.left_base; lstate.outer_idx = 0; @@ -918,7 +1108,7 @@ class IEJoinGlobalSourceState : public GlobalSourceState { if (r < right_outers) { lstate.joiner = nullptr; lstate.right_block_index = r; - lstate.right_base = right_bases[r]; + lstate.right_base = right_table.BlockStart(r); lstate.right_matches = right_table.found_match.get() + lstate.right_base; lstate.outer_idx = 0; @@ -929,10 +1119,10 @@ class IEJoinGlobalSourceState : public GlobalSourceState { } } - void PairCompleted(ClientContext &client, IEJoinLocalSourceState &lstate) { + void PairCompleted(ExecutionContext &context, IEJoinLocalSourceState &lstate) { lstate.joiner.reset(); ++completed; - GetNextPair(client, lstate); + GetNextPair(context, lstate); } ProgressData GetProgress() const { @@ -962,16 +1152,14 @@ class IEJoinGlobalSourceState : public GlobalSourceState { const PhysicalIEJoin &op; IEJoinGlobalState &gsink; - bool initialized; + bool initialized = false; // Join queue state + const idx_t left_per_thread = 1024; + const idx_t right_per_thread = 1024; atomic next_pair; atomic completed; - // Block base row number - vector left_bases; - vector right_bases; - // Outer joins atomic left_outers; atomic next_left; @@ -1001,10 +1189,10 @@ SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &r auto &ie_gstate = input.global_state.Cast(); auto &ie_lstate = input.local_state.Cast(); - ie_gstate.Initialize(); + ie_gstate.Initialize(context.client); if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { - ie_gstate.GetNextPair(context.client, ie_lstate); + ie_gstate.GetNextPair(context, ie_lstate); } // Process INNER results @@ -1015,26 +1203,38 @@ SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &r return SourceResultType::HAVE_MORE_OUTPUT; } - ie_gstate.PairCompleted(context.client, ie_lstate); + ie_gstate.PairCompleted(context, ie_lstate); } // Process LEFT OUTER results const auto left_cols = children[0].get().GetTypes().size(); + auto &chunk = ie_lstate.unprojected; while (ie_lstate.left_matches) { const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); + ie_gstate.GetNextPair(context, ie_lstate); continue; } - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[0]->global_sort_state, ie_lstate.left_block_index, ie_lstate.true_sel, - count); + auto &left_table = *ie_sink.tables[0]; + auto &lpayload = ie_lstate.lpayload; + auto &left_iterator = *ie_lstate.left_iterator; + auto &left_chunk_state = ie_lstate.left_chunk_state; + auto &left_block_index = ie_lstate.left_block_index; + auto &left_scan_state = *ie_lstate.left_scan_state; + + left_table.Repin(left_iterator); + SliceSortedPayload(lpayload, left_table, left_iterator, left_chunk_state, left_block_index, ie_lstate.true_sel, + count, left_scan_state); // Fill in NULLs to the right - for (auto col_idx = left_cols; col_idx < chunk.ColumnCount(); ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Reference(lpayload.data[col_idx]); + } else { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } } ProjectResult(chunk, result); @@ -1048,19 +1248,31 @@ SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &r while (ie_lstate.right_matches) { const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); + ie_gstate.GetNextPair(context, ie_lstate); continue; } - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[1]->global_sort_state, ie_lstate.right_block_index, ie_lstate.true_sel, - count, left_cols); + auto &right_table = *ie_sink.tables[1]; + auto &rsel = ie_lstate.true_sel; + auto &rpayload = ie_lstate.rpayload; + auto &right_iterator = *ie_lstate.right_iterator; + auto &right_chunk_state = ie_lstate.right_chunk_state; + auto &right_block_index = ie_lstate.right_block_index; + auto &right_scan_state = *ie_lstate.right_scan_state; + + right_table.Repin(right_iterator); + SliceSortedPayload(rpayload, right_table, right_iterator, right_chunk_state, right_block_index, rsel, count, + right_scan_state); // Fill in NULLs to the left - for (idx_t col_idx = 0; col_idx < left_cols; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); + chunk.Reset(); + for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(chunk.data[col_idx], true); + } else { + chunk.data[col_idx].Reference(rpayload.data[col_idx - left_cols]); + } } ProjectResult(chunk, result); diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp index 1bd48ab62..c29f58a68 100644 --- a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -1,11 +1,8 @@ #include "duckdb/execution/operator/join/physical_piecewise_merge_join.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" +#include "duckdb/common/sorting/sort_key.hpp" +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/operator/join/outer_join_marker.hpp" #include "duckdb/main/client_context.hpp" @@ -65,15 +62,14 @@ class MergeJoinGlobalState : public GlobalSinkState { using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - MergeJoinGlobalState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op) { - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1].get().GetTypes()); + MergeJoinGlobalState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) { + const auto &rhs_types = op.children[1].get().GetTypes(); vector rhs_order; rhs_order.emplace_back(op.rhs_orders[0].Copy()); - table = make_uniq(context, rhs_order, rhs_layout, op); + table = make_uniq(client, rhs_order, rhs_types, op); if (op.filter_pushdown) { skip_filter_pushdown = op.filter_pushdown->probe_info.empty(); - global_filter_state = op.filter_pushdown->GetGlobalState(context, op); + global_filter_state = op.filter_pushdown->GetGlobalState(client, op); } } @@ -81,8 +77,9 @@ class MergeJoinGlobalState : public GlobalSinkState { return table->count; } - void Sink(DataChunk &input, MergeJoinLocalState &lstate); + void Sink(ExecutionContext &context, DataChunk &input, MergeJoinLocalState &lstate); + //! The sorted table unique_ptr table; //! Should we not bother pushing down filters? bool skip_filter_pushdown = false; @@ -92,16 +89,19 @@ class MergeJoinGlobalState : public GlobalSinkState { class MergeJoinLocalState : public LocalSinkState { public: - explicit MergeJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, MergeJoinGlobalState &gstate, - const idx_t child) - : table(context, op, child) { + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; + using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + + MergeJoinLocalState(ExecutionContext &context, MergeJoinGlobalState &gstate, const idx_t child) + : table(context, *gstate.table, child) { + auto &op = gstate.table->op; if (op.filter_pushdown) { local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); } } //! The local sort state - PhysicalRangeJoin::LocalSortedTable table; + LocalSortedTable table; //! Local state for accumulating filter statistics unique_ptr local_filter_state; }; @@ -113,20 +113,12 @@ unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(Clien unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSinkState(ExecutionContext &context) const { // We only sink the RHS auto &gstate = sink_state->Cast(); - return make_uniq(context.client, *this, gstate, 1U); + return make_uniq(context, gstate, 1U); } -void MergeJoinGlobalState::Sink(DataChunk &input, MergeJoinLocalState &lstate) { - auto &global_sort_state = table->global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - +void MergeJoinGlobalState::Sink(ExecutionContext &context, DataChunk &input, MergeJoinLocalState &lstate) { // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table->memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } + lstate.table.Sink(context, input); } SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataChunk &chunk, @@ -134,7 +126,7 @@ SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataC auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.Sink(chunk, lstate); + gstate.Sink(context, chunk, lstate); if (filter_pushdown && !gstate.skip_filter_pushdown) { filter_pushdown->Sink(lstate.table.keys, *lstate.local_filter_state); @@ -147,7 +139,7 @@ SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &cont OperatorSinkCombineInput &input) const { auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - gstate.table->Combine(lstate.table); + gstate.table->Combine(context, lstate.table); auto &client_profiler = QueryProfiler::Get(context.client); context.thread.profiler.Flush(*this); @@ -162,25 +154,28 @@ SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &cont //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, +SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { auto &gstate = input.global_state.Cast(); if (filter_pushdown && !gstate.skip_filter_pushdown) { - (void)filter_pushdown->Finalize(context, nullptr, *gstate.global_filter_state, *this); + (void)filter_pushdown->Finalize(client, nullptr, *gstate.global_filter_state, *this); } - auto &global_sort_state = gstate.table->global_sort_state; + + gstate.table->Finalize(client, input.interrupt_state); if (PropagatesBuildSide(join_type)) { // for FULL/RIGHT OUTER JOIN, initialize found_match to false for every tuple gstate.table->IntializeMatches(); } - if (global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { + + if (gstate.table->Count() == 0 && EmptyResultIfRHSIsEmpty()) { // Empty input! + gstate.table->MaterializeEmpty(client); return SinkFinalizeType::NO_OUTPUT_POSSIBLE; } // Sort the current input child - gstate.table->Finalize(pipeline, event); + gstate.table->Materialize(pipeline, event); return SinkFinalizeType::READY; } @@ -191,46 +186,50 @@ SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event class PiecewiseMergeJoinState : public CachingOperatorState { public: using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; + using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - PiecewiseMergeJoinState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op, bool force_external) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(force_external), - left_outer(IsLeftOuterJoin(op.join_type)), left_position(0), first_fetch(true), finished(true), - right_position(0), right_chunk_index(0), rhs_executor(context) { - vector condition_types; - for (auto &order : op.lhs_orders) { - condition_types.push_back(order.expression->return_type); - } + PiecewiseMergeJoinState(ClientContext &client, const PhysicalPiecewiseMergeJoin &op) + : client(client), allocator(Allocator::Get(client)), op(op), left_outer(IsLeftOuterJoin(op.join_type)), + left_position(0), first_fetch(true), finished(true), right_position(0), right_chunk_index(0), + rhs_executor(client) { left_outer.Initialize(STANDARD_VECTOR_SIZE); - lhs_layout.Initialize(op.children[0].get().GetTypes()); - lhs_payload.Initialize(allocator, op.children[0].get().GetTypes()); + lhs_payload.Initialize(client, op.children[0].get().GetTypes()); + // Sort on the first column lhs_order.emplace_back(op.lhs_orders[0].Copy()); // Set up shared data for multiple predicates sel.Initialize(STANDARD_VECTOR_SIZE); - condition_types.clear(); + vector condition_types; for (auto &order : op.rhs_orders) { rhs_executor.AddExpression(*order.expression); condition_types.push_back(order.expression->return_type); } - rhs_keys.Initialize(allocator, condition_types); + rhs_keys.Initialize(client, condition_types); + rhs_input.Initialize(client, op.children[1].get().GetTypes()); + + auto &gsink = op.sink_state->Cast(); + auto &rhs_table = *gsink.table; + rhs_iterator = rhs_table.CreateIteratorState(); + rhs_table.InitializePayloadState(rhs_chunk_state); + rhs_scan_state = rhs_table.CreateScanState(client); + + // Since we have now materialized the payload, the keys will not have payloads? + sort_key_type = rhs_table.GetSortKeyType(); } - ClientContext &context; + ClientContext &client; Allocator &allocator; const PhysicalPiecewiseMergeJoin &op; - BufferManager &buffer_manager; - bool force_external; // Block sorting DataChunk lhs_payload; OuterJoinMarker left_outer; vector lhs_order; - RowLayout lhs_layout; + unique_ptr lhs_global_table; unique_ptr lhs_local_table; - unique_ptr lhs_global_state; - unique_ptr scanner; + SortKeyType sort_key_type; + TupleDataScanState lhs_scan; // Simple scans idx_t left_position; @@ -238,178 +237,127 @@ class PiecewiseMergeJoinState : public CachingOperatorState { // Complex scans bool first_fetch; bool finished; + unique_ptr lhs_iterator; + unique_ptr rhs_iterator; idx_t right_position; idx_t right_chunk_index; idx_t right_base; idx_t prev_left_index; + TupleDataChunkState rhs_chunk_state; + unique_ptr rhs_scan_state; // Secondary predicate shared data SelectionVector sel; DataChunk rhs_keys; DataChunk rhs_input; ExpressionExecutor rhs_executor; - vector payload_heap_handles; public: - void ResolveJoinKeys(DataChunk &input) { + void ResolveJoinKeys(ExecutionContext &context, DataChunk &input) { // sort by join key - lhs_global_state = make_uniq(context, lhs_order, lhs_layout); - lhs_local_table = make_uniq(context, op, 0U); - lhs_local_table->Sink(input, *lhs_global_state); - - // Set external (can be forced with the PRAGMA) - lhs_global_state->external = force_external; - lhs_global_state->AddLocalState(lhs_local_table->local_sort_state); - lhs_global_state->PrepareMergePhase(); - while (lhs_global_state->sorted_blocks.size() > 1) { - MergeSorter merge_sorter(*lhs_global_state, buffer_manager); - merge_sorter.PerformInMergeRound(); - lhs_global_state->CompleteMergeRound(); - } - - // Scan the sorted payload - D_ASSERT(lhs_global_state->sorted_blocks.size() == 1); - - scanner = make_uniq(*lhs_global_state->sorted_blocks[0]->payload_data, *lhs_global_state); - lhs_payload.Reset(); - scanner->Scan(lhs_payload); + const auto &lhs_types = lhs_payload.GetTypes(); + lhs_global_table = make_uniq(context.client, lhs_order, lhs_types, op); + lhs_local_table = make_uniq(context, *lhs_global_table, 0U); + lhs_local_table->Sink(context, input); + lhs_global_table->Combine(context, *lhs_local_table); + + InterruptState interrupt; + lhs_global_table->Finalize(context.client, interrupt); + lhs_global_table->Materialize(context, interrupt); + + // Scan the sorted payload (minus the primary sort column) + auto &lhs_table = *lhs_global_table; + auto &lhs_payload_data = *lhs_table.sorted->payload_data; + lhs_payload_data.InitializeScan(lhs_scan); + lhs_payload_data.Scan(lhs_scan, lhs_payload); // Recompute the sorted keys from the sorted input - lhs_local_table->keys.Reset(); - lhs_local_table->executor.Execute(lhs_payload, lhs_local_table->keys); - } + auto &lhs_keys = lhs_local_table->keys; + lhs_keys.Reset(); + lhs_local_table->executor.Execute(lhs_payload, lhs_keys); - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - if (lhs_local_table) { - context.thread.profiler.Flush(op); - } + lhs_iterator = lhs_table.CreateIteratorState(); } }; unique_ptr PhysicalPiecewiseMergeJoin::GetOperatorState(ExecutionContext &context) const { - bool force_external = ClientConfig::GetConfig(context.client).force_external; - return make_uniq(context.client, *this, force_external); + return make_uniq(context.client, *this); } -static inline idx_t SortedBlockNotNull(const idx_t base, const idx_t count, const idx_t not_null) { - return MinValue(base + count, MaxValue(base, not_null)) - base; +static inline idx_t SortedChunkNotNull(const idx_t chunk_idx, const idx_t count, const idx_t has_null) { + const auto chunk_begin = chunk_idx * STANDARD_VECTOR_SIZE; + const auto chunk_end = MinValue(chunk_begin + STANDARD_VECTOR_SIZE, count); + const auto not_null = count - has_null; + return MinValue(chunk_end, MaxValue(chunk_begin, not_null)) - chunk_begin; } -static int MergeJoinComparisonValue(ExpressionType comparison) { +static bool MergeJoinStrictComparison(ExpressionType comparison) { switch (comparison) { case ExpressionType::COMPARE_LESSTHAN: case ExpressionType::COMPARE_GREATERTHAN: - return -1; + return true; case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; + return false; default: throw InternalException("Unimplemented comparison type for merge join!"); } } -struct BlockMergeInfo { - GlobalSortState &state; - //! The block being scanned - const idx_t block_idx; - //! The number of not-NULL values in the block (they are at the end) - const idx_t not_null; - //! The current offset in the block - idx_t &entry_idx; - SelectionVector result; - - BlockMergeInfo(GlobalSortState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) - : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { - } -}; - -static void MergeJoinPinSortingBlock(SBScanState &scan, const idx_t block_idx) { - scan.SetIndices(block_idx, 0); - scan.PinRadix(block_idx); - - auto &sd = *scan.sb->blob_sorting_data; - if (block_idx < sd.data_blocks.size()) { - scan.PinData(sd); +// Compare using +bool MergeJoinBefore(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); } + return less_than; } -static data_ptr_t MergeJoinRadixPtr(SBScanState &scan, const idx_t entry_idx) { - scan.entry_idx = entry_idx; - return scan.RadixPtr(); -} +template +static idx_t TemplatedMergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &gstate, + bool *found_match, const bool strict) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; -static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &rstate, bool *found_match, - const ExpressionType comparison) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - auto &lsort = *lstate.lhs_global_state; - auto &rsort = rstate.table->global_sort_state; - D_ASSERT(lsort.sort_layout.all_constant == rsort.sort_layout.all_constant); - const auto all_constant = lsort.sort_layout.all_constant; - D_ASSERT(lsort.external == rsort.external); - const auto external = lsort.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(lsort.sorted_blocks.size() == 1); - SBScanState lread(lsort.buffer_manager, lsort); - lread.sb = lsort.sorted_blocks[0].get(); - - const idx_t l_block_idx = 0; - idx_t l_entry_idx = 0; - const auto lhs_not_null = lstate.lhs_local_table->count - lstate.lhs_local_table->has_null; - MergeJoinPinSortingBlock(lread, l_block_idx); - auto l_ptr = MergeJoinRadixPtr(lread, l_entry_idx); - - D_ASSERT(rsort.sorted_blocks.size() == 1); - SBScanState rread(rsort.buffer_manager, rsort); - rread.sb = rsort.sorted_blocks[0].get(); + // We only need the keys because we are extracting the row numbers + auto &lhs_table = *lstate.lhs_global_table; + D_ASSERT(SORT_KEY_TYPE == lhs_table.GetSortKeyType()); + auto &lhs_iterator = *lstate.lhs_iterator; + const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - const auto cmp_size = lsort.sort_layout.comparison_size; - const auto entry_size = lsort.sort_layout.entry_size; + auto &rhs_table = *gstate.table; + auto &rhs_iterator = *lstate.rhs_iterator; + const auto rhs_not_null = rhs_table.count - rhs_table.has_null; - idx_t right_base = 0; - for (idx_t r_block_idx = 0; r_block_idx < rread.sb->radix_sorting_data.size(); r_block_idx++) { - // we only care about the BIGGEST value in each of the RHS data blocks + idx_t l_entry_idx = 0; + BLOCK_ITERATOR lhs_itr(lhs_iterator); + BLOCK_ITERATOR rhs_itr(rhs_iterator); + for (idx_t r_idx = 0; r_idx < rhs_not_null; r_idx += STANDARD_VECTOR_SIZE) { + // Repin the RHS to release memory + // This is safe because we only return the LHS values + // Note we only do this for the RHS because the LHS is only one chunk. + rhs_table.Repin(rhs_iterator); + + // we only care about the BIGGEST value in the RHS // because we want to figure out if the LHS values are less than [or equal] to ANY value - // get the biggest value from the RHS chunk - MergeJoinPinSortingBlock(rread, r_block_idx); - - auto &rblock = *rread.sb->radix_sorting_data[r_block_idx]; - const auto r_not_null = - SortedBlockNotNull(right_base, rblock.count, rstate.table->count - rstate.table->has_null); - if (r_not_null == 0) { - break; - } - const auto r_entry_idx = r_not_null - 1; - right_base += rblock.count; - - auto r_ptr = MergeJoinRadixPtr(rread, r_entry_idx); + const auto r_entry_idx = MinValue(r_idx + STANDARD_VECTOR_SIZE, rhs_not_null) - 1; // now we start from the current lpos value and check if we found a new value that is [<= OR <] the max RHS // value while (true) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l_entry_idx; - rread.entry_idx = r_entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, lsort.sort_layout, external); - } - - if (comp_res <= cmp) { + // Note that both subscripts here are table indices, not chunk indices. + if (MergeJoinBefore(lhs_itr[l_entry_idx], rhs_itr[r_entry_idx], strict)) { // found a match for lpos, set it in the found_match vector found_match[l_entry_idx] = true; l_entry_idx++; - l_ptr += entry_size; if (l_entry_idx >= lhs_not_null) { // early out: we exhausted the entire LHS and they all match return 0; } } else { // we found no match: any subsequent value from the LHS we scan now will be bigger and thus also not - // match move to the next RHS chunk + // match. Move to the next RHS chunk break; } } @@ -417,13 +365,42 @@ static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlo return 0; } +static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &gstate, bool *match, + const ExpressionType comparison) { + const auto strict = MergeJoinStrictComparison(comparison); + + switch (lstate.sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinSimpleBlocks(lstate, gstate, match, strict); + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(lstate.sort_key_type)); + } +} + void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state_p) const { auto &state = state_p.Cast(); auto &gstate = sink_state->Cast(); - state.ResolveJoinKeys(input); - auto &lhs_table = *state.lhs_local_table; + state.ResolveJoinKeys(context, input); + auto &lhs_table = *state.lhs_global_table; + auto &lhs_keys = state.lhs_local_table->keys; // perform the actual join bool found_match[STANDARD_VECTOR_SIZE]; @@ -439,8 +416,8 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da case JoinType::MARK: { // The only part of the join keys that is actually used is the validity mask. // Since the payload is sorted, we can just set the tail end of the validity masks to invalid. - for (auto &key : lhs_table.keys.data) { - key.Flatten(lhs_table.keys.size()); + for (auto &key : lhs_keys.data) { + key.Flatten(lhs_keys.size()); auto &mask = FlatVector::Validity(key); if (mask.AllValid()) { continue; @@ -451,7 +428,7 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da } } // So we make a set of keys that have the validity mask set for the - PhysicalJoin::ConstructMarkJoinResult(lhs_table.keys, payload, chunk, found_match, gstate.table->has_null); + PhysicalJoin::ConstructMarkJoinResult(lhs_keys, payload, chunk, found_match, gstate.table->has_null); break; } case JoinType::SEMI: @@ -465,41 +442,42 @@ void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, Da } } -static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const ExpressionType comparison, - idx_t &prev_left_index) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - D_ASSERT(l.state.sort_layout.all_constant == r.state.sort_layout.all_constant); - const auto all_constant = r.state.sort_layout.all_constant; - D_ASSERT(l.state.external == r.state.external); - const auto external = l.state.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(l.state.sorted_blocks.size() == 1); - SBScanState lread(l.state.buffer_manager, l.state); - lread.sb = l.state.sorted_blocks[0].get(); - D_ASSERT(lread.sb->radix_sorting_data.size() == 1); - MergeJoinPinSortingBlock(lread, l.block_idx); - auto l_start = MergeJoinRadixPtr(lread, 0); - auto l_ptr = MergeJoinRadixPtr(lread, l.entry_idx); - - D_ASSERT(r.state.sorted_blocks.size() == 1); - SBScanState rread(r.state.buffer_manager, r.state); - rread.sb = r.state.sorted_blocks[0].get(); +struct ChunkMergeInfo { + //! The iteration state + ExternalBlockIteratorState &state; + //! The block being scanned + const idx_t block_idx; + //! The number of not-NULL values in the chunk (they are at the end) + const idx_t not_null; + //! The current offset in the chunk + idx_t &entry_idx; + //! The offsets that match + SelectionVector result; - if (r.entry_idx >= r.not_null) { - return 0; + ChunkMergeInfo(ExternalBlockIteratorState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) + : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { } - MergeJoinPinSortingBlock(rread, r.block_idx); - auto r_ptr = MergeJoinRadixPtr(rread, r.entry_idx); + idx_t GetIndex() const { + return state.GetIndex(block_idx, entry_idx); + } +}; + +template +static idx_t TemplatedMergeJoinComplexBlocks(ChunkMergeInfo &l, ChunkMergeInfo &r, const bool strict, + idx_t &prev_left_index) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; - const auto cmp_size = l.state.sort_layout.comparison_size; - const auto entry_size = l.state.sort_layout.entry_size; + if (r.entry_idx >= r.not_null) { + return 0; + } idx_t result_count = 0; + BLOCK_ITERATOR l_ptr(l.state); + BLOCK_ITERATOR r_ptr(r.state); while (true) { + if (l.entry_idx < prev_left_index) { // left side smaller: found match l.result.set_index(result_count, sel_t(l.entry_idx)); @@ -507,7 +485,7 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const result_count++; // move left side forward l.entry_idx++; - l_ptr += entry_size; + ++l_ptr; if (result_count == STANDARD_VECTOR_SIZE) { // out of space! break; @@ -515,22 +493,14 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const continue; } if (l.entry_idx < l.not_null) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l.entry_idx; - rread.entry_idx = r.entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, l.state.sort_layout, external); - } - if (comp_res <= cmp) { + if (MergeJoinBefore(l_ptr[l.GetIndex()], r_ptr[r.GetIndex()], strict)) { // left side smaller: found match l.result.set_index(result_count, sel_t(l.entry_idx)); r.result.set_index(result_count, sel_t(r.entry_idx)); result_count++; // move left side forward l.entry_idx++; - l_ptr += entry_size; + ++l_ptr; if (result_count == STANDARD_VECTOR_SIZE) { // out of space! break; @@ -546,27 +516,53 @@ static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const if (r.entry_idx >= r.not_null) { break; } - r_ptr += entry_size; + ++r_ptr; - l_ptr = l_start; l.entry_idx = 0; } return result_count; } +static idx_t MergeJoinComplexBlocks(const SortKeyType &sort_key_type, ChunkMergeInfo &l, ChunkMergeInfo &r, + const ExpressionType comparison, idx_t &prev_left_index) { + const auto strict = MergeJoinStrictComparison(comparison); + + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_16: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_24: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_FIXED_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_16: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_24: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_FIXED_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + case SortKeyType::PAYLOAD_VARIABLE_32: + return TemplatedMergeJoinComplexBlocks(l, r, strict, prev_left_index); + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(sort_key_type)); + } +} + OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, OperatorState &state_p) const { auto &state = state_p.Cast(); auto &gstate = sink_state->Cast(); - auto &rsorted = *gstate.table->global_sort_state.sorted_blocks[0]; const auto left_cols = input.ColumnCount(); const auto tail_cols = conditions.size() - 1; - state.payload_heap_handles.clear(); do { if (state.first_fetch) { - state.ResolveJoinKeys(input); + state.ResolveJoinKeys(context, input); + state.lhs_payload.Verify(); state.right_chunk_index = 0; state.right_base = 0; @@ -588,36 +584,44 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte return OperatorResultType::NEED_MORE_INPUT; } - auto &lhs_table = *state.lhs_local_table; + auto &lhs_table = *state.lhs_global_table; const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - BlockMergeInfo left_info(*state.lhs_global_state, 0, state.left_position, lhs_not_null); + ChunkMergeInfo left_info(*state.lhs_iterator, 0, state.left_position, lhs_not_null); + + auto &rhs_table = *gstate.table; + auto &rhs_iterator = *state.rhs_iterator; + const auto rhs_not_null = SortedChunkNotNull(state.right_chunk_index, rhs_table.count, rhs_table.has_null); + ChunkMergeInfo right_info(rhs_iterator, state.right_chunk_index, state.right_position, rhs_not_null); - const auto &rblock = *rsorted.radix_sorting_data[state.right_chunk_index]; - const auto rhs_not_null = - SortedBlockNotNull(state.right_base, rblock.count, gstate.table->count - gstate.table->has_null); - BlockMergeInfo right_info(gstate.table->global_sort_state, state.right_chunk_index, state.right_position, - rhs_not_null); + // Repin so we don't hang on to data after we have scanned it + // Note we only do this for the RHS because the LHS is only one chunk. + rhs_table.Repin(rhs_iterator); - idx_t result_count = - MergeJoinComplexBlocks(left_info, right_info, conditions[0].comparison, state.prev_left_index); + idx_t result_count = MergeJoinComplexBlocks(state.sort_key_type, left_info, right_info, + conditions[0].comparison, state.prev_left_index); if (result_count == 0) { // exhausted this chunk on the right side // move to the next right chunk state.left_position = 0; state.right_position = 0; - state.right_base += rsorted.radix_sorting_data[state.right_chunk_index]->count; + state.right_base += STANDARD_VECTOR_SIZE; state.right_chunk_index++; - if (state.right_chunk_index >= rsorted.radix_sorting_data.size()) { + if (state.right_chunk_index >= rhs_table.BlockCount()) { state.finished = true; } } else { // found matches: extract them + SliceSortedPayload(state.rhs_input, rhs_table, rhs_iterator, state.rhs_chunk_state, right_info.block_idx, + right_info.result, result_count, *state.rhs_scan_state); + chunk.Reset(); - for (idx_t c = 0; c < state.lhs_payload.ColumnCount(); ++c) { - chunk.data[c].Slice(state.lhs_payload.data[c], left_info.result, result_count); + for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { + if (col_idx < left_cols) { + chunk.data[col_idx].Slice(state.lhs_payload.data[col_idx], left_info.result, result_count); + } else { + chunk.data[col_idx].Reference(state.rhs_input.data[col_idx - left_cols]); + } } - state.payload_heap_handles.push_back(SliceSortedPayload(chunk, right_info.state, right_info.block_idx, - right_info.result, result_count, left_cols)); chunk.SetCardinality(result_count); auto sel = FlatVector::IncrementalSelectionVector(); @@ -625,13 +629,12 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte // If there are more expressions to compute, // split the result chunk into the left and right halves // so we can compute the values for comparison. - chunk.Split(state.rhs_input, left_cols); state.rhs_executor.SetChunk(state.rhs_input); state.rhs_keys.Reset(); auto tail_count = result_count; for (size_t cmp_idx = 1; cmp_idx < conditions.size(); ++cmp_idx) { - Vector left(lhs_table.keys.data[cmp_idx]); + Vector left(state.lhs_local_table->keys.data[cmp_idx]); left.Slice(left_info.result, result_count); auto &right = state.rhs_keys.data[cmp_idx]; @@ -645,7 +648,6 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionConte SelectJoinTail(conditions[cmp_idx].comparison, left, right, sel, tail_count, &state.sel); sel = &state.sel; } - chunk.Fuse(state.rhs_input); if (tail_count < result_count) { result_count = tail_count; @@ -713,54 +715,78 @@ OperatorResultType PhysicalPiecewiseMergeJoin::ExecuteInternal(ExecutionContext //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -class PiecewiseJoinScanState : public GlobalSourceState { +class PiecewiseJoinGlobalScanState : public GlobalSourceState { +public: + explicit PiecewiseJoinGlobalScanState(TupleDataCollection &payload) : payload(payload), right_outer_position(0) { + payload.InitializeScan(parallel_scan); + } + + idx_t Scan(TupleDataLocalScanState &local_scan, DataChunk &chunk) { + lock_guard guard(lock); + const auto result = right_outer_position; + payload.Scan(parallel_scan, local_scan, chunk); + right_outer_position += chunk.size(); + return result; + } + + TupleDataCollection &payload; + public: - explicit PiecewiseJoinScanState(const PhysicalPiecewiseMergeJoin &op) : op(op), right_outer_position(0) { + idx_t MaxThreads() override { + return payload.ChunkCount(); } +private: mutex lock; - const PhysicalPiecewiseMergeJoin &op; - unique_ptr scanner; + TupleDataParallelScanState parallel_scan; idx_t right_outer_position; +}; +class PiecewiseJoinLocalScanState : public LocalSourceState { public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.Count() / (STANDARD_VECTOR_SIZE * idx_t(10)); + explicit PiecewiseJoinLocalScanState(PiecewiseJoinGlobalScanState &gstate) : rsel(STANDARD_VECTOR_SIZE) { + gstate.payload.InitializeScan(scanner); + gstate.payload.InitializeChunk(rhs_chunk); } + + TupleDataLocalScanState scanner; + DataChunk rhs_chunk; + SelectionVector rsel; }; unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); + auto &gsink = sink_state->Cast(); + return make_uniq(*gsink.table->sorted->payload_data); +} + +unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(gstate.Cast()); } SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { + OperatorSourceInput &source) const { D_ASSERT(PropagatesBuildSide(join_type)); // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &state = input.global_state.Cast(); - - lock_guard l(state.lock); - if (!state.scanner) { - // Initialize scanner (if not yet initialized) - auto &sort_state = sink.table->global_sort_state; - if (sort_state.sorted_blocks.empty()) { - return SourceResultType::FINISHED; - } - state.scanner = make_uniq(*sort_state.sorted_blocks[0]->payload_data, sort_state); + auto &gsink = sink_state->Cast(); + auto &gsource = source.global_state.Cast(); + + // RHS was empty, so nothing to do? + if (!gsink.table->count) { + return SourceResultType::FINISHED; } // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan the found_match for any chunks we // still need to output - const auto found_match = sink.table->found_match.get(); + const auto found_match = gsink.table->found_match.get(); - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), sink.table->global_sort_state.payload_layout.GetTypes()); - SelectionVector rsel(STANDARD_VECTOR_SIZE); + auto &lsource = source.local_state.Cast(); + auto &rhs_chunk = lsource.rhs_chunk; + auto &rsel = lsource.rsel; for (;;) { // Read the next sorted chunk - state.scanner->Scan(rhs_chunk); + rhs_chunk.Reset(); + const auto rhs_pos = gsource.Scan(lsource.scanner, rhs_chunk); const auto count = rhs_chunk.size(); if (count == 0) { @@ -770,11 +796,10 @@ SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, idx_t result_count = 0; // figure out which tuples didn't find a match in the RHS for (idx_t i = 0; i < count; i++) { - if (!found_match[state.right_outer_position + i]) { + if (!found_match[rhs_pos + i]) { rsel.set_index(result_count++, i); } } - state.right_outer_position += count; if (result_count > 0) { // if there were any tuples that didn't find a match, output them diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp index 4fefafbd4..f92a2bb88 100644 --- a/src/duckdb/src/execution/operator/join/physical_range_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_range_join.cpp @@ -1,10 +1,7 @@ #include "duckdb/execution/operator/join/physical_range_join.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types/validity_mask.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/unordered_map.hpp" @@ -14,15 +11,15 @@ #include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/thread_context.hpp" #include "duckdb/parallel/executor_task.hpp" - -#include +#include "duckdb/planner/expression/bound_reference_expression.hpp" namespace duckdb { -PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, +PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ExecutionContext &context, GlobalSortedTable &global_table, const idx_t child) - : op(op), executor(context), has_null(0), count(0) { + : global_table(global_table), executor(context.client), has_null(0), count(0) { // Initialize order clause expression executor and key DataChunk + const auto &op = global_table.op; vector types; for (const auto &cond : op.conditions) { const auto &expr = child ? cond.right : cond.left; @@ -30,15 +27,19 @@ PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, co types.push_back(expr->return_type); } - auto &allocator = Allocator::Get(context); + auto &allocator = Allocator::Get(context.client); keys.Initialize(allocator, types); + + local_sink = global_table.sort->GetLocalSinkState(context); + + // Only sort the primary key + types.resize(1); + const auto &payload_types = op.children[child].get().types; + types.insert(types.end(), payload_types.begin(), payload_types.end()); + sort_chunk.InitializeEmpty(types); } -void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { - // Initialize local state (if necessary) - if (!local_sort_state.initialized) { - local_sort_state.Initialize(global_sort_state, global_sort_state.buffer_manager); - } +void PhysicalRangeJoin::LocalSortedTable::Sink(ExecutionContext &context, DataChunk &input) { // Obtain sorting columns keys.Reset(); @@ -47,121 +48,180 @@ void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState // Do not operate on primary key directly to avoid modifying the input chunk Vector primary = keys.data[0]; // Count the NULLs so we can exclude them later - has_null += MergeNulls(primary, op.conditions); + has_null += MergeNulls(primary, global_table.op.conditions); count += keys.size(); // Only sort the primary key - DataChunk join_head; - join_head.data.emplace_back(primary); - join_head.SetCardinality(keys.size()); + sort_chunk.data[0].Reference(primary); + for (column_t col_idx = 0; col_idx < input.ColumnCount(); ++col_idx) { + sort_chunk.data[col_idx + 1].Reference(input.data[col_idx]); + } + sort_chunk.SetCardinality(input); // Sink the data into the local sort state - local_sort_state.SinkChunk(join_head, input); + InterruptState interrupt; + OperatorSinkInput sink {*global_table.global_sink, *local_sink, interrupt}; + global_table.sort->Sink(context, sort_chunk, sink); } -PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector &orders, - RowLayout &payload_layout, const PhysicalOperator &op_p) - : op(op_p), global_sort_state(context, orders, payload_layout), has_null(0), count(0), memory_per_thread(0) { +PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &client, + const vector &order_bys, + const vector &payload_types, + const PhysicalRangeJoin &op) + : op(op), has_null(0), count(0), tasks_completed(0) { + + // Set up the sort. We will materialize keys ourselves, so just set up references. + vector orders; + vector input_types; + for (const auto &order_by : order_bys) { + auto order = order_by.Copy(); + const auto type = order.expression->return_type; + input_types.emplace_back(type); + order.expression = make_uniq(type, orders.size()); + orders.emplace_back(std::move(order)); + } + + vector projection_map; + for (const auto &type : payload_types) { + projection_map.emplace_back(input_types.size()); + input_types.emplace_back(type); + } + + sort = make_uniq(client, orders, input_types, projection_map); - // Set external (can be forced with the PRAGMA) - global_sort_state.external = ClientConfig::GetConfig(context).force_external; - memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); + global_sink = sort->GetGlobalSinkState(client); } -void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { - global_sort_state.AddLocalState(ltable.local_sort_state); +void PhysicalRangeJoin::GlobalSortedTable::Combine(ExecutionContext &context, LocalSortedTable <able) { + InterruptState interrupt; + OperatorSinkCombineInput combine {*global_sink, *ltable.local_sink, interrupt}; + sort->Combine(context, combine); has_null += ltable.has_null; count += ltable.count; } +void PhysicalRangeJoin::GlobalSortedTable::Finalize(ClientContext &client, InterruptState &interrupt) { + OperatorSinkFinalizeInput finalize {*global_sink, interrupt}; + sort->Finalize(client, finalize); +} + void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { found_match = make_unsafe_uniq_array_uninitialized(Count()); memset(found_match.get(), 0, sizeof(bool) * Count()); } +void PhysicalRangeJoin::GlobalSortedTable::MaterializeEmpty(ClientContext &client) { + D_ASSERT(!sorted); + sorted = make_uniq(client, *sort, false); +} + void PhysicalRangeJoin::GlobalSortedTable::Print() { - global_sort_state.Print(); + D_ASSERT(sorted); + auto &collection = *sorted->payload_data; + TupleDataScanState scanner; + collection.InitializeScan(scanner); + + DataChunk payload; + collection.InitializeScanChunk(scanner, payload); + + while (collection.Scan(scanner, payload)) { + payload.Print(); + } } -class RangeJoinMergeTask : public ExecutorTask { +//===--------------------------------------------------------------------===// +// RangeJoinMaterializeTask +//===--------------------------------------------------------------------===// +class RangeJoinMaterializeTask : public ExecutorTask { public: using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - RangeJoinMergeTask(shared_ptr event_p, ClientContext &context, GlobalSortedTable &table) - : ExecutorTask(context, std::move(event_p), table.op), context(context), table(table) { + RangeJoinMaterializeTask(Pipeline &pipeline, shared_ptr event, ClientContext &client, + GlobalSortedTable &table, idx_t tasks_scheduled) + : ExecutorTask(client, std::move(event), table.op), pipeline(pipeline), table(table), + tasks_scheduled(tasks_scheduled) { } TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - // Initialize iejoin sorted and iterate until done - auto &global_sort_state = table.global_sort_state; - MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); - merge_sorter.PerformInMergeRound(); - event->FinishTask(); + ExecutionContext execution(pipeline.GetClientContext(), *thread_context, &pipeline); + auto &sort = *table.sort; + auto &sort_global = *table.global_source; + auto sort_local = sort.GetLocalSourceState(execution, sort_global); + InterruptState interrupt((weak_ptr(shared_from_this()))); + OperatorSourceInput input {sort_global, *sort_local, interrupt}; + sort.MaterializeSortedRun(execution, input); + if (++table.tasks_completed == tasks_scheduled) { + table.sorted = sort.GetSortedRun(sort_global); + if (!table.sorted) { + table.MaterializeEmpty(execution.client); + } + } + event->FinishTask(); return TaskExecutionResult::TASK_FINISHED; } string TaskType() const override { - return "RangeJoinMergeTask"; + return "RangeJoinMaterializeTask"; } private: - ClientContext &context; + Pipeline &pipeline; GlobalSortedTable &table; + const idx_t tasks_scheduled; }; -class RangeJoinMergeEvent : public BasePipelineEvent { +//===--------------------------------------------------------------------===// +// RangeJoinMaterializeEvent +//===--------------------------------------------------------------------===// +class RangeJoinMaterializeEvent : public BasePipelineEvent { public: using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; public: - RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), table(table_p) { + RangeJoinMaterializeEvent(GlobalSortedTable &table, Pipeline &pipeline) + : BasePipelineEvent(pipeline), table(table) { } GlobalSortedTable &table; public: void Schedule() override { - auto &context = pipeline->GetClientContext(); + auto &client = pipeline->GetClientContext(); - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); + // Schedule as many tasks as the sort will allow + auto &ts = TaskScheduler::GetScheduler(client); auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> iejoin_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - iejoin_tasks.push_back(make_uniq(shared_from_this(), context, table)); + vector> tasks; + + auto &sort = *table.sort; + auto &global_sink = *table.global_sink; + table.global_source = sort.GetGlobalSourceState(client, global_sink); + const auto tasks_scheduled = MinValue(num_threads, table.global_source->MaxThreads()); + for (idx_t tnum = 0; tnum < tasks_scheduled; ++tnum) { + tasks.push_back( + make_uniq(*pipeline, shared_from_this(), client, table, tasks_scheduled)); } - SetTasks(std::move(iejoin_tasks)); - } - - void FinishEvent() override { - auto &global_sort_state = table.global_sort_state; - global_sort_state.CompleteMergeRound(true); - if (global_sort_state.sorted_blocks.size() > 1) { - // Multiple blocks remaining: Schedule the next round - table.ScheduleMergeTasks(*pipeline, *this); - } + SetTasks(std::move(tasks)); } }; -void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { - // Initialize global sort state for a round of merging - global_sort_state.InitializeMergeRound(); - auto new_event = make_shared_ptr(*this, pipeline); - event.InsertEvent(std::move(new_event)); +void PhysicalRangeJoin::GlobalSortedTable::Materialize(Pipeline &pipeline, Event &event) { + // Schedule all the sorts for maximum thread utilisation + auto sort_event = make_shared_ptr(*this, pipeline); + event.InsertEvent(std::move(sort_event)); } -void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { - // Prepare for merge sort phase - global_sort_state.PrepareMergePhase(); - - // Start the merge phase or finish if a merge is not necessary - if (global_sort_state.sorted_blocks.size() > 1) { - ScheduleMergeTasks(pipeline, event); +void PhysicalRangeJoin::GlobalSortedTable::Materialize(ExecutionContext &context, InterruptState &interrupt) { + global_source = sort->GetGlobalSourceState(context.client, *global_sink); + auto local_source = sort->GetLocalSourceState(context, *global_source); + OperatorSourceInput source {*global_source, *local_source, interrupt}; + sort->MaterializeSortedRun(context, source); + sorted = sort->GetSortedRun(*global_source); + if (!sorted) { + MaterializeEmpty(context.client); } } @@ -336,56 +396,75 @@ void PhysicalRangeJoin::ProjectResult(DataChunk &chunk, DataChunk &result) const result.SetCardinality(chunk); } -BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols) { - // There should only be one sorted block if they have been sorted - D_ASSERT(state.sorted_blocks.size() == 1); - SBScanState read_state(state.buffer_manager, state); - read_state.sb = state.sorted_blocks[0].get(); - auto &sorted_data = *read_state.sb->payload_data; - - read_state.SetIndices(block_idx, 0); - read_state.PinData(sorted_data); - const auto data_ptr = read_state.DataPtr(sorted_data); - data_ptr_t heap_ptr = nullptr; - - // Set up a batch of pointers to scan data from - Vector addresses(LogicalType::POINTER, result_count); - auto data_pointers = FlatVector::GetData(addresses); - - // Set up the data pointers for the values that are actually referenced - const idx_t &row_width = sorted_data.layout.GetRowWidth(); - - auto prev_idx = result.get_index(0); - SelectionVector gsel(result_count); - idx_t addr_count = 0; - gsel.set_index(0, addr_count); - data_pointers[addr_count] = data_ptr + prev_idx * row_width; - for (idx_t i = 1; i < result_count; ++i) { - const auto row_idx = result.get_index(i); - if (row_idx != prev_idx) { - data_pointers[++addr_count] = data_ptr + row_idx * row_width; - prev_idx = row_idx; - } - gsel.set_index(i, addr_count); +template +static void TemplatedSliceSortedPayload(DataChunk &chunk, const SortedRun &sorted_run, + ExternalBlockIteratorState &state, Vector &sort_key_pointers, + SortedRunScanState &scan_state, const idx_t chunk_idx, SelectionVector &result, + const idx_t result_count) { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + BLOCK_ITERATOR itr(state, chunk_idx, 0); + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + for (idx_t i = 0; i < result_count; ++i) { + const auto idx = state.GetIndex(chunk_idx, result.get_index(i)); + sort_keys[i] = &itr[idx]; } - ++addr_count; - // Unswizzle the offsets back to pointers (if needed) - if (!sorted_data.layout.AllConstant() && state.external) { - heap_ptr = read_state.payload_heap_handle.Ptr(); - } + // Scan + chunk.Reset(); + scan_state.Scan(sorted_run, sort_key_pointers, result_count, chunk); +} - // Deserialize the payload data - auto sel = FlatVector::IncrementalSelectionVector(); - for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { - auto &col = payload.data[left_cols + col_no]; - RowOperations::Gather(addresses, *sel, col, *sel, addr_count, sorted_data.layout, col_no, 0, heap_ptr); - col.Slice(gsel, result_count); +void PhysicalRangeJoin::SliceSortedPayload(DataChunk &chunk, GlobalSortedTable &table, + ExternalBlockIteratorState &state, TupleDataChunkState &chunk_state, + const idx_t chunk_idx, SelectionVector &result, const idx_t result_count, + SortedRunScanState &scan_state) { + + auto &sorted = *table.sorted; + auto &sort_keys = chunk_state.row_locations; + const auto sort_key_type = table.GetSortKeyType(); + + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_16: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_24: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_FIXED_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + TemplatedSliceSortedPayload(chunk, sorted, state, sort_keys, scan_state, + chunk_idx, result, result_count); + break; + default: + throw NotImplementedException("MergeJoinSimpleBlocks for %s", EnumUtil::ToString(sort_key_type)); } - - return std::move(read_state.payload_heap_handle); } idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp index 95b519d4d..025af5ba1 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp @@ -90,7 +90,7 @@ class CollectionMerger { auto &collection = data_table.GetOptimisticCollection(context, collection_indexes[i]); TableScanState scan_state; scan_state.Initialize(column_ids); - collection.collection->InitializeScan(scan_state.local_state, column_ids, nullptr); + collection.collection->InitializeScan(context, scan_state.local_state, column_ids, nullptr); while (true) { scan_chunk.Reset(); @@ -194,7 +194,10 @@ class BatchInsertLocalState : public LocalSinkState { void CreateNewCollection(ClientContext &context, DuckTableEntry &table_entry, const vector &insert_types) { - auto collection = OptimisticDataWriter::CreateCollection(table_entry.GetStorage(), insert_types); + if (!optimistic_writer) { + optimistic_writer = make_uniq(context, table_entry.GetStorage()); + } + auto collection = optimistic_writer->CreateCollection(table_entry.GetStorage(), insert_types); auto &row_collection = *collection->collection; row_collection.InitializeEmpty(); row_collection.InitializeAppend(current_append_state); @@ -526,9 +529,6 @@ SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &i lock_guard l(gstate.lock); // no collection yet: create a new one lstate.CreateNewCollection(context.client, table, insert_types); - if (!lstate.optimistic_writer) { - lstate.optimistic_writer = make_uniq(context.client, table.GetStorage()); - } } if (lstate.current_index != batch_index) { diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp index 97c31c4ba..8c152c2b4 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp @@ -651,14 +651,14 @@ SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &insert D_ASSERT(!return_chunk); auto &data_table = gstate.table.GetStorage(); if (!lstate.collection_index.IsValid()) { + lock_guard l(gstate.lock); + lstate.optimistic_writer = make_uniq(context.client, data_table); // Create the local row group collection. - auto optimistic_collection = OptimisticDataWriter::CreateCollection(storage, insert_types); + auto optimistic_collection = lstate.optimistic_writer->CreateCollection(storage, insert_types); auto &collection = *optimistic_collection->collection; collection.InitializeEmpty(); collection.InitializeAppend(lstate.local_append_state); - lock_guard l(gstate.lock); - lstate.optimistic_writer = make_uniq(context.client, data_table); lstate.collection_index = data_table.CreateOptimisticCollection(context.client, std::move(optimistic_collection)); } diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp index 5759583c5..c457e9826 100644 --- a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp @@ -27,7 +27,7 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera // // ∏ * \ pk // | - // Γ pk;first(P),arg_xxx(B,inequality) + // Γ pk;first(P),arg_xxx_null(B,inequality) // | // ∏ *,inequality // | @@ -88,13 +88,13 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera case ExpressionType::COMPARE_GREATERTHAN: D_ASSERT(asof_idx == op.conditions.size()); asof_idx = i; - arg_min_max = "arg_max"; + arg_min_max = "arg_max_null"; break; case ExpressionType::COMPARE_LESSTHANOREQUALTO: case ExpressionType::COMPARE_LESSTHAN: D_ASSERT(asof_idx == op.conditions.size()); asof_idx = i; - arg_min_max = "arg_min"; + arg_min_max = "arg_min_null"; break; case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_NOTEQUAL: diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index bde4c1479..64135df05 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -10,6 +10,7 @@ #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/parser/expression_map.hpp" #include "duckdb/parallel/thread_context.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp index 0c0c1c058..558329f70 100644 --- a/src/duckdb/src/function/cast/default_casts.cpp +++ b/src/duckdb/src/function/cast/default_casts.cpp @@ -162,6 +162,8 @@ BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const L return EnumCastSwitch(input, source, target); case LogicalTypeId::ARRAY: return ArrayCastSwitch(input, source, target); + case LogicalTypeId::GEOMETRY: + return GeoCastSwitch(input, source, target); case LogicalTypeId::BIGNUM: return BignumCastSwitch(input, source, target); case LogicalTypeId::AGGREGATE_STATE: diff --git a/src/duckdb/src/function/cast/geo_casts.cpp b/src/duckdb/src/function/cast/geo_casts.cpp new file mode 100644 index 000000000..b57b4cb9e --- /dev/null +++ b/src/duckdb/src/function/cast/geo_casts.cpp @@ -0,0 +1,24 @@ +#include "duckdb/common/types/geometry.hpp" +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/function/cast/vector_cast_helpers.hpp" + +namespace duckdb { + +static bool GeometryToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + UnaryExecutor::Execute( + source, result, count, [&](const string_t &input) -> string_t { return Geometry::ToString(result, input); }); + return true; +} + +BoundCastInfo DefaultCasts::GeoCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { + + // now switch on the result type + switch (target.id()) { + case LogicalTypeId::VARCHAR: + return GeometryToVarcharCast; + default: + return TryVectorNullCast; + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp index 511d09a86..930231808 100644 --- a/src/duckdb/src/function/cast/string_cast.cpp +++ b/src/duckdb/src/function/cast/string_cast.cpp @@ -490,6 +490,8 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::UUID: return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::GEOMETRY: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::SQLNULL: return &DefaultCasts::TryVectorNullCast; case LogicalTypeId::VARCHAR: diff --git a/src/duckdb/src/function/cast/variant/from_variant.cpp b/src/duckdb/src/function/cast/variant/from_variant.cpp index ca377b326..aa129b463 100644 --- a/src/duckdb/src/function/cast/variant/from_variant.cpp +++ b/src/duckdb/src/function/cast/variant/from_variant.cpp @@ -1,3 +1,4 @@ +#include "yyjson_utils.hpp" #include "duckdb/function/cast/default_casts.hpp" #include "duckdb/common/types/variant.hpp" #include "duckdb/function/scalar/variant_utils.hpp" @@ -49,22 +50,6 @@ struct DecimalConversionPayloadFromVariant { idx_t scale; }; -struct ConvertedJSONHolder { -public: - ~ConvertedJSONHolder() { - if (doc) { - yyjson_mut_doc_free(doc); - } - if (stringified_json) { - free(stringified_json); - } - } - -public: - yyjson_mut_doc *doc = nullptr; - char *stringified_json = nullptr; -}; - } // namespace //===--------------------------------------------------------------------===// @@ -550,6 +535,11 @@ static bool CastVariant(FromVariantConversionData &conversion_data, Vector &resu return CastVariantToPrimitive>( conversion_data, result, sel, offset, count, row, string_payload); } + case LogicalTypeId::GEOMETRY: { + StringConversionPayload string_payload(result); + return CastVariantToPrimitive>( + conversion_data, result, sel, offset, count, row, string_payload); + } case LogicalTypeId::VARCHAR: { if (target_type.IsJSONType()) { return CastVariantToJSON(conversion_data, result, sel, offset, count, row); @@ -686,6 +676,8 @@ BoundCastInfo DefaultCasts::VariantCastSwitch(BindCastInput &input, const Logica case LogicalTypeId::UUID: case LogicalTypeId::ARRAY: return BoundCastInfo(CastFromVARIANT); + case LogicalTypeId::GEOMETRY: + return BoundCastInfo(CastFromVARIANT); case LogicalTypeId::VARCHAR: { return BoundCastInfo(CastFromVARIANT); } diff --git a/src/duckdb/src/function/cast/variant/to_json.cpp b/src/duckdb/src/function/cast/variant/to_json.cpp index 9d35c142c..85d957867 100644 --- a/src/duckdb/src/function/cast/variant/to_json.cpp +++ b/src/duckdb/src/function/cast/variant/to_json.cpp @@ -149,6 +149,12 @@ yyjson_mut_val *VariantCasts::ConvertVariantToJSON(yyjson_mut_doc *doc, const Re auto val_str = Value::BLOB(const_data_ptr_cast(string_data), string_length).ToString(); return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); } + case VariantLogicalType::GEOMETRY: { + auto string_length = VarintDecode(ptr); + auto string_data = reinterpret_cast(ptr); + auto val_str = Value::GEOMETRY(const_data_ptr_cast(string_data), string_length).ToString(); + return yyjson_mut_strncpy(doc, val_str.c_str(), val_str.size()); + } case VariantLogicalType::VARCHAR: { auto string_length = VarintDecode(ptr); auto string_data = reinterpret_cast(ptr); diff --git a/src/duckdb/src/function/scalar/variant/variant_utils.cpp b/src/duckdb/src/function/scalar/variant/variant_utils.cpp index 44a370251..05e96a905 100644 --- a/src/duckdb/src/function/scalar/variant/variant_utils.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_utils.cpp @@ -22,9 +22,19 @@ VariantDecimalData VariantUtils::DecodeDecimalData(const UnifiedVariantVectorDat VariantDecimalData result; result.width = VarintDecode(ptr); result.scale = VarintDecode(ptr); + result.value_ptr = ptr; return result; } +string_t VariantUtils::DecodeStringData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { + auto byte_offset = variant.GetByteOffset(row, value_index); + auto data = const_data_ptr_cast(variant.GetData(row).GetData()); + auto ptr = data + byte_offset; + + auto length = VarintDecode(ptr); + return string_t(reinterpret_cast(ptr), length); +} + VariantNestedData VariantUtils::DecodeNestedData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { D_ASSERT(IsNestedType(variant, row, value_index)); diff --git a/src/duckdb/src/function/table/direct_file_reader.cpp b/src/duckdb/src/function/table/direct_file_reader.cpp index 8aa6aba35..eacfe1de1 100644 --- a/src/duckdb/src/function/table/direct_file_reader.cpp +++ b/src/duckdb/src/function/table/direct_file_reader.cpp @@ -64,7 +64,7 @@ void DirectFileReader::Scan(ClientContext &context, GlobalTableFunctionState &gl if (FileSystem::IsRemoteFile(file.path)) { flags |= FileFlags::FILE_FLAGS_DIRECT_IO; } - file_handle = fs.OpenFile(QueryContext(context), file, flags); + file_handle = fs.OpenFile(context, file, flags); } for (idx_t col_idx = 0; col_idx < state.column_ids.size(); col_idx++) { diff --git a/src/duckdb/src/function/table/system/duckdb_connection_count.cpp b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp new file mode 100644 index 000000000..ce7857f3b --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_connection_count.cpp @@ -0,0 +1,45 @@ +#include "duckdb/function/table/system_functions.hpp" + +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/connection_manager.hpp" + +namespace duckdb { + +struct DuckDBConnectionCountData : public GlobalTableFunctionState { + DuckDBConnectionCountData() : count(0), finished(false) { + } + idx_t count; + bool finished; +}; + +static unique_ptr DuckDBConnectionCountBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + names.emplace_back("count"); + return_types.emplace_back(LogicalType::UBIGINT); + return nullptr; +} + +unique_ptr DuckDBConnectionCountInit(ClientContext &context, TableFunctionInitInput &input) { + auto result = make_uniq(); + auto &conn_manager = context.db->GetConnectionManager(); + result->count = conn_manager.GetConnectionCount(); + return std::move(result); +} + +void DuckDBConnectionCountFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.finished) { + return; + } + output.SetValue(0, 0, Value::UBIGINT(data.count)); + output.SetCardinality(1); + data.finished = true; +} + +void DuckDBConnectionCountFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_connection_count", {}, DuckDBConnectionCountFunction, + DuckDBConnectionCountBind, DuckDBConnectionCountInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_storage_info.cpp b/src/duckdb/src/function/table/system/pragma_storage_info.cpp index 5500c1c5d..7ba6cfd69 100644 --- a/src/duckdb/src/function/table/system/pragma_storage_info.cpp +++ b/src/duckdb/src/function/table/system/pragma_storage_info.cpp @@ -88,7 +88,7 @@ static unique_ptr PragmaStorageInfoBind(ClientContext &context, Ta Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); auto &table_entry = Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); auto result = make_uniq(table_entry); - result->column_segments_info = table_entry.GetColumnSegmentInfo(); + result->column_segments_info = table_entry.GetColumnSegmentInfo(context); return std::move(result); } diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp index d10ec5d31..0a6a03507 100644 --- a/src/duckdb/src/function/table/system_functions.cpp +++ b/src/duckdb/src/function/table/system_functions.cpp @@ -18,6 +18,7 @@ void BuiltinFunctions::RegisterSQLiteFunctions() { PragmaDatabaseSize::RegisterFunction(*this); PragmaUserAgent::RegisterFunction(*this); + DuckDBConnectionCountFun::RegisterFunction(*this); DuckDBApproxDatabaseCountFun::RegisterFunction(*this); DuckDBColumnsFun::RegisterFunction(*this); DuckDBConstraintsFun::RegisterFunction(*this); diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index c3f5f0a6b..4ed97cc23 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev383" +#define DUCKDB_PATCH_VERSION "0-dev720" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev383" +#define DUCKDB_VERSION "v1.5.0-dev720" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "07d170f87e" +#define DUCKDB_SOURCE_ID "5657cbdc0b" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/function/window/window_merge_sort_tree.cpp b/src/duckdb/src/function/window/window_merge_sort_tree.cpp index 6af3d0e5b..5943e6228 100644 --- a/src/duckdb/src/function/window/window_merge_sort_tree.cpp +++ b/src/duckdb/src/function/window/window_merge_sort_tree.cpp @@ -1,5 +1,8 @@ #include "duckdb/function/window/window_merge_sort_tree.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/common/types/column/column_data_collection.hpp" #include #include diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp index af70521a0..e9898762a 100644 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ b/src/duckdb/src/function/window/window_rank_function.cpp @@ -1,6 +1,7 @@ #include "duckdb/function/window/window_rank_function.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" #include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp index f0929d642..8b5401c5d 100644 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ b/src/duckdb/src/function/window/window_rownumber_function.cpp @@ -1,6 +1,7 @@ #include "duckdb/function/window/window_rownumber_function.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" #include "duckdb/function/window/window_token_tree.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp index 0cf71fa73..7f206a43d 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/duck_table_entry.hpp @@ -47,7 +47,7 @@ class DuckTableEntry : public TableCatalogEntry { TableFunction GetScanFunction(ClientContext &context, unique_ptr &bind_data) override; - vector GetColumnSegmentInfo() override; + vector GetColumnSegmentInfo(const QueryContext &context) override; TableStorageInfo GetStorageInfo(ClientContext &context) override; diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index 5cab72c59..1e40319cf 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -111,7 +111,7 @@ class TableCatalogEntry : public StandardEntry { static string ColumnNamesToSQL(const ColumnList &columns); //! Returns a list of segment information for this table, if exists - virtual vector GetColumnSegmentInfo(); + virtual vector GetColumnSegmentInfo(const QueryContext &context); //! Returns the storage info of this table virtual TableStorageInfo GetStorageInfo(ClientContext &context) = 0; diff --git a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp index e2265c8c7..f3a71b594 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp @@ -19,7 +19,7 @@ struct DefaultType { LogicalTypeId type; }; -using builtin_type_array = std::array; +using builtin_type_array = std::array; static constexpr const builtin_type_array BUILTIN_TYPES{{ {"decimal", LogicalTypeId::DECIMAL}, @@ -97,7 +97,8 @@ static constexpr const builtin_type_array BUILTIN_TYPES{{ {"real", LogicalTypeId::FLOAT}, {"float4", LogicalTypeId::FLOAT}, {"double", LogicalTypeId::DOUBLE}, - {"float8", LogicalTypeId::DOUBLE} + {"float8", LogicalTypeId::DOUBLE}, + {"geometry", LogicalTypeId::GEOMETRY} }}; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index d07e93d02..b43824c46 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -202,6 +202,8 @@ enum class FunctionStability : uint8_t; enum class GateStatus : uint8_t; +enum class GeometryType : uint32_t; + enum class HLLStorageType : uint8_t; enum class HTTPStatusCode : uint16_t; @@ -706,6 +708,9 @@ const char* EnumUtil::ToChars(FunctionStability value); template<> const char* EnumUtil::ToChars(GateStatus value); +template<> +const char* EnumUtil::ToChars(GeometryType value); + template<> const char* EnumUtil::ToChars(HLLStorageType value); @@ -1334,6 +1339,9 @@ FunctionStability EnumUtil::FromString(const char *value); template<> GateStatus EnumUtil::FromString(const char *value); +template<> +GeometryType EnumUtil::FromString(const char *value); + template<> HLLStorageType EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp index 8fd2790ab..bb0760897 100644 --- a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp @@ -74,6 +74,7 @@ enum class MetricsType : uint8_t { OPTIMIZER_SUM_REWRITER, OPTIMIZER_LATE_MATERIALIZATION, OPTIMIZER_CTE_INLINING, + OPTIMIZER_COMMON_SUBPLAN, }; struct MetricsTypeHashFunction { diff --git a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp index b57823028..36bb672c9 100644 --- a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp @@ -42,7 +42,8 @@ enum class OptimizerType : uint32_t { MATERIALIZED_CTE, SUM_REWRITER, LATE_MATERIALIZATION, - CTE_INLINING + CTE_INLINING, + COMMON_SUBPLAN, }; string OptimizerTypeToString(OptimizerType type); diff --git a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp index d5e35ee96..60d84686d 100644 --- a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp +++ b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp @@ -28,7 +28,8 @@ enum class ExtraTypeInfoType : uint8_t { ARRAY_TYPE_INFO = 9, ANY_TYPE_INFO = 10, INTEGER_LITERAL_TYPE_INFO = 11, - TEMPLATE_TYPE_INFO = 12 + TEMPLATE_TYPE_INFO = 12, + GEO_TYPE_INFO = 13 }; struct ExtraTypeInfo { @@ -278,4 +279,16 @@ struct TemplateTypeInfo : public ExtraTypeInfo { TemplateTypeInfo(); }; +struct GeoTypeInfo : public ExtraTypeInfo { +public: + GeoTypeInfo(); + + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + shared_ptr Copy() const override; + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp index ac55e5a69..e495e9760 100644 --- a/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/cast_operators.hpp @@ -1070,6 +1070,19 @@ bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, Vector &res template <> bool TryCastBlobToUUID::Operation(string_t input, hugeint_t &result, bool strict); +//===--------------------------------------------------------------------===// +// GEOMETRY +//===--------------------------------------------------------------------===// +struct TryCastToGeometry { + template + static inline bool Operation(SRC input, DST &result, Vector &result_vector, CastParameters ¶meters) { + throw InternalException("Unsupported type for try cast to geometry"); + } +}; + +template <> +bool TryCastToGeometry::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters); + //===--------------------------------------------------------------------===// // Pointers //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp index 597b8261b..de1e33f3b 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp @@ -8,25 +8,44 @@ #pragma once -#include "duckdb/common/sorting/sorted_run.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" #include "duckdb/execution/physical_operator_states.hpp" +#include "duckdb/execution/progress_data.hpp" #include "duckdb/common/sorting/sort_projection_column.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" namespace duckdb { class SortLocalSinkState; class SortGlobalSinkState; + class SortLocalSourceState; class SortGlobalSourceState; +class SortedRun; +class SortedRunScanState; + +class SortedRunMerger; +class SortedRunMergerLocalState; +class SortedRunMergerGlobalState; + +class TupleDataLayout; +class ColumnDataCollection; + //! Class that sorts the data, follows the PhysicalOperator interface class Sort { friend class SortLocalSinkState; friend class SortGlobalSinkState; + friend class SortLocalSourceState; friend class SortGlobalSourceState; + friend class SortedRun; + friend class SortedRunScanState; + + friend class SortedRunMerger; + friend class SortedRunMergerLocalState; + friend class SortedRunMergerGlobalState; + public: Sort(ClientContext &context, const vector &orders, const vector &input_types, vector projection_map, bool is_index_sort = false); @@ -45,7 +64,7 @@ class Sort { vector input_projection_map; vector output_projection_columns; - //! Whether to force an external sort + //! Whether to force an approximate sort bool is_index_sort; public: diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp index fe0d67e32..53ce1c5c0 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp @@ -9,18 +9,40 @@ #pragma once #include "duckdb/common/types/row/tuple_data_states.hpp" +#include "duckdb/execution/expression_executor.hpp" namespace duckdb { +class Sort; +class SortedRun; class BufferManager; class DataChunk; class TupleDataCollection; class TupleDataLayout; +class SortedRunScanState { +public: + SortedRunScanState(ClientContext &context, const Sort &sort); + +public: + void Scan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, DataChunk &chunk); + +private: + template + void TemplatedScan(const SortedRun &sorted_run, const Vector &sort_key_pointers, const idx_t &count, + DataChunk &chunk); + +private: + const Sort &sort; + ExpressionExecutor key_executor; + DataChunk key; + DataChunk decoded_key; + TupleDataScanState payload_state; +}; + class SortedRun { public: - SortedRun(ClientContext &context, shared_ptr key_layout, - shared_ptr payload_layout, bool is_index_sort); + SortedRun(ClientContext &context, const Sort &sort, bool is_index_sort); unique_ptr CreateRunForMaterialization() const; ~SortedRun(); @@ -36,8 +58,13 @@ class SortedRun { //! Size of this sorted run idx_t SizeInBytes() const; +private: + mutex merger_global_state_lock; + unique_ptr merge_global_state; + public: ClientContext &context; + const Sort &sort; //! Key and payload collections (and associated append states) unique_ptr key_data; diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp index 21a56df83..fd894d698 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp @@ -9,10 +9,10 @@ #pragma once #include "duckdb/execution/physical_operator_states.hpp" -#include "duckdb/common/sorting/sort_projection_column.hpp" namespace duckdb { +class Sort; class TupleDataLayout; struct BoundOrderByNode; struct ProgressData; @@ -24,9 +24,7 @@ class SortedRunMerger { friend class SortedRunMergerGlobalState; public: - SortedRunMerger(const Expression &decode_sort_key, shared_ptr key_layout, - vector> &&sorted_runs, - const vector &output_projection_columns, idx_t partition_size, bool external, + SortedRunMerger(const Sort &sort, vector> &&sorted_runs, idx_t partition_size, bool external, bool is_index_sort); public: @@ -44,14 +42,12 @@ class SortedRunMerger { //===--------------------------------------------------------------------===// // Non-Standard Interface //===--------------------------------------------------------------------===// - SourceResultType MaterializeMerge(ExecutionContext &context, OperatorSourceInput &input) const; - unique_ptr GetMaterialized(GlobalSourceState &global_state); + SourceResultType MaterializeSortedRun(ExecutionContext &context, OperatorSourceInput &input) const; + unique_ptr GetSortedRun(GlobalSourceState &global_state); public: - const Expression &decode_sort_key; - shared_ptr key_layout; + const Sort &sort; vector> sorted_runs; - const vector &output_projection_columns; const idx_t total_count; const idx_t partition_size; diff --git a/src/duckdb/src/include/duckdb/common/string_util.hpp b/src/duckdb/src/include/duckdb/common/string_util.hpp index 8c0c19bef..1448c559a 100644 --- a/src/duckdb/src/include/duckdb/common/string_util.hpp +++ b/src/duckdb/src/include/duckdb/common/string_util.hpp @@ -318,6 +318,8 @@ class StringUtil { //! Transforms an complex JSON to a JSON string DUCKDB_API static string ToComplexJSONMap(const ComplexJSON &complex_json); + DUCKDB_API static string ValidateJSON(const char *data, const idx_t &len); + DUCKDB_API static string GetFileName(const string &file_path); DUCKDB_API static string GetFileExtension(const string &file_name); DUCKDB_API static string GetFileStem(const string &file_name); diff --git a/src/duckdb/src/include/duckdb/common/types.hpp b/src/duckdb/src/include/duckdb/common/types.hpp index 0f7ddbb2d..6d85ce2de 100644 --- a/src/duckdb/src/include/duckdb/common/types.hpp +++ b/src/duckdb/src/include/duckdb/common/types.hpp @@ -230,6 +230,8 @@ enum class LogicalTypeId : uint8_t { VALIDITY = 53, UUID = 54, + GEOMETRY = 60, + STRUCT = 100, LIST = 101, MAP = 102, @@ -430,6 +432,7 @@ struct LogicalType { DUCKDB_API static LogicalType UNION(child_list_t members); // NOLINT DUCKDB_API static LogicalType ARRAY(const LogicalType &child, optional_idx index); // NOLINT DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT + DUCKDB_API static LogicalType GEOMETRY(); // NOLINT // ANY but with special rules (default is LogicalType::ANY, 5) DUCKDB_API static LogicalType ANY_PARAMS(LogicalType target, idx_t cast_score = 5); // NOLINT DUCKDB_API static LogicalType TEMPLATE(const string &name); // NOLINT diff --git a/src/duckdb/src/include/duckdb/common/types/geometry.hpp b/src/duckdb/src/include/duckdb/common/types/geometry.hpp new file mode 100644 index 000000000..2cca6fe29 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/geometry.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/time.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types.hpp" + +namespace duckdb { + +enum class GeometryType : uint32_t { + INVALID = 0, + POINT = 1, + LINESTRING = 2, + POLYGON = 3, + MULTIPOINT = 4, + MULTILINESTRING = 5, + MULTIPOLYGON = 6, + GEOMETRYCOLLECTION = 7, +}; + +class Geometry { +public: + static constexpr auto MAX_RECURSION_DEPTH = 16; + + DUCKDB_API static bool FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict); + DUCKDB_API static string_t ToString(Vector &result, const string_t &geom); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp b/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp index c29b094a8..f9ec233eb 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/block_iterator.hpp @@ -183,15 +183,14 @@ class ExternalBlockIteratorState { key_scan_state.pin_state.row_handles.acquire_handles(pins); key_scan_state.pin_state.heap_handles.acquire_handles(pins); } - key_data.FetchChunk(key_scan_state, 0, chunk_idx, false); + key_data.FetchChunk(key_scan_state, chunk_idx, false); if (pin_payload && payload_data) { if (keep_pinned) { payload_scan_state.pin_state.row_handles.acquire_handles(pins); payload_scan_state.pin_state.heap_handles.acquire_handles(pins); } - const auto chunk_count = payload_data->FetchChunk(payload_scan_state, 0, chunk_idx, false); + const auto chunk_count = payload_data->FetchChunk(payload_scan_state, chunk_idx, false); const auto sort_keys = reinterpret_cast(key_ptrs); - payload_data->FetchChunk(payload_scan_state, 0, chunk_idx, false); const auto payload_ptrs = FlatVector::GetData(payload_scan_state.chunk_state.row_locations); for (idx_t i = 0; i < chunk_count; i++) { sort_keys[i]->SetPayload(payload_ptrs[i]); diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp index d759341ee..e0f0a0fd2 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp @@ -50,6 +50,7 @@ class TupleDataCollection { public: //! Constructs a TupleDataCollection with the specified layout TupleDataCollection(BufferManager &buffer_manager, shared_ptr layout_ptr); + TupleDataCollection(ClientContext &context, shared_ptr layout_ptr); ~TupleDataCollection(); @@ -185,8 +186,8 @@ class TupleDataCollection { //! Initialize a parallel scan over the tuple data collection over a subset of the columns void InitializeScan(TupleDataParallelScanState &gstate, vector column_ids, TupleDataPinProperties properties = TupleDataPinProperties::UNPIN_AFTER_DONE) const; - //! Grab the chunk state for the given segment and chunk index, returns the count of the chunk - idx_t FetchChunk(TupleDataScanState &state, idx_t segment_idx, idx_t chunk_idx, bool init_heap); + //! Grab the chunk state for the given chunk index, returns the count of the chunk + idx_t FetchChunk(TupleDataScanState &state, idx_t chunk_idx, bool init_heap); //! Scans a DataChunk from the TupleDataCollection bool Scan(TupleDataScanState &state, DataChunk &result); //! Scans a DataChunk from the TupleDataCollection diff --git a/src/duckdb/src/include/duckdb/common/types/value.hpp b/src/duckdb/src/include/duckdb/common/types/value.hpp index 1993d0295..bba9a7297 100644 --- a/src/duckdb/src/include/duckdb/common/types/value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/value.hpp @@ -201,6 +201,8 @@ class Value { DUCKDB_API static Value BIGNUM(const_data_ptr_t data, idx_t len); DUCKDB_API static Value BIGNUM(const string &data); + DUCKDB_API static Value GEOMETRY(const_data_ptr_t data, idx_t len); + //! Creates an aggregate state DUCKDB_API static Value AGGREGATE_STATE(const LogicalType &type, const_data_ptr_t data, idx_t len); // NOLINT diff --git a/src/duckdb/src/include/duckdb/common/types/variant.hpp b/src/duckdb/src/include/duckdb/common/types/variant.hpp index cc8a9ffa6..0d5917892 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant.hpp @@ -31,6 +31,7 @@ struct VariantNestedData { struct VariantDecimalData { uint32_t width; uint32_t scale; + const_data_ptr_t value_ptr; }; struct VariantVectorData { @@ -105,6 +106,7 @@ enum class VariantLogicalType : uint8_t { ARRAY = 30, BIGNUM = 31, BITSTRING = 32, + GEOMETRY = 33, ENUM_SIZE /* always kept as last item of the enum */ }; diff --git a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp index 6a0f0346a..07c49b541 100644 --- a/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/virtual_file_system.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/file_system.hpp" #include "duckdb/common/map.hpp" #include "duckdb/common/unordered_set.hpp" +#include "duckdb/main/extension_helper.hpp" namespace duckdb { @@ -82,8 +83,10 @@ class VirtualFileSystem : public FileSystem { } private: + FileSystem &FindFileSystem(const string &path, optional_ptr file_opener); + FileSystem &FindFileSystem(const string &path, optional_ptr database_instance); FileSystem &FindFileSystem(const string &path); - FileSystem &FindFileSystemInternal(const string &path); + optional_ptr FindFileSystemInternal(const string &path); private: vector> sub_systems; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp index 6089a728b..f46496c9a 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp @@ -40,15 +40,6 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { // Predicate (join conditions that don't reference both sides) unique_ptr predicate; -public: - // Operator Interface - unique_ptr GetGlobalOperatorState(ClientContext &context) const override; - unique_ptr GetOperatorState(ExecutionContext &context) const override; - - bool ParallelOperator() const override { - return true; - } - protected: // CachingOperator Interface OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, @@ -83,6 +74,9 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { bool ParallelSink() const override { return true; } + +public: + void BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp index 4da01aff3..12d974ddd 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_piecewise_merge_join.hpp @@ -46,6 +46,8 @@ class PhysicalPiecewiseMergeJoin : public PhysicalRangeJoin { public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; bool IsSource() const override { diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp index 4ee6ef557..85463b954 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp @@ -8,36 +8,35 @@ #pragma once +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/execution/operator/join/physical_comparison_join.hpp" -#include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" namespace duckdb { -struct GlobalSortState; - //! PhysicalRangeJoin represents one or more inequality range join predicates between //! two tables class PhysicalRangeJoin : public PhysicalComparisonJoin { public: + class GlobalSortedTable; + class LocalSortedTable { public: - LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child); + LocalSortedTable(ExecutionContext &context, GlobalSortedTable &global_table, const idx_t child); - void Sink(DataChunk &input, GlobalSortState &global_sort_state); + void Sink(ExecutionContext &context, DataChunk &input); - inline void Sort(GlobalSortState &global_sort_state) { - local_sort_state.Sort(global_sort_state, true); - } - - //! The hosting operator - const PhysicalRangeJoin &op; + //! The global table we are connected to + GlobalSortedTable &global_table; //! The local sort state - LocalSortState local_sort_state; + unique_ptr local_sink; //! Local copy of the sorting expression executor ExpressionExecutor executor; //! Holds a vector of incoming sorting columns DataChunk keys; + //! The sort data + DataChunk sort_chunk; //! The number of NULL values idx_t has_null; //! The total number of rows @@ -50,45 +49,89 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { class GlobalSortedTable { public: - GlobalSortedTable(ClientContext &context, const vector &orders, RowLayout &payload_layout, - const PhysicalOperator &op); + GlobalSortedTable(ClientContext &client, const vector &orders, + const vector &payload_layout, const PhysicalRangeJoin &op); inline idx_t Count() const { return count; } inline idx_t BlockCount() const { - if (global_sort_state.sorted_blocks.empty()) { - return 0; - } - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - return global_sort_state.sorted_blocks[0]->radix_sorting_data.size(); + return sorted->key_data->ChunkCount(); + } + + inline idx_t BlockStart(idx_t i) const { + return MinValue(i * STANDARD_VECTOR_SIZE, count); + } + + inline idx_t BlockEnd(idx_t i) const { + return BlockStart(i + 1) - 1; } inline idx_t BlockSize(idx_t i) const { - return global_sort_state.sorted_blocks[0]->radix_sorting_data[i]->count; + return i < BlockCount() ? MinValue(STANDARD_VECTOR_SIZE, count - BlockStart(i)) : 0; + } + + inline SortKeyType GetSortKeyType() const { + return sorted->key_data->GetLayout().GetSortKeyType(); } - void Combine(LocalSortedTable <able); void IntializeMatches(); + + //! Combine local states + void Combine(ExecutionContext &context, LocalSortedTable <able); + //! Prepare for sorting. + void Finalize(ClientContext &client, InterruptState &interrupt); + //! Schedules the materialisation process. + void Materialize(Pipeline &pipeline, Event &event); + //! Single-threaded materialisation. + void Materialize(ExecutionContext &context, InterruptState &interrupt); + //! Materialize an empty sorted run. + void MaterializeEmpty(ClientContext &client); + //! Print the table to the console void Print(); - //! Starts the sorting process. - void Finalize(Pipeline &pipeline, Event &event); - //! Schedules tasks to merge sort the current child's data during a Finalize phase - void ScheduleMergeTasks(Pipeline &pipeline, Event &event); + //! Create an iteration state + unique_ptr CreateIteratorState() { + auto state = make_uniq(*sorted->key_data, sorted->payload_data.get()); + + // Unless we do this, we will only get values from the first chunk + Repin(*state); + + return state; + } + //! Reset the pins for an iterator so we release memory in a timely manner + static void Repin(ExternalBlockIteratorState &iter) { + iter.SetKeepPinned(true); + iter.SetPinPayload(true); + } + //! Create an iteration state + unique_ptr CreateScanState(ClientContext &client) { + return make_uniq(client, *sort); + } + //! Initialize a payload scanning state + void InitializePayloadState(TupleDataChunkState &state) { + sorted->payload_data->InitializeChunkState(state); + } //! The hosting operator - const PhysicalOperator &op; - GlobalSortState global_sort_state; + const PhysicalRangeJoin &op; + //! The sort description + unique_ptr sort; + //! The shared sort state + unique_ptr global_sink; //! Whether or not the RHS has NULL values atomic has_null; //! The total number of rows in the RHS atomic count; + //! The number of materialisation tasks completed in parallel + atomic tasks_completed; + //! The shared materialisation state + unique_ptr global_source; + //! The materialized data + unique_ptr sorted; //! A bool indicating for each tuple in the RHS if they found a match (only used in FULL OUTER JOIN) unsafe_unique_array found_match; - //! Memory usage per thread - idx_t memory_per_thread; }; public: @@ -106,10 +149,9 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { public: // Gather the result values and slice the payload columns to those values. - // Returns a buffer handle to the pinned heap block (if any) - static BufferHandle SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols = 0); + static void SliceSortedPayload(DataChunk &chunk, GlobalSortedTable &table, ExternalBlockIteratorState &state, + TupleDataChunkState &chunk_state, const idx_t chunk_idx, SelectionVector &result, + const idx_t result_count, SortedRunScanState &scan_state); // Apply a tail condition to the current selection static idx_t SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel); diff --git a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp index d87a3a976..107b482b7 100644 --- a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp @@ -170,6 +170,7 @@ struct DefaultCasts { static BoundCastInfo UnionCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo VariantCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo GeoCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo BignumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target); diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp index 2e7fbf68e..9aa105bd3 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/primitive_to_variant.hpp @@ -357,6 +357,9 @@ bool ConvertPrimitiveToVariant(ToVariantSourceData &source, ToVariantGlobalResul case LogicalTypeId::CHAR: return ConvertPrimitiveTemplated( source, result, count, selvec, values_index_selvec, empty_payload, is_root); + case LogicalTypeId::GEOMETRY: + return ConvertPrimitiveTemplated( + source, result, count, selvec, values_index_selvec, empty_payload, is_root); case LogicalTypeId::BLOB: return ConvertPrimitiveTemplated( source, result, count, selvec, values_index_selvec, empty_payload, is_root); diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp index 5a8b088ae..209598a74 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/struct_to_variant.hpp @@ -98,7 +98,7 @@ bool ConvertStructToVariant(ToVariantSourceData &source, ToVariantGlobalResultDa } } if (WRITE_DATA) { - //! Now forward the selection to point to the next index in the children.values_index + //! Now move the selection forward to write the value_id for the next struct child, for each row for (idx_t i = 0; i < sel.count; i++) { sel.children_selection[i]++; } diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp index 28d9db96b..cdbf698bd 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp @@ -240,7 +240,7 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe } } else if (source_type_id == VariantLogicalType::BITSTRING || source_type_id == VariantLogicalType::BIGNUM || source_type_id == VariantLogicalType::VARCHAR || - source_type_id == VariantLogicalType::BLOB) { + source_type_id == VariantLogicalType::BLOB || source_type_id == VariantLogicalType::GEOMETRY) { auto str_blob_data = source_blob_data + source_byte_offset; auto str_length = VarintDecode(str_blob_data); auto str_length_varint_size = GetVarintSize(str_length); diff --git a/src/duckdb/src/include/duckdb/function/compression_function.hpp b/src/duckdb/src/include/duckdb/function/compression_function.hpp index 64b1c2a58..97fff72b1 100644 --- a/src/duckdb/src/include/duckdb/function/compression_function.hpp +++ b/src/duckdb/src/include/duckdb/function/compression_function.hpp @@ -17,6 +17,7 @@ #include "duckdb/storage/data_pointer.hpp" #include "duckdb/storage/storage_info.hpp" #include "duckdb/storage/block_manager.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/storage/storage_lock.hpp" namespace duckdb { @@ -28,7 +29,6 @@ class SegmentStatistics; class TableFilter; struct TableFilterState; struct ColumnSegmentState; - struct ColumnFetchState; struct ColumnScanState; struct PrefetchState; @@ -174,7 +174,8 @@ typedef void (*compression_compress_finalize_t)(CompressionState &state); // Uncompress / Scan //===--------------------------------------------------------------------===// typedef void (*compression_init_prefetch_t)(ColumnSegment &segment, PrefetchState &prefetch_state); -typedef unique_ptr (*compression_init_segment_scan_t)(ColumnSegment &segment); +typedef unique_ptr (*compression_init_segment_scan_t)(const QueryContext &context, + ColumnSegment &segment); //! Function prototype used for reading an entire vector (STANDARD_VECTOR_SIZE) typedef void (*compression_scan_vector_t)(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, @@ -221,7 +222,8 @@ typedef void (*compression_cleanup_state_t)(ColumnSegment &segment); // GetSegmentInfo (optional) //===--------------------------------------------------------------------===// //! Function prototype for retrieving segment information straight from the column segment -typedef InsertionOrderPreservingMap (*compression_get_segment_info_t)(ColumnSegment &segment); +typedef InsertionOrderPreservingMap (*compression_get_segment_info_t)(QueryContext context, + ColumnSegment &segment); enum class CompressionValidity : uint8_t { REQUIRES_VALIDITY, NO_VALIDITY_REQUIRED }; diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp index 6408639ec..7c9ce455d 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_functions.hpp @@ -29,7 +29,7 @@ struct VariantTypeofFun { static constexpr const char *Name = "variant_typeof"; static constexpr const char *Parameters = "input_variant"; static constexpr const char *Description = "Returns the internal type of the `input_variant`."; - static constexpr const char *Example = "variant_typeof({'a': 42, 'b': [1,2,3])::VARIANT)"; + static constexpr const char *Example = "variant_typeof({'a': 42, 'b': [1,2,3]})::VARIANT)"; static constexpr const char *Categories = "variant"; static ScalarFunction GetFunction(); diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp index f0c4cb82b..d8605eb1f 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp @@ -66,6 +66,8 @@ struct VariantUtils { uint32_t value_index); DUCKDB_API static VariantNestedData DecodeNestedData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index); + DUCKDB_API static string_t DecodeStringData(const UnifiedVariantVectorData &variant, idx_t row, + uint32_t value_index); DUCKDB_API static vector GetObjectKeys(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data); DUCKDB_API static VariantChildDataCollectionResult FindChildValues(const UnifiedVariantVectorData &variant, diff --git a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp index e325b2f46..49c5e794c 100644 --- a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp @@ -47,6 +47,10 @@ struct DuckDBSchemasFun { static void RegisterFunction(BuiltinFunctions &set); }; +struct DuckDBConnectionCountFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + struct DuckDBApproxDatabaseCountFun { static void RegisterFunction(BuiltinFunctions &set); }; diff --git a/src/duckdb/src/include/duckdb/main/attached_database.hpp b/src/duckdb/src/include/duckdb/main/attached_database.hpp index 7333d9adb..a9ec117e9 100644 --- a/src/duckdb/src/include/duckdb/main/attached_database.hpp +++ b/src/duckdb/src/include/duckdb/main/attached_database.hpp @@ -42,6 +42,9 @@ struct StoredDatabasePath { DatabaseFilePathManager &manager; string path; + +public: + void OnDetach(); }; //! AttachOptions holds information about a database we plan to attach. These options are generalized, i.e., @@ -115,6 +118,7 @@ class AttachedDatabase : public CatalogEntry, public enable_shared_from_this context) : context(context) { // NOLINT: allow implicit construction } + QueryContext(ClientContext &context) : context(&context) { // NOLINT: allow implicit construction + } + QueryContext(weak_ptr context) // NOLINT: allow implicit construction + : owning_context(context.lock()), context(owning_context.get()) { + } public: bool Valid() const { @@ -347,6 +352,7 @@ class QueryContext { } private: + shared_ptr owning_context; optional_ptr context; }; diff --git a/src/duckdb/src/include/duckdb/main/connection.hpp b/src/duckdb/src/include/duckdb/main/connection.hpp index c27d84d21..72c61209c 100644 --- a/src/duckdb/src/include/duckdb/main/connection.hpp +++ b/src/duckdb/src/include/duckdb/main/connection.hpp @@ -50,7 +50,6 @@ class Connection { DUCKDB_API ~Connection(); shared_ptr context; - warning_callback_t warning_cb; public: //! Returns query profiling information for the current query diff --git a/src/duckdb/src/include/duckdb/main/connection_manager.hpp b/src/duckdb/src/include/duckdb/main/connection_manager.hpp index 7fa5c66b5..1c647ce02 100644 --- a/src/duckdb/src/include/duckdb/main/connection_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/connection_manager.hpp @@ -40,7 +40,6 @@ class ConnectionManager { mutex connections_lock; reference_map_t> connections; atomic connection_count; - atomic current_connection_id; }; diff --git a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp index 1912a90bf..90d028035 100644 --- a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp @@ -20,10 +20,11 @@ struct AttachOptions; enum class InsertDatabasePathResult { SUCCESS, ALREADY_EXISTS }; struct DatabasePathInfo { - explicit DatabasePathInfo(string name_p) : name(std::move(name_p)) { + explicit DatabasePathInfo(string name_p) : name(std::move(name_p)), is_attached(true) { } string name; + bool is_attached; }; //! The DatabaseFilePathManager is used to ensure we only ever open a single database file once @@ -34,11 +35,15 @@ class DatabaseFilePathManager { AttachOptions &options); //! Erase a database path - indicating we are done with using it void EraseDatabasePath(const string &path); + //! Called when a database is detached, but before it is fully finished being used + void DetachDatabase(const string &path); private: //! The lock to add entries to the db_paths map mutable mutex db_paths_lock; - //! A set containing all attached database paths mapped to their attached database name + //! A set containing all attached database path + //! This allows to attach many databases efficiently, and to avoid attaching the + //! same file path twice case_insensitive_map_t db_paths; }; diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp index a32331c9b..4f119c4a0 100644 --- a/src/duckdb/src/include/duckdb/main/extension_entries.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -779,6 +779,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"var_pop", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"var_samp", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"variance", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"variant_to_parquet_variant", "parquet", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vector_type", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"version", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vss_join", "vss", CatalogType::TABLE_MACRO_ENTRY}, diff --git a/src/duckdb/src/include/duckdb/main/profiling_info.hpp b/src/duckdb/src/include/duckdb/main/profiling_info.hpp index 904f0205d..709314375 100644 --- a/src/duckdb/src/include/duckdb/main/profiling_info.hpp +++ b/src/duckdb/src/include/duckdb/main/profiling_info.hpp @@ -32,9 +32,6 @@ class ProfilingInfo { profiler_settings_t expanded_settings; //! Contains all enabled metrics. profiler_metrics_t metrics; - //! Additional metrics. - // FIXME: move to metrics. - InsertionOrderPreservingMap extra_info; public: ProfilingInfo() = default; @@ -102,6 +99,7 @@ class ProfilingInfo { return MaxValue(old_value, new_value); }); } + template void MetricMax(const MetricsType type, const METRIC_TYPE &value) { auto new_value = Value::CreateValue(value); @@ -109,4 +107,19 @@ class ProfilingInfo { } }; +// Specialization for InsertionOrderPreservingMap +template <> +inline InsertionOrderPreservingMap +ProfilingInfo::GetMetricValue>(const MetricsType type) const { + auto val = metrics.at(type); + InsertionOrderPreservingMap result; + auto children = MapValue::GetChildren(val); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + result.insert(key, value); + } + return result; +} } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/query_profiler.hpp b/src/duckdb/src/include/duckdb/main/query_profiler.hpp index 0f7b8812d..db27be1e4 100644 --- a/src/duckdb/src/include/duckdb/main/query_profiler.hpp +++ b/src/duckdb/src/include/duckdb/main/query_profiler.hpp @@ -180,7 +180,8 @@ class QueryProfiler { DUCKDB_API string ToString(ExplainFormat format = ExplainFormat::DEFAULT) const; DUCKDB_API string ToString(ProfilerPrintFormat format) const; - static InsertionOrderPreservingMap JSONSanitize(const InsertionOrderPreservingMap &input); + // Sanitize a Value::MAP + static Value JSONSanitize(const Value &input); static string JSONSanitize(const string &text); static string DrawPadded(const string &str, idx_t width); DUCKDB_API string ToJSON() const; diff --git a/src/duckdb/src/include/duckdb/main/secret/secret.hpp b/src/duckdb/src/include/duckdb/main/secret/secret.hpp index ed8034413..fd8a1b241 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret.hpp @@ -296,7 +296,9 @@ class KeyValueSecretReader { Value result; auto lookup_result = TryGetSecretKeyOrSetting(secret_key, setting_name, result); if (lookup_result) { - value_out = result.GetValue(); + if (!result.IsNull()) { + value_out = result.GetValue(); + } } return lookup_result; } diff --git a/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp new file mode 100644 index 000000000..8d8e35ea1 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/common_subplan_optimizer.hpp @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/common_subplan_optimizer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/logical_operator.hpp" + +namespace duckdb { + +class Optimizer; +class LogicalOperator; + +//! The CommonSubplanOptimizer optimizer detects common subplans, and converts them to refs of a materialized CTE +class CommonSubplanOptimizer { +public: + explicit CommonSubplanOptimizer(Optimizer &optimizer); + +public: + unique_ptr Optimize(unique_ptr op); + +private: + //! The optimizer + Optimizer &optimizer; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp b/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp index 90439b11e..97529f6ee 100644 --- a/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/cte_inlining.hpp @@ -25,6 +25,7 @@ class CTEInlining { public: explicit CTEInlining(Optimizer &optimizer); unique_ptr Optimize(unique_ptr op); + static bool EndsInAggregateOrDistinct(const LogicalOperator &op); private: void TryInlining(unique_ptr &op); diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp index c2fc87a52..5e920ca2d 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp @@ -96,6 +96,8 @@ class FilterPushdown { unique_ptr FinishPushdown(unique_ptr op); //! Adds a filter to the set of filters. Returns FilterResult::UNSATISFIABLE if the subtree should be stripped, or //! FilterResult::SUCCESS otherwise + + unique_ptr PushFiltersIntoDelimJoin(unique_ptr op); FilterResult AddFilter(unique_ptr expr); //! Extract filter bindings to compare them with expressions in an operator and determine if the filter //! can be pushed down diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 5a664f2dc..a2bac2325 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -100,6 +100,65 @@ struct CorrelatedColumnInfo { } }; +struct CorrelatedColumns { +private: + using container_type = vector; + +public: + CorrelatedColumns() : delim_index(1ULL << 63) { + } + + void AddColumn(container_type::value_type info) { + // Add to beginning + correlated_columns.insert(correlated_columns.begin(), std::move(info)); + delim_index++; + } + + void SetDelimIndexToZero() { + delim_index = 0; + } + + idx_t GetDelimIndex() const { + return delim_index; + } + + const container_type::value_type &operator[](const idx_t &index) const { + return correlated_columns.at(index); + } + + idx_t size() const { // NOLINT: match stl case + return correlated_columns.size(); + } + + bool empty() const { // NOLINT: match stl case + return correlated_columns.empty(); + } + + void clear() { // NOLINT: match stl case + correlated_columns.clear(); + } + + container_type::iterator begin() { // NOLINT: match stl case + return correlated_columns.begin(); + } + + container_type::iterator end() { // NOLINT: match stl case + return correlated_columns.end(); + } + + container_type::const_iterator begin() const { // NOLINT: match stl case + return correlated_columns.begin(); + } + + container_type::const_iterator end() const { // NOLINT: match stl case + return correlated_columns.end(); + } + +private: + container_type correlated_columns; + idx_t delim_index; +}; + //! Bind the parsed query tree to the actual columns present in the catalog. /*! The binder is responsible for binding tables and columns to actual physical @@ -122,7 +181,7 @@ class Binder : public enable_shared_from_this { BindContext bind_context; //! The set of correlated columns bound by this binder (FIXME: this should probably be an unordered_set and not a //! vector) - vector correlated_columns; + CorrelatedColumns correlated_columns; //! The set of parameter expressions bound by this binder optional_ptr parameters; //! The alias for the currently processing subquery, if it exists @@ -198,7 +257,7 @@ class Binder : public enable_shared_from_this { vector> &GetActiveBinders(); - void MergeCorrelatedColumns(vector &other); + void MergeCorrelatedColumns(CorrelatedColumns &other); //! Add a correlated column to this binder (if it does not exist) void AddCorrelatedColumn(const CorrelatedColumnInfo &info); @@ -426,7 +485,7 @@ class Binder : public enable_shared_from_this { void PlanSubqueries(unique_ptr &expr, unique_ptr &root); unique_ptr PlanSubquery(BoundSubqueryExpression &expr, unique_ptr &root); unique_ptr PlanLateralJoin(unique_ptr left, unique_ptr right, - vector &correlated_columns, + CorrelatedColumns &correlated_columns, JoinType join_type = JoinType::INNER, unique_ptr condition = nullptr); diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp index eb68a0cdf..55f046cd7 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/lateral_binder.hpp @@ -24,7 +24,7 @@ class LateralBinder : public ExpressionBinder { return !correlated_columns.empty(); } - static void ReduceExpressionDepth(LogicalOperator &op, const vector &info); + static void ReduceExpressionDepth(LogicalOperator &op, const CorrelatedColumns &info); protected: BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, @@ -37,7 +37,7 @@ class LateralBinder : public ExpressionBinder { void ExtractCorrelatedColumns(Expression &expr); private: - vector correlated_columns; + CorrelatedColumns correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp index cd2ed3c21..0548cd4e7 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_cte.hpp @@ -35,6 +35,6 @@ class LogicalCTE : public LogicalOperator { string ctename; idx_t table_index; idx_t column_count; - vector correlated_columns; + CorrelatedColumns correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp index 724f2bc57..5e4c83919 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_dependent_join.hpp @@ -27,7 +27,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { public: explicit LogicalDependentJoin(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition); explicit LogicalDependentJoin(JoinType type); @@ -35,7 +35,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { //! The conditions of the join unique_ptr join_condition; //! The list of columns that have correlations with the right - vector correlated_columns; + CorrelatedColumns correlated_columns; SubqueryType subquery_type = SubqueryType::INVALID; bool perform_delim = true; @@ -51,7 +51,7 @@ class LogicalDependentJoin : public LogicalComparisonJoin { public: static unique_ptr Create(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp index 2f343e901..14ad4510c 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp @@ -18,7 +18,7 @@ namespace duckdb { //! The FlattenDependentJoins class is responsible for pushing the dependent join down into the plan to create a //! flattened subquery struct FlattenDependentJoins { - FlattenDependentJoins(Binder &binder, const vector &correlated, bool perform_delim = true, + FlattenDependentJoins(Binder &binder, const CorrelatedColumns &correlated, bool perform_delim = true, bool any_join = false, optional_ptr parent = nullptr); static unique_ptr DecorrelateIndependent(Binder &binder, unique_ptr plan); @@ -47,7 +47,7 @@ struct FlattenDependentJoins { reference_map_t has_correlated_expressions; column_binding_map_t correlated_map; column_binding_map_t replacement_map; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; vector delim_types; bool perform_delim; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp index 6b238ffcc..81a097b49 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/has_correlated_expressions.hpp @@ -16,7 +16,7 @@ namespace duckdb { //! Helper class to recursively detect correlated expressions inside a single LogicalOperator class HasCorrelatedExpressions : public LogicalOperatorVisitor { public: - explicit HasCorrelatedExpressions(const vector &correlated, bool lateral = false, + explicit HasCorrelatedExpressions(const CorrelatedColumns &correlated, bool lateral = false, idx_t lateral_depth = 0); void VisitOperator(LogicalOperator &op) override; @@ -28,7 +28,7 @@ class HasCorrelatedExpressions : public LogicalOperatorVisitor { unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; unique_ptr VisitReplace(BoundSubqueryExpression &expr, unique_ptr *expr_ptr) override; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; // Tracks number of nested laterals idx_t lateral_depth; }; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp index e2c507e73..72886f80e 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/rewrite_cte_scan.hpp @@ -17,13 +17,13 @@ namespace duckdb { //! Helper class to rewrite correlated cte scans within a single LogicalOperator class RewriteCTEScan : public LogicalOperatorVisitor { public: - RewriteCTEScan(idx_t table_index, const vector &correlated_columns); + RewriteCTEScan(idx_t table_index, const CorrelatedColumns &correlated_columns); void VisitOperator(LogicalOperator &op) override; private: idx_t table_index; - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp index 38c83c95f..299189624 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp @@ -47,7 +47,7 @@ class BoundJoinRef : public BoundTableRef { //! Whether or not this is a lateral join bool lateral; //! The correlated columns of the right-side with the left-side - vector correlated_columns; + CorrelatedColumns correlated_columns; //! The mark index, for mark joins generated by the relational API idx_t mark_index {}; }; diff --git a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp index 619e89a5a..fd281e141 100644 --- a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp @@ -58,6 +58,7 @@ class BufferManager { virtual void ReAllocate(shared_ptr &handle, idx_t block_size) = 0; //! Pin a block handle. virtual BufferHandle Pin(shared_ptr &handle) = 0; + virtual BufferHandle Pin(const QueryContext &context, shared_ptr &handle) = 0; //! Pre-fetch a series of blocks. //! Using this function is a performance suggestion. virtual void Prefetch(vector> &handles) = 0; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp index bac590d0e..13eecf42b 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_analyze.hpp @@ -82,7 +82,12 @@ unique_ptr AlpInitAnalyze(ColumnData &col_data, PhysicalType type) */ template bool AlpAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &analyze_state = (AlpAnalyzeState &)state; + if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { + return false; + } + + auto &analyze_state = state.Cast>(); + bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); analyze_state.vectors_count += 1; diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp index 28b52b848..8c7d12e67 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp @@ -201,7 +201,7 @@ struct AlpScanState : public SegmentScanState { }; template -unique_ptr AlpInitScan(ColumnSegment &segment) { +unique_ptr AlpInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp index 25901667e..da7f8bda0 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_analyze.hpp @@ -47,8 +47,12 @@ unique_ptr AlpRDInitAnalyze(ColumnData &col_data, PhysicalType typ */ template bool AlpRDAnalyze(AnalyzeState &state, Vector &input, idx_t count) { + if (state.info.GetBlockSize() + state.info.GetBlockHeaderSize() < DEFAULT_BLOCK_ALLOC_SIZE) { + return false; + } + using EXACT_TYPE = typename FloatingToExact::TYPE; - auto &analyze_state = (AlpRDAnalyzeState &)state; + auto &analyze_state = state.Cast>(); bool must_skip_current_vector = alp::AlpUtils::MustSkipSamplingFromCurrentVector( analyze_state.vectors_count, analyze_state.vectors_sampled_count, count); diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp index a3feb94b5..520d38fa2 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp @@ -208,7 +208,7 @@ struct AlpRDScanState : public SegmentScanState { }; template -unique_ptr AlpRDInitScan(ColumnSegment &segment) { +unique_ptr AlpRDInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp index de11979cb..970b35f23 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/chimp_scan.hpp @@ -252,7 +252,7 @@ struct ChimpScanState : public SegmentScanState { }; template -unique_ptr ChimpInitScan(ColumnSegment &segment) { +unique_ptr ChimpInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp index 1118f77f2..476faf89f 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/empty_validity.hpp @@ -77,7 +77,7 @@ class EmptyValidityCompression { auto &checkpoint_state = checkpoint_data.GetCheckpointState(); checkpoint_state.FlushSegment(std::move(compressed_segment), std::move(handle), 0); } - static unique_ptr InitScan(ColumnSegment &segment) { + static unique_ptr InitScan(const QueryContext &context, ColumnSegment &segment) { return make_uniq(); } static void ScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp index b523600e3..4261d2d23 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp @@ -204,7 +204,7 @@ struct PatasScanState : public SegmentScanState { }; template -unique_ptr PatasInitScan(ColumnSegment &segment) { +unique_ptr PatasInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq_base>(segment); return result; } diff --git a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp index 0ae44d7f3..54f6f239f 100644 --- a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp @@ -19,6 +19,7 @@ namespace duckdb { class Serializer; class Deserializer; +class QueryContext; struct ColumnSegmentState { virtual ~ColumnSegmentState() { diff --git a/src/duckdb/src/include/duckdb/storage/data_table.hpp b/src/duckdb/src/include/duckdb/storage/data_table.hpp index bc8727a18..bf355ed05 100644 --- a/src/duckdb/src/include/duckdb/storage/data_table.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_table.hpp @@ -196,7 +196,7 @@ class DataTable : public enable_shared_from_this { //! Remove the chunk with the specified set of row identifiers from all indexes of the table void RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vector &row_identifiers); //! Remove the row identifiers from all the indexes of the table - void RemoveFromIndexes(Vector &row_identifiers, idx_t count); + void RemoveFromIndexes(const QueryContext &context, Vector &row_identifiers, idx_t count); void SetAsMainTable() { this->version = DataTableVersion::MAIN_TABLE; @@ -234,7 +234,7 @@ class DataTable : public enable_shared_from_this { idx_t ColumnCount() const; idx_t GetTotalRows() const; - vector GetColumnSegmentInfo(); + vector GetColumnSegmentInfo(const QueryContext &context); //! Scans the next chunk for the CREATE INDEX operator bool CreateIndexScan(TableScanState &state, DataChunk &result, TableScanType type); diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp index cd63a96b8..06a451a52 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp @@ -65,7 +65,7 @@ class MetadataManager { MetadataHandle AllocateHandle(); MetadataHandle Pin(const MetadataPointer &pointer); - MetadataHandle Pin(QueryContext context, const MetadataPointer &pointer); + MetadataHandle Pin(const QueryContext &context, const MetadataPointer &pointer); MetaBlockPointer GetDiskPointer(const MetadataPointer &pointer, uint32_t offset = 0); MetadataPointer FromDiskPointer(MetaBlockPointer pointer); diff --git a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp index 1ded8bba6..98d371437 100644 --- a/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/optimistic_data_writer.hpp @@ -14,11 +14,16 @@ namespace duckdb { class PartialBlockManager; struct OptimisticWriteCollection { + ~OptimisticWriteCollection(); + shared_ptr collection; idx_t last_flushed = 0; idx_t complete_row_groups = 0; + vector> partial_block_managers; }; +enum class OptimisticWritePartialManagers { PER_COLUMN, GLOBAL }; + class OptimisticDataWriter { public: OptimisticDataWriter(ClientContext &context, DataTable &table); @@ -26,8 +31,9 @@ class OptimisticDataWriter { ~OptimisticDataWriter(); //! Creates a collection to write to - static unique_ptr CreateCollection(DataTable &storage, - const vector &insert_types); + unique_ptr + CreateCollection(DataTable &storage, const vector &insert_types, + OptimisticWritePartialManagers type = OptimisticWritePartialManagers::PER_COLUMN); //! Write a new row group to disk (if possible) void WriteNewRowGroup(OptimisticWriteCollection &row_groups); //! Write the last row group of a collection to disk @@ -35,9 +41,10 @@ class OptimisticDataWriter { //! Final flush of the optimistic writer - fully flushes the partial block manager void FinalFlush(); //! Flushes a specific row group to disk - void FlushToDisk(const vector> &row_groups); + void FlushToDisk(OptimisticWriteCollection &collection, const vector> &row_groups); //! Merge the partially written blocks from one optimistic writer into another void Merge(OptimisticDataWriter &other); + void Merge(unique_ptr &other_manager); //! Rollback void Rollback(); diff --git a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp index d0a54c597..f6b91ed1e 100644 --- a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -71,7 +71,7 @@ class StandardBufferManager : public BufferManager { void ReAllocate(shared_ptr &handle, idx_t block_size) final; BufferHandle Pin(shared_ptr &handle) final; - BufferHandle Pin(QueryContext context, shared_ptr &handle); + BufferHandle Pin(const QueryContext &context, shared_ptr &handle) final; void Prefetch(vector> &handles) final; void Unpin(shared_ptr &handle) final; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp index 0982f8905..6e5814a36 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/string_stats.hpp @@ -71,6 +71,8 @@ struct StringStats { ExpressionType comparison_type, const string &value); DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void SetMin(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void SetMax(BaseStatistics &stats, const string_t &value); DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); diff --git a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp index b5342829c..7ca16e38e 100644 --- a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp @@ -67,7 +67,7 @@ struct UncompressedStringStorage { static unique_ptr StringInitAnalyze(ColumnData &col_data, PhysicalType type); static bool StringAnalyze(AnalyzeState &state_p, Vector &input, idx_t count); static idx_t StringFinalAnalyze(AnalyzeState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); diff --git a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp index abc9577a3..f4c943a79 100644 --- a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp @@ -65,8 +65,8 @@ class ArrayColumnData : public ColumnData { PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; void Verify(RowGroup &parent) override; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp index 400daeaa6..b688f8ed7 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp @@ -39,14 +39,16 @@ struct PersistentColumnData; using column_segment_vector_t = vector>; struct ColumnCheckpointInfo { - ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx) : info(info), column_idx(column_idx) { - } + ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx); - RowGroupWriteInfo &info; idx_t column_idx; public: + PartialBlockManager &GetPartialBlockManager(); CompressionType GetCompressionType(); + +private: + RowGroupWriteInfo &info; }; class ColumnData { @@ -178,7 +180,8 @@ class ColumnData { static shared_ptr Deserialize(BlockManager &block_manager, DataTableInfo &info, idx_t column_index, idx_t start_row, ReadStream &source, const LogicalType &type); - virtual void GetColumnSegmentInfo(idx_t row_group_index, vector col_path, vector &result); + virtual void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result); virtual void Verify(RowGroup &parent); FilterPropagateResult CheckZonemap(TableFilter &filter); diff --git a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp index 61b2c0d4f..e99664958 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp @@ -29,7 +29,6 @@ class DatabaseInstance; class TableFilter; class Transaction; class UpdateSegment; - struct ColumnAppendState; struct ColumnFetchState; struct ColumnScanState; diff --git a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp index c8e75d136..98d8c662b 100644 --- a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp @@ -63,8 +63,8 @@ class ListColumnData : public ColumnData { PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; private: uint64_t FetchListOffset(idx_t row_idx); diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp index 242e19121..d003d1378 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp @@ -53,13 +53,22 @@ class StorageCommitState; struct RowGroupWriteInfo { RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, - CheckpointType checkpoint_type = CheckpointType::FULL_CHECKPOINT) - : manager(manager), compression_types(compression_types), checkpoint_type(checkpoint_type) { - } + CheckpointType checkpoint_type = CheckpointType::FULL_CHECKPOINT); + RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + vector> &column_partial_block_managers_p); +private: PartialBlockManager &manager; + +public: const vector &compression_types; CheckpointType checkpoint_type; + +public: + PartialBlockManager &GetPartialBlockManager(idx_t column_idx); + +private: + optional_ptr>> column_partial_block_managers; }; struct RowGroupWriteData { @@ -174,7 +183,7 @@ class RowGroup : public SegmentBase { void MergeIntoStatistics(TableStatistics &other); unique_ptr GetStatistics(idx_t column_idx); - void GetColumnSegmentInfo(idx_t row_group_index, vector &result); + void GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector &result); PartitionStatistics GetPartitionStats() const; idx_t GetAllocationSize() const { diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp index 32808ff4c..e5c745829 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp @@ -61,13 +61,14 @@ class RowGroupCollection { void Verify(); void Destroy(); - void InitializeScan(CollectionScanState &state, const vector &column_ids, + void InitializeScan(const QueryContext &context, CollectionScanState &state, const vector &column_ids, optional_ptr table_filters); void InitializeCreateIndexScan(CreateIndexScanState &state); - void InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, idx_t start_row, - idx_t end_row); - static bool InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, - RowGroup &row_group, idx_t vector_index, idx_t max_row); + void InitializeScanWithOffset(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, idx_t start_row, idx_t end_row); + static bool InitializeScanInRowGroup(const QueryContext &context, CollectionScanState &state, + RowGroupCollection &collection, RowGroup &row_group, idx_t vector_index, + idx_t max_row); void InitializeParallelScan(ParallelCollectionScanState &state); bool NextParallelScan(ClientContext &context, ParallelCollectionScanState &state, CollectionScanState &scan_state); @@ -97,7 +98,7 @@ class RowGroupCollection { optional_ptr commit_state); bool IsPersistent() const; - void RemoveFromIndexes(TableIndexList &indexes, Vector &row_identifiers, idx_t count); + void RemoveFromIndexes(const QueryContext &context, TableIndexList &indexes, Vector &row_identifiers, idx_t count); idx_t Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count); void Update(TransactionData transaction, row_t *ids, const vector &column_ids, DataChunk &updates); @@ -116,7 +117,7 @@ class RowGroupCollection { void CommitDropTable(); vector GetPartitionStats() const; - vector GetColumnSegmentInfo(); + vector GetColumnSegmentInfo(const QueryContext &context); const vector &GetTypes() const; shared_ptr AddColumn(ClientContext &context, ColumnDefinition &new_column, @@ -124,7 +125,7 @@ class RowGroupCollection { shared_ptr RemoveColumn(idx_t col_idx); shared_ptr AlterType(ClientContext &context, idx_t changed_idx, const LogicalType &target_type, vector bound_columns, Expression &cast_expr); - void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + void VerifyNewConstraint(const QueryContext &context, DataTable &parent, const BoundConstraint &constraint); void CopyStats(TableStatistics &stats); unique_ptr CopyStats(column_t column_id); diff --git a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp index 97416cbf2..8e85053ec 100644 --- a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp @@ -78,6 +78,8 @@ struct IndexScanState { typedef unordered_map buffer_handle_set_t; struct ColumnScanState { + //! The query context for this scan + QueryContext context; //! The column segment that is currently being scanned ColumnSegment *current = nullptr; //! Column segment tree @@ -105,9 +107,9 @@ struct ColumnScanState { optional_ptr scan_options; public: - void Initialize(const LogicalType &type, const vector &children, + void Initialize(const QueryContext &context_p, const LogicalType &type, const vector &children, optional_ptr options); - void Initialize(const LogicalType &type, optional_ptr options); + void Initialize(const QueryContext &context_p, const LogicalType &type, optional_ptr options); //! Move the scan state forward by "count" rows (including all child states) void Next(idx_t count); //! Move ONLY this state forward by "count" rows (i.e. not the child states) @@ -115,6 +117,8 @@ struct ColumnScanState { }; struct ColumnFetchState { + //! The query context for this fetch + QueryContext context; //! The set of pinned block handles for this set of fetches buffer_handle_set_t handles; //! Any child states of the fetch @@ -202,7 +206,7 @@ class CollectionScanState { RandomEngine random; public: - void Initialize(const vector &types); + void Initialize(const QueryContext &context, const vector &types); const vector &GetColumnIds(); ScanFilterInfo &GetFilterInfo(); ScanSamplingInfo &GetSamplingInfo(); diff --git a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp index 48ac6ccb7..8d233139a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp @@ -61,8 +61,8 @@ class StandardColumnData : public ColumnData { void CheckpointScan(ColumnSegment &segment, ColumnScanState &state, idx_t row_group_start, idx_t count, Vector &scan_vector) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; bool IsPersistent() override; bool HasAnyChanges() const override; diff --git a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp index d05436bfc..91c7f1e19 100644 --- a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp @@ -63,8 +63,8 @@ class StructColumnData : public ColumnData { PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; - void GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) override; + void GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) override; void Verify(RowGroup &parent) override; }; diff --git a/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp index 0de2faabd..21f117674 100644 --- a/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/cleanup_state.hpp @@ -11,6 +11,7 @@ #include "duckdb/transaction/undo_buffer.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/common/unordered_map.hpp" +#include "duckdb/main/client_context.hpp" namespace duckdb { @@ -21,7 +22,7 @@ struct UpdateInfo; class CleanupState { public: - explicit CleanupState(transaction_t lowest_active_transaction); + explicit CleanupState(const QueryContext &context, transaction_t lowest_active_transaction); ~CleanupState(); // all tables with indexes that possibly need a vacuum (after e.g. a delete) @@ -31,6 +32,7 @@ class CleanupState { void CleanupEntry(UndoFlags type, data_ptr_t data); private: + QueryContext context; //! Lowest active transaction transaction_t lowest_active_transaction; // data for index cleanup diff --git a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp index 5d29da46c..1cab839d0 100644 --- a/src/duckdb/src/include/duckdb/transaction/local_storage.hpp +++ b/src/duckdb/src/include/duckdb/transaction/local_storage.hpp @@ -40,6 +40,8 @@ class LocalTableStorage : public enable_shared_from_this { ExpressionExecutor &default_executor); ~LocalTableStorage(); + QueryContext context; + reference table_ref; Allocator &allocator; @@ -189,6 +191,10 @@ class LocalStorage { void VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint); + ClientContext &GetClientContext() const { + return context; + } + private: ClientContext &context; DuckTransaction &transaction; diff --git a/src/duckdb/src/main/attached_database.cpp b/src/duckdb/src/main/attached_database.cpp index e98070c18..e879e17a1 100644 --- a/src/duckdb/src/main/attached_database.cpp +++ b/src/duckdb/src/main/attached_database.cpp @@ -22,6 +22,10 @@ StoredDatabasePath::~StoredDatabasePath() { manager.EraseDatabasePath(path); } +void StoredDatabasePath::OnDetach() { + manager.DetachDatabase(path); +} + //===--------------------------------------------------------------------===// // Attach Options //===--------------------------------------------------------------------===// @@ -157,6 +161,13 @@ bool AttachedDatabase::NameIsReserved(const string &name) { return name == DEFAULT_SCHEMA || name == TEMP_CATALOG || name == SYSTEM_CATALOG; } +string AttachedDatabase::StoredPath() const { + if (stored_database_path) { + return stored_database_path->path; + } + return string(); +} + static string RemoveQueryParams(const string &name) { auto vec = StringUtil::Split(name, "?"); D_ASSERT(!vec.empty()); @@ -181,7 +192,7 @@ void AttachedDatabase::Initialize(optional_ptr context) { catalog->Initialize(context, false); } if (storage) { - storage->Initialize(QueryContext(context)); + storage->Initialize(context); } } @@ -232,6 +243,9 @@ void AttachedDatabase::OnDetach(ClientContext &context) { if (catalog) { catalog->OnDetach(context); } + if (stored_database_path && visibility != AttachVisibility::HIDDEN) { + stored_database_path->OnDetach(); + } } void AttachedDatabase::Close() { diff --git a/src/duckdb/src/main/client_data.cpp b/src/duckdb/src/main/client_data.cpp index 1348c0b09..50df05563 100644 --- a/src/duckdb/src/main/client_data.cpp +++ b/src/duckdb/src/main/client_data.cpp @@ -56,6 +56,9 @@ class ClientBufferManager : public BufferManager { return buffer_manager.ReAllocate(handle, block_size); } BufferHandle Pin(shared_ptr &handle) override { + return Pin(QueryContext(), handle); + } + BufferHandle Pin(const QueryContext &context, shared_ptr &handle) override { return buffer_manager.Pin(handle); } void Prefetch(vector> &handles) override { diff --git a/src/duckdb/src/main/connection.cpp b/src/duckdb/src/main/connection.cpp index e561a3cb9..af76cfd17 100644 --- a/src/duckdb/src/main/connection.cpp +++ b/src/duckdb/src/main/connection.cpp @@ -19,7 +19,7 @@ namespace duckdb { Connection::Connection(DatabaseInstance &database) - : context(make_shared_ptr(database.shared_from_this())), warning_cb(nullptr) { + : context(make_shared_ptr(database.shared_from_this())) { auto &connection_manager = ConnectionManager::Get(database); connection_manager.AddConnection(*context); connection_manager.AssignConnectionId(*this); @@ -31,18 +31,15 @@ Connection::Connection(DatabaseInstance &database) } Connection::Connection(DuckDB &database) : Connection(*database.instance) { - // Initialization of warning_cb happens in the other constructor } -Connection::Connection(Connection &&other) noexcept : warning_cb(nullptr) { +Connection::Connection(Connection &&other) noexcept { std::swap(context, other.context); - std::swap(warning_cb, other.warning_cb); std::swap(connection_id, other.connection_id); } Connection &Connection::operator=(Connection &&other) noexcept { std::swap(context, other.context); - std::swap(warning_cb, other.warning_cb); std::swap(connection_id, other.connection_id); return *this; } diff --git a/src/duckdb/src/main/database_file_path_manager.cpp b/src/duckdb/src/main/database_file_path_manager.cpp index 05adeadfe..f1825780e 100644 --- a/src/duckdb/src/main/database_file_path_manager.cpp +++ b/src/duckdb/src/main/database_file_path_manager.cpp @@ -22,7 +22,12 @@ InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(const strin if (!entry.second) { auto &existing = entry.first->second; if (on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT && existing.name == name) { - return InsertDatabasePathResult::ALREADY_EXISTS; + if (existing.is_attached) { + return InsertDatabasePathResult::ALREADY_EXISTS; + } + throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is in " + "the process of being detached", + name, path); } throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " "attached by database \"%s\"", @@ -40,4 +45,15 @@ void DatabaseFilePathManager::EraseDatabasePath(const string &path) { db_paths.erase(path); } +void DatabaseFilePathManager::DetachDatabase(const string &path) { + if (path.empty() || path == IN_MEMORY_PATH) { + return; + } + lock_guard path_lock(db_paths_lock); + auto entry = db_paths.find(path); + if (entry != db_paths.end()) { + entry->second.is_attached = false; + } +} + } // namespace duckdb diff --git a/src/duckdb/src/main/database_manager.cpp b/src/duckdb/src/main/database_manager.cpp index ae0a6447d..f59cc2719 100644 --- a/src/duckdb/src/main/database_manager.cpp +++ b/src/duckdb/src/main/database_manager.cpp @@ -303,7 +303,7 @@ void DatabaseManager::GetDatabaseType(ClientContext &context, AttachInfo &info, // Try to extract the database type from the path. if (options.db_type.empty()) { auto &fs = FileSystem::GetFileSystem(context); - DBPathAndType::CheckMagicBytes(QueryContext(context), fs, info.path, options.db_type); + DBPathAndType::CheckMagicBytes(context, fs, info.path, options.db_type); } if (options.db_type.empty()) { diff --git a/src/duckdb/src/main/http/http_util.cpp b/src/duckdb/src/main/http/http_util.cpp index a51fb3e7f..554346489 100644 --- a/src/duckdb/src/main/http/http_util.cpp +++ b/src/duckdb/src/main/http/http_util.cpp @@ -367,7 +367,9 @@ HTTPUtil::RunRequestWithRetry(const std::function(void) try { response = on_request(); - response->url = request.url; + if (response) { + response->url = request.url; + } } catch (IOException &e) { exception_error = e.what(); caught_e = std::current_exception(); diff --git a/src/duckdb/src/main/profiling_info.cpp b/src/duckdb/src/main/profiling_info.cpp index 8f744d51b..276e0c198 100644 --- a/src/duckdb/src/main/profiling_info.cpp +++ b/src/duckdb/src/main/profiling_info.cpp @@ -104,6 +104,7 @@ void ProfilingInfo::ResetMetrics() { metrics[metric] = Value::CreateValue(0); break; case MetricsType::EXTRA_INFO: + metrics[metric] = Value::MAP(InsertionOrderPreservingMap()); break; default: throw InternalException("MetricsType" + EnumUtil::ToString(metric) + "not implemented"); @@ -149,21 +150,10 @@ string ProfilingInfo::GetMetricAsString(const MetricsType metric) const { throw InternalException("Metric %s not enabled", EnumUtil::ToString(metric)); } - if (metric == MetricsType::EXTRA_INFO) { - string result; - for (auto &it : extra_info) { - if (!result.empty()) { - result += ", "; - } - result += StringUtil::Format("%s: %s", it.first, it.second); - } - return "\"" + result + "\""; - } - // The metric cannot be NULL and must be initialized. D_ASSERT(!metrics.at(metric).IsNull()); if (metric == MetricsType::OPERATOR_TYPE) { - auto type = PhysicalOperatorType(metrics.at(metric).GetValue()); + const auto type = PhysicalOperatorType(metrics.at(metric).GetValue()); return EnumUtil::ToString(type); } return metrics.at(metric).ToString(); @@ -178,18 +168,25 @@ void ProfilingInfo::WriteMetricsToJSON(yyjson_mut_doc *doc, yyjson_mut_val *dest if (metric == MetricsType::EXTRA_INFO) { auto extra_info_obj = yyjson_mut_obj(doc); - for (auto &it : extra_info) { - auto &key = it.first; - auto &value = it.second; - auto splits = StringUtil::Split(value, "\n"); + auto extra_info = metrics.at(metric); + auto children = MapValue::GetChildren(extra_info); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + + auto key_mut = unsafe_yyjson_mut_strncpy(doc, key.c_str(), key.size()); + auto value_mut = unsafe_yyjson_mut_strncpy(doc, value.c_str(), value.size()); + + auto splits = StringUtil::Split(value_mut, "\n"); if (splits.size() > 1) { auto list_items = yyjson_mut_arr(doc); for (auto &split : splits) { yyjson_mut_arr_add_strcpy(doc, list_items, split.c_str()); } - yyjson_mut_obj_add_val(doc, extra_info_obj, key.c_str(), list_items); + yyjson_mut_obj_add_val(doc, extra_info_obj, key_mut, list_items); } else { - yyjson_mut_obj_add_strcpy(doc, extra_info_obj, key.c_str(), value.c_str()); + yyjson_mut_obj_add_strcpy(doc, extra_info_obj, key_mut, value_mut); } } yyjson_mut_obj_add_val(doc, dest, key_ptr, extra_info_obj); diff --git a/src/duckdb/src/main/query_profiler.cpp b/src/duckdb/src/main/query_profiler.cpp index 4c9c9328a..eb4116cca 100644 --- a/src/duckdb/src/main/query_profiler.cpp +++ b/src/duckdb/src/main/query_profiler.cpp @@ -16,6 +16,7 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/storage/buffer/buffer_pool.hpp" #include "yyjson.hpp" +#include "yyjson_utils.hpp" #include #include @@ -557,7 +558,7 @@ void QueryProfiler::Flush(OperatorProfiler &profiler) { info.MetricSum(MetricsType::RESULT_SET_SIZE, node.second.result_set_size); } if (ProfilingInfo::Enabled(profiler.settings, MetricsType::EXTRA_INFO)) { - info.extra_info = node.second.extra_info; + info.metrics[MetricsType::EXTRA_INFO] = Value::MAP(node.second.extra_info); } if (ProfilingInfo::Enabled(profiler.settings, MetricsType::SYSTEM_PEAK_BUFFER_MEMORY)) { query_metrics.query_global_info.MetricMax(MetricsType::SYSTEM_PEAK_BUFFER_MEMORY, @@ -721,18 +722,24 @@ void QueryProfiler::QueryTreeToStream(std::ostream &ss) const { } } -InsertionOrderPreservingMap QueryProfiler::JSONSanitize(const InsertionOrderPreservingMap &input) { +Value QueryProfiler::JSONSanitize(const Value &input) { + D_ASSERT(input.type().id() == LogicalTypeId::MAP); + InsertionOrderPreservingMap result; - for (auto &it : input) { - auto key = it.first; + auto children = MapValue::GetChildren(input); + for (auto &child : children) { + auto struct_children = StructValue::GetChildren(child); + auto key = struct_children[0].GetValue(); + auto value = struct_children[1].GetValue(); + if (StringUtil::StartsWith(key, "__")) { key = StringUtil::Replace(key, "__", ""); key = StringUtil::Replace(key, "_", " "); key = StringUtil::Title(key); } - result[key] = it.second; + result[key] = value; } - return result; + return Value::MAP(result); } string QueryProfiler::JSONSanitize(const std::string &text) { @@ -772,7 +779,10 @@ string QueryProfiler::JSONSanitize(const std::string &text) { static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) { auto result_obj = yyjson_mut_obj(doc); auto &profiling_info = node.GetProfilingInfo(); - profiling_info.extra_info = QueryProfiler::JSONSanitize(profiling_info.extra_info); + + profiling_info.metrics[MetricsType::EXTRA_INFO] = + QueryProfiler::JSONSanitize(profiling_info.metrics.at(MetricsType::EXTRA_INFO)); + profiling_info.WriteMetricsToJSON(doc, result_obj); auto children_list = yyjson_mut_arr(doc); @@ -784,44 +794,43 @@ static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) return result_obj; } -static string StringifyAndFree(yyjson_mut_doc *doc, yyjson_mut_val *object) { - auto data = yyjson_mut_val_write_opts(object, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, - nullptr, nullptr); - if (!data) { - yyjson_mut_doc_free(doc); +static string StringifyAndFree(ConvertedJSONHolder &json_holder, yyjson_mut_val *object) { + json_holder.stringified_json = yyjson_mut_val_write_opts( + object, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, nullptr, nullptr); + if (!json_holder.stringified_json) { throw InternalException("The plan could not be rendered as JSON, yyjson failed"); } - auto result = string(data); - free(data); - yyjson_mut_doc_free(doc); + auto result = string(json_holder.stringified_json); return result; } string QueryProfiler::ToJSON() const { lock_guard guard(lock); - auto doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, result_obj); + ConvertedJSONHolder json_holder; + + json_holder.doc = yyjson_mut_doc_new(nullptr); + auto result_obj = yyjson_mut_obj(json_holder.doc); + yyjson_mut_doc_set_root(json_holder.doc, result_obj); if (query_metrics.query.empty() && !root) { - yyjson_mut_obj_add_str(doc, result_obj, "result", "empty"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "empty"); + return StringifyAndFree(json_holder, result_obj); } if (!root) { - yyjson_mut_obj_add_str(doc, result_obj, "result", "error"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "error"); + return StringifyAndFree(json_holder, result_obj); } auto &settings = root->GetProfilingInfo(); - settings.WriteMetricsToJSON(doc, result_obj); + settings.WriteMetricsToJSON(json_holder.doc, result_obj); // recursively print the physical operator tree - auto children_list = yyjson_mut_arr(doc); - yyjson_mut_obj_add_val(doc, result_obj, "children", children_list); - auto child = ToJSONRecursive(doc, *root->GetChild(0)); + auto children_list = yyjson_mut_arr(json_holder.doc); + yyjson_mut_obj_add_val(json_holder.doc, result_obj, "children", children_list); + auto child = ToJSONRecursive(json_holder.doc, *root->GetChild(0)); yyjson_mut_arr_add_val(children_list, child); - return StringifyAndFree(doc, result_obj); + return StringifyAndFree(json_holder, result_obj); } void QueryProfiler::WriteToFile(const char *path, string &info) const { @@ -871,7 +880,7 @@ unique_ptr QueryProfiler::CreateTree(const PhysicalOperator &root info.MetricSum(MetricsType::OPERATOR_TYPE, static_cast(root_p.type)); } if (info.Enabled(info.settings, MetricsType::EXTRA_INFO)) { - info.extra_info = root_p.ParamsToString(); + info.metrics[MetricsType::EXTRA_INFO] = Value::MAP(root_p.ParamsToString()); } tree_map.insert(make_pair(reference(root_p), reference(*node))); @@ -905,12 +914,13 @@ string QueryProfiler::RenderDisabledMessage(ProfilerPrintFormat format) const { } )"; case ProfilerPrintFormat::JSON: { - auto doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, result_obj); + ConvertedJSONHolder json_holder; + json_holder.doc = yyjson_mut_doc_new(nullptr); + auto result_obj = yyjson_mut_obj(json_holder.doc); + yyjson_mut_doc_set_root(json_holder.doc, result_obj); - yyjson_mut_obj_add_str(doc, result_obj, "result", "disabled"); - return StringifyAndFree(doc, result_obj); + yyjson_mut_obj_add_str(json_holder.doc, result_obj, "result", "disabled"); + return StringifyAndFree(json_holder, result_obj); } default: throw InternalException("Unknown ProfilerPrintFormat \"%s\"", EnumUtil::ToString(format)); diff --git a/src/duckdb/src/optimizer/common_subplan_optimizer.cpp b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp new file mode 100644 index 000000000..0c3c9cb35 --- /dev/null +++ b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp @@ -0,0 +1,575 @@ +#include "duckdb/optimizer/common_subplan_optimizer.hpp" + +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/optimizer/cte_inlining.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/planner/operator/list.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Subplan Signature/Info +//===--------------------------------------------------------------------===// +struct PlanSignatureCreateState { + PlanSignatureCreateState() : stream(DEFAULT_BLOCK_ALLOC_SIZE), serializer(stream) { + } + + void Reset() { + to_canonical.clear(); + from_canonical.clear(); + table_indices.clear(); + expression_info.clear(); + } + + MemoryStream stream; + BinarySerializer serializer; + + unordered_map to_canonical; + unordered_map from_canonical; + + vector table_indices; + vector> expression_info; +}; + +class PlanSignature { +private: + PlanSignature(const MemoryStream &stream_p, idx_t offset_p, idx_t length_p, + vector> &&child_signatures_p, idx_t operator_count_p) + : stream(stream_p), offset(offset_p), length(length_p), + signature_hash(Hash(stream_p.GetData() + offset, length)), child_signatures(std::move(child_signatures_p)), + operator_count(operator_count_p) { + } + +public: + static unique_ptr Create(PlanSignatureCreateState &state, LogicalOperator &op, + vector> &&child_signatures, + const idx_t operator_count) { + state.Reset(); + if (!OperatorIsSupported(op)) { + return nullptr; + } + + if (op.type == LogicalOperatorType::LOGICAL_CHUNK_GET && + op.Cast().collection->Count() > 1000) { + // Avoid serializing massive amounts of data (this is here because of the "Test TPCH arrow roundtrip" test) + return nullptr; + } + + // Construct maps for converting column bindings to canonical representation and back + static constexpr idx_t CANONICAL_TABLE_INDEX_OFFSET = 10000000000000; + for (const auto &child_op : op.children) { + for (const auto &child_cb : child_op->GetColumnBindings()) { + const auto &original = child_cb.table_index; + auto it = state.to_canonical.find(original); + if (it != state.to_canonical.end()) { + continue; // We've seen this table index before + } + const auto canonical = CANONICAL_TABLE_INDEX_OFFSET + state.to_canonical.size(); + state.to_canonical[original] = canonical; + state.from_canonical[canonical] = original; + } + } + + // Convert operators to canonical table indices + ConvertTableIndices(op, state.table_indices); + + // Convert expressions to canonical (table indices, aliases, query locations) + bool can_materialize = ConvertExpressions(op, state.to_canonical, state.expression_info); + + // Temporarily move children here as we don't want to serialize them + auto children = std::move(op.children); + op.children.clear(); + + // TODO: to allow for better detection of equivalent plans, we could: + // 1. Sort the children of operators + // 2. Sort the expressions of operators + + // Serialize canonical representation of operator + const auto offset = state.stream.GetPosition(); + state.serializer.Begin(); + try { // Operators will throw if they cannot serialize, so we need to try/catch here + op.Serialize(state.serializer); + } catch (std::exception &) { + can_materialize = false; + } + state.serializer.End(); + const auto length = state.stream.GetPosition() - offset; + + // Convert back from canonical + ConvertTableIndices(op, state.table_indices); + ConvertExpressions(op, state.from_canonical, state.expression_info); + + // Restore children + op.children = std::move(children); + + if (can_materialize) { + return unique_ptr( + new PlanSignature(state.stream, offset, length, std::move(child_signatures), operator_count)); + } + return nullptr; + } + + idx_t OperatorCount() const { + return operator_count; + } + + hash_t HashSignature() const { + auto res = signature_hash; + for (auto &child : child_signatures) { + res = CombineHash(res, child.get().HashSignature()); + } + return res; + } + + bool Equals(const PlanSignature &other) const { + if (this->GetSignature() != other.GetSignature()) { + return false; + } + if (this->child_signatures.size() != other.child_signatures.size()) { + return false; + } + for (idx_t child_idx = 0; child_idx < this->child_signatures.size(); ++child_idx) { + if (!this->child_signatures[child_idx].get().Equals(other.child_signatures[child_idx].get())) { + return false; + } + } + return true; + } + +private: + String GetSignature() const { + return String(char_ptr_cast(stream.GetData() + offset), NumericCast(length)); + } + + static bool OperatorIsSupported(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: + case LogicalOperatorType::LOGICAL_FILTER: + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_WINDOW: + case LogicalOperatorType::LOGICAL_UNNEST: + case LogicalOperatorType::LOGICAL_LIMIT: + case LogicalOperatorType::LOGICAL_ORDER_BY: + case LogicalOperatorType::LOGICAL_TOP_N: + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_PIVOT: + case LogicalOperatorType::LOGICAL_GET: + case LogicalOperatorType::LOGICAL_CHUNK_GET: + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: + case LogicalOperatorType::LOGICAL_EMPTY_RESULT: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: + return true; + default: + // Unsupported: + // - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: + // - case LogicalOperatorType::LOGICAL_SAMPLE: + // - case LogicalOperatorType::LOGICAL_COPY_DATABASE: + // - case LogicalOperatorType::LOGICAL_DELIM_GET: + // - case LogicalOperatorType::LOGICAL_CTE_REF: + // - case LogicalOperatorType::LOGICAL_JOIN: + // - case LogicalOperatorType::LOGICAL_DELIM_JOIN: + // - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: + // - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: + // - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: + // - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR + return false; + } + } + + template + static void ConvertTableIndices(LogicalOperator &op, vector &table_indices) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_GET: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_CHUNK_GET: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_EXPRESSION_GET: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_DUMMY_SCAN: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_CTE_REF: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_PROJECTION: { + ConvertTableIndicesGeneric(op, table_indices); + break; + } + case LogicalOperatorType::LOGICAL_PIVOT: { + auto &pivot = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(pivot.pivot_index); + } + pivot.pivot_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + case LogicalOperatorType::LOGICAL_UNNEST: { + auto &unnest = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(unnest.unnest_index); + } + unnest.unnest_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + case LogicalOperatorType::LOGICAL_WINDOW: { + auto &window = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(window.window_index); + } + window.window_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + auto &aggregate = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(aggregate.group_index); + table_indices.emplace_back(aggregate.aggregate_index); + table_indices.emplace_back(aggregate.groupings_index); + } + aggregate.group_index = TO_CANONICAL ? 0 : table_indices[0]; + aggregate.aggregate_index = TO_CANONICAL ? 1 : table_indices[1]; + aggregate.groupings_index = TO_CANONICAL ? 2 : table_indices[2]; + break; + } + case LogicalOperatorType::LOGICAL_UNION: + case LogicalOperatorType::LOGICAL_EXCEPT: + case LogicalOperatorType::LOGICAL_INTERSECT: { + auto &setop = op.Cast(); + if (TO_CANONICAL) { + table_indices.emplace_back(setop.table_index); + } + setop.table_index = TO_CANONICAL ? 0 : table_indices[0]; + break; + } + default: + break; + } + } + + template + static void ConvertTableIndicesGeneric(LogicalOperator &op, vector &table_idxs) { + auto &generic = op.Cast(); + if (TO_CANONICAL) { + table_idxs.emplace_back(generic.table_index); + } + generic.table_index = TO_CANONICAL ? 0 : table_idxs[0]; + } + + static bool ConvertExpressions(LogicalOperator &op, const unordered_map &table_index_mapping, + vector> &expression_info) { + bool can_materialize = true; + const auto to_canonical = expression_info.empty(); + idx_t info_idx = 0; + LogicalOperatorVisitor::EnumerateExpressions(op, [&](unique_ptr *expr) { + ExpressionIterator::EnumerateExpression(*expr, [&](unique_ptr &child) { + if (child->GetExpressionClass() == ExpressionClass::BOUND_COLUMN_REF) { + auto &col_ref = child->Cast(); + auto &table_index = col_ref.binding.table_index; + auto it = table_index_mapping.find(table_index); + D_ASSERT(it != table_index_mapping.end()); + table_index = it->second; + } + if (to_canonical) { + expression_info.emplace_back(std::move(child->alias), child->query_location); + child->alias.clear(); + child->query_location.SetInvalid(); + } else { + auto &info = expression_info[info_idx++]; + child->alias = std::move(info.first); + child->query_location = info.second; + } + if (child->IsVolatile()) { + can_materialize = false; + } + }); + }); + return can_materialize; + } + +private: + const MemoryStream &stream; + const idx_t offset; + const idx_t length; + + const hash_t signature_hash; + + const vector> child_signatures; + const idx_t operator_count; +}; + +struct PlanSignatureHash { + std::size_t operator()(const PlanSignature &k) const { + return k.HashSignature(); + } +}; + +struct PlanSignatureEquality { + bool operator()(const PlanSignature &a, const PlanSignature &b) const { + return a.Equals(b); + } +}; + +struct SubplanInfo { + explicit SubplanInfo(unique_ptr &op) : subplans({op}), lowest_common_ancestor(op) { + } + vector>> subplans; + reference> lowest_common_ancestor; +}; + +using subplan_map_t = unordered_map, SubplanInfo, PlanSignatureHash, PlanSignatureEquality>; + +//===--------------------------------------------------------------------===// +// CommonSubplanFinder +//===--------------------------------------------------------------------===// +class CommonSubplanFinder { +public: + CommonSubplanFinder() { + } + +private: + struct OperatorInfo { + OperatorInfo(unique_ptr &parent_p, const idx_t &depth_p) : parent(parent_p), depth(depth_p) { + } + + unique_ptr &parent; + const idx_t depth; + unique_ptr signature; + }; + + struct StackNode { + explicit StackNode(unique_ptr &op_p) : op(op_p), child_index(0) { + } + + bool HasMoreChildren() const { + return child_index < op->children.size(); + } + + unique_ptr &GetNextChild() { + D_ASSERT(child_index < op->children.size()); + return op->children[child_index++]; + }; + + unique_ptr &op; + idx_t child_index; + }; + +public: + subplan_map_t FindCommonSubplans(reference> root) { + // Find first operator with more than 1 child + while (root.get()->children.size() == 1) { + root = root.get()->children[0]; + } + + // Recurse through query plan using stack-based recursion + vector stack; + stack.emplace_back(root); + operator_infos.emplace(root, OperatorInfo(root, 0)); + + while (!stack.empty()) { + auto ¤t = stack.back(); + + // Depth-first + if (current.HasMoreChildren()) { + auto &child = current.GetNextChild(); + operator_infos.emplace(child, OperatorInfo(current.op, stack.size())); + stack.emplace_back(child); + continue; + } + + if (!RefersToSameObject(current.op, root.get())) { + // We have all child information for this operator now, compute signature + auto &signature = operator_infos.find(current.op)->second.signature; + signature = CreatePlanSignature(current.op); + + // Add to subplans (if we got actually got a signature) + if (signature) { + auto it = subplans.find(*signature); + if (it == subplans.end()) { + subplans.emplace(*signature, SubplanInfo(current.op)); + } else { + auto &info = it->second; + info.subplans.emplace_back(current.op); + info.lowest_common_ancestor = LowestCommonAncestor(info.lowest_common_ancestor, current.op); + } + } + } + + // Done with current + stack.pop_back(); + } + + // Filter out redundant or ineligible subplans before returning + for (auto it = subplans.begin(); it != subplans.end();) { + if (it->first.get().OperatorCount() == 1) { + it = subplans.erase(it); // Just one operator in this subplan + continue; + } + if (it->second.subplans.size() == 1) { + it = subplans.erase(it); // No other identical subplan + continue; + } + auto &subplan = it->second.subplans[0].get(); + auto &parent = operator_infos.find(subplan)->second.parent; + auto &parent_signature = operator_infos.find(parent)->second.signature; + if (parent_signature) { + auto parent_it = subplans.find(*parent_signature); + if (parent_it != subplans.end() && it->second.subplans.size() == parent_it->second.subplans.size()) { + it = subplans.erase(it); // Parent has exact same number of identical subplans + continue; + } + } + if (!CTEInlining::EndsInAggregateOrDistinct(*subplan)) { + it = subplans.erase(it); // Not eligible for materialization + continue; + } + it++; // This subplan might be useful + } + + return std::move(subplans); + } + +private: + unique_ptr CreatePlanSignature(const unique_ptr &op) { + vector> child_signatures; + idx_t operator_count = 1; + for (auto &child : op->children) { + auto it = operator_infos.find(child); + D_ASSERT(it != operator_infos.end()); + if (!it->second.signature) { + return nullptr; // Failed to create signature from one of the children + } + child_signatures.emplace_back(*it->second.signature); + operator_count += it->second.signature->OperatorCount(); + } + return PlanSignature::Create(state, *op, std::move(child_signatures), operator_count); + } + + unique_ptr &LowestCommonAncestor(reference> a, + reference> b) { + auto a_it = operator_infos.find(a); + auto b_it = operator_infos.find(b); + D_ASSERT(a_it != operator_infos.end() && b_it != operator_infos.end()); + + // Get parents of a and b until they're at the same depth + while (a_it->second.depth > b_it->second.depth) { + a = a_it->second.parent; + a_it = operator_infos.find(a); + D_ASSERT(a_it != operator_infos.end()); + } + while (b_it->second.depth > a_it->second.depth) { + b = b_it->second.parent; + b_it = operator_infos.find(b); + D_ASSERT(b_it != operator_infos.end()); + } + + // Move up one level at a time for both until ancestor is the same + while (!RefersToSameObject(a, b)) { + a_it = operator_infos.find(a); + b_it = operator_infos.find(b); + D_ASSERT(a_it != operator_infos.end() && b_it != operator_infos.end()); + a = a_it->second.parent; + b = b_it->second.parent; + } + + return a.get(); + } + +private: + //! Mapping from operator to info + reference_map_t, OperatorInfo> operator_infos; + //! Mapping from subplan signature to subplan information + subplan_map_t subplans; + //! State for creating PlanSignature with reusable data structures + PlanSignatureCreateState state; +}; + +//===--------------------------------------------------------------------===// +// CommonSubplanOptimizer +//===--------------------------------------------------------------------===// +CommonSubplanOptimizer::CommonSubplanOptimizer(Optimizer &optimizer_p) : optimizer(optimizer_p) { +} + +static void ConvertSubplansToCTE(Optimizer &optimizer, unique_ptr &op, SubplanInfo &subplan_info) { + const auto cte_index = optimizer.binder.GenerateTableIndex(); + const auto cte_name = StringUtil::Format("__common_subplan_1"); + + // Resolve types to be used for creating the materialized CTE and refs + op->ResolveOperatorTypes(); + + // Get types and names + const auto &types = subplan_info.subplans[0].get()->types; + vector col_names; + for (idx_t i = 0; i < types.size(); i++) { + col_names.emplace_back(StringUtil::Format("%s_col_%llu", cte_name, i)); + } + + // Create CTE refs and figure out column binding replacements + vector> cte_refs; + ColumnBindingReplacer replacer; + for (auto &subplan : subplan_info.subplans) { + cte_refs.emplace_back( + make_uniq(optimizer.binder.GenerateTableIndex(), cte_index, types, col_names)); + const auto old_bindings = subplan.get()->GetColumnBindings(); + const auto new_bindings = cte_refs.back()->GetColumnBindings(); + D_ASSERT(old_bindings.size() == new_bindings.size()); + for (idx_t i = 0; i < old_bindings.size(); i++) { + replacer.replacement_bindings.emplace_back(old_bindings[i], new_bindings[i]); + } + } + + // Create the materialized CTE and replace the common subplans with references to it + auto &lowest_common_ancestor = subplan_info.lowest_common_ancestor.get(); + auto cte = + make_uniq(cte_name, cte_index, types.size(), std::move(subplan_info.subplans[0].get()), + std::move(lowest_common_ancestor), CTEMaterialize::CTE_MATERIALIZE_DEFAULT); + for (idx_t i = 0; i < subplan_info.subplans.size(); i++) { + subplan_info.subplans[i].get() = std::move(cte_refs[i]); + } + lowest_common_ancestor = std::move(cte); + + // Replace bindings of subplans with those of the CTE refs + replacer.stop_operator = lowest_common_ancestor.get(); + replacer.VisitOperator(*op); // Replace from the root until CTE + replacer.VisitOperator(*lowest_common_ancestor->children[1]); // Replace in CTE child +} + +unique_ptr CommonSubplanOptimizer::Optimize(unique_ptr op) { + // Bottom-up identification of identical subplans + CommonSubplanFinder finder; + auto subplans = finder.FindCommonSubplans(op); + + // Identify the single best subplan (TODO: for now, in the future we should identify multiple) + if (subplans.empty()) { + return op; // No eligible subplans + } + auto best_it = subplans.begin(); + for (auto it = ++subplans.begin(); it != subplans.end(); it++) { + if (it->first.get().OperatorCount() > best_it->first.get().OperatorCount()) { + best_it = it; + } + } + + // Create a CTE! + ConvertSubplansToCTE(optimizer, op, best_it->second); + return op; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/cte_inlining.cpp b/src/duckdb/src/optimizer/cte_inlining.cpp index 0b9e942ee..116d64768 100644 --- a/src/duckdb/src/optimizer/cte_inlining.cpp +++ b/src/duckdb/src/optimizer/cte_inlining.cpp @@ -55,10 +55,14 @@ static bool ContainsLimit(const LogicalOperator &op) { return false; } -static bool EndsInAggregateOrDistinct(const LogicalOperator &op) { - if (op.type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY || - op.type == LogicalOperatorType::LOGICAL_DISTINCT) { +bool CTEInlining::EndsInAggregateOrDistinct(const LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: + case LogicalOperatorType::LOGICAL_DISTINCT: + case LogicalOperatorType::LOGICAL_WINDOW: return true; + default: + break; } if (op.children.size() != 1) { return false; @@ -146,8 +150,7 @@ void CTEInlining::TryInlining(unique_ptr &op) { } } -bool CTEInlining::Inline(unique_ptr &op, LogicalOperator &materialized_cte, - bool requires_copy) { +bool CTEInlining::Inline(unique_ptr &op, LogicalOperator &materialized_cte, bool requires_copy) { if (op->type == LogicalOperatorType::LOGICAL_CTE_REF) { auto &cteref = op->Cast(); auto &cte = materialized_cte.Cast(); diff --git a/src/duckdb/src/optimizer/filter_pushdown.cpp b/src/duckdb/src/optimizer/filter_pushdown.cpp index c4f7bb04b..4fa17f7d0 100644 --- a/src/duckdb/src/optimizer/filter_pushdown.cpp +++ b/src/duckdb/src/optimizer/filter_pushdown.cpp @@ -276,51 +276,52 @@ unique_ptr FilterPushdown::PushFinalFilters(unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { - if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - for (idx_t i = 0; i < filters.size(); i++) { - auto &f = *filters[i]; - for (auto &child : op->children) { - FilterPushdown pushdown(optimizer, convert_mark_joins); +unique_ptr FilterPushdown::PushFiltersIntoDelimJoin(unique_ptr op) { + for (idx_t i = 0; i < filters.size(); i++) { + auto &f = *filters[i]; + for (auto &child : op->children) { + FilterPushdown pushdown(optimizer, convert_mark_joins); - // check if filter bindings can be applied to the child bindings. - auto child_bindings = child->GetColumnBindings(); - unordered_set child_bindings_table; - for (auto &binding : child_bindings) { - child_bindings_table.insert(binding.table_index); - } + // check if filter bindings can be applied to the child bindings. + auto child_bindings = child->GetColumnBindings(); + unordered_set child_bindings_table; + for (auto &binding : child_bindings) { + child_bindings_table.insert(binding.table_index); + } - // Check if ALL bindings of the filter are present in the child - bool should_push = true; - for (auto &binding : f.bindings) { - if (child_bindings_table.find(binding) == child_bindings_table.end()) { - should_push = false; - break; - } + // Check if ALL bindings of the filter are present in the child + bool should_push = true; + for (auto &binding : f.bindings) { + if (child_bindings_table.find(binding) == child_bindings_table.end()) { + should_push = false; + break; } + } - if (!should_push) { - continue; - } + if (!should_push) { + continue; + } - // copy the filter - auto filter_copy = f.filter->Copy(); - if (pushdown.AddFilter(std::move(filter_copy)) == FilterResult::UNSATISFIABLE) { - return make_uniq(std::move(op)); - } + // copy the filter + auto filter_copy = f.filter->Copy(); + if (pushdown.AddFilter(std::move(filter_copy)) == FilterResult::UNSATISFIABLE) { + return make_uniq(std::move(op)); + } - // push the filter into the child. - pushdown.GenerateFilters(); - child = pushdown.Rewrite(std::move(child)); + // push the filter into the child. + pushdown.GenerateFilters(); + child = pushdown.Rewrite(std::move(child)); - // Don't push same filter again - filters.erase_at(i); - i--; - break; - } + // Don't push same filter again + filters.erase_at(i); + i--; + break; } } + return op; +} +unique_ptr FilterPushdown::FinishPushdown(unique_ptr op) { // unhandled type, first perform filter pushdown in its children for (auto &child : op->children) { FilterPushdown pushdown(optimizer, convert_mark_joins); diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index ce6cb0045..3007fa9ac 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -34,6 +34,7 @@ #include "duckdb/optimizer/topn_optimizer.hpp" #include "duckdb/optimizer/unnest_rewriter.hpp" #include "duckdb/optimizer/late_materialization.hpp" +#include "duckdb/optimizer/common_subplan_optimizer.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/planner.hpp" @@ -127,6 +128,12 @@ void Optimizer::RunBuiltInOptimizers() { plan = cte_inlining.Optimize(std::move(plan)); }); + // convert common subplans into materialized CTEs + RunOptimizer(OptimizerType::COMMON_SUBPLAN, [&]() { + CommonSubplanOptimizer common_subplan_optimizer(*this); + plan = common_subplan_optimizer.Optimize(std::move(plan)); + }); + // Rewrites SUM(x + C) into SUM(x) + C * COUNT(x) RunOptimizer(OptimizerType::SUM_REWRITER, [&]() { SumRewriterOptimizer optimizer(*this); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp index 8370f4ca9..e2e4730d1 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_inner_join.cpp @@ -14,6 +14,7 @@ unique_ptr FilterPushdown::PushdownInnerJoin(unique_ptrCast(); D_ASSERT(join.join_type == JoinType::INNER); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } // inner join: gather all the conditions of the inner join and add to the filter list diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp index 9e56ed9d6..1ebf3cedd 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_left_join.cpp @@ -78,6 +78,7 @@ unique_ptr FilterPushdown::PushdownLeftJoin(unique_ptr &right_bindings) { auto &join = op->Cast(); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } FilterPushdown left_pushdown(optimizer, convert_mark_joins), right_pushdown(optimizer, convert_mark_joins); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp index 7d240e3f6..0b937fe25 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_semi_anti_join.cpp @@ -12,6 +12,7 @@ using Filter = FilterPushdown::Filter; unique_ptr FilterPushdown::PushdownSemiAntiJoin(unique_ptr op) { auto &join = op->Cast(); if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + op = PushFiltersIntoDelimJoin(std::move(op)); return FinishPushdown(std::move(op)); } diff --git a/src/duckdb/src/parser/parser.cpp b/src/duckdb/src/parser/parser.cpp index 552b6e180..3695f75dc 100644 --- a/src/duckdb/src/parser/parser.cpp +++ b/src/duckdb/src/parser/parser.cpp @@ -165,27 +165,29 @@ bool Parser::StripUnicodeSpaces(const string &query_str, string &new_query) { return ReplaceUnicodeSpaces(query_str, new_query, unicode_spaces); } -vector SplitQueryStringIntoStatements(const string &query) { - // Break sql string down into sql statements using the tokenizer - vector query_statements; - auto tokens = Parser::Tokenize(query); - idx_t next_statement_start = 0; - for (idx_t i = 1; i < tokens.size(); ++i) { - auto &t_prev = tokens[i - 1]; - auto &t = tokens[i]; - if (t_prev.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR) { - // LCOV_EXCL_START - for (idx_t c = t_prev.start; c <= t.start; ++c) { - if (query.c_str()[c] == ';') { - query_statements.emplace_back(query.substr(next_statement_start, t.start - next_statement_start)); - next_statement_start = tokens[i].start; - } +vector SplitQueries(const string &input_query) { + vector queries; + auto tokenized_input = Parser::Tokenize(input_query); + size_t last_split = 0; + + for (const auto &token : tokenized_input) { + if (token.type == SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR && input_query[token.start] == ';') { + string segment = input_query.substr(last_split, token.start - last_split); + StringUtil::Trim(segment); + if (!segment.empty()) { + segment.append(";"); + queries.push_back(std::move(segment)); } - // LCOV_EXCL_STOP + last_split = token.start + 1; } } - query_statements.emplace_back(query.substr(next_statement_start, query.size() - next_statement_start)); - return query_statements; + string final_segment = input_query.substr(last_split); + StringUtil::Trim(final_segment); + if (!final_segment.empty()) { + final_segment.append(";"); + queries.push_back(std::move(final_segment)); + } + return queries; } void Parser::ParseQuery(const string &query) { @@ -250,9 +252,9 @@ void Parser::ParseQuery(const string &query) { throw ParserException::SyntaxError(query, parser_error, parser_error_location); } else { // split sql string into statements and re-parse using extension - auto query_statements = SplitQueryStringIntoStatements(query); + auto queries = SplitQueries(query); idx_t stmt_loc = 0; - for (auto const &query_statement : query_statements) { + for (auto const &query_statement : queries) { ErrorData another_parser_error; // Creating a new scope to allow extensions to use PostgresParser, which is not reentrant { diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 2ba52b64f..01e6fbfca 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -434,7 +434,7 @@ void Binder::MoveCorrelatedExpressions(Binder &other) { other.correlated_columns.clear(); } -void Binder::MergeCorrelatedColumns(vector &other) { +void Binder::MergeCorrelatedColumns(CorrelatedColumns &other) { for (idx_t i = 0; i < other.size(); i++) { AddCorrelatedColumn(other[i]); } @@ -443,7 +443,7 @@ void Binder::MergeCorrelatedColumns(vector &other) { void Binder::AddCorrelatedColumn(const CorrelatedColumnInfo &info) { // we only add correlated columns to the list if they are not already there if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(info); + correlated_columns.AddColumn(info); } } diff --git a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp index 2664903d3..6b979bf17 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp @@ -186,9 +186,10 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq } } -static unique_ptr -CreateDuplicateEliminatedJoin(const vector &correlated_columns, JoinType join_type, - unique_ptr original_plan, bool perform_delim) { +static unique_ptr CreateDuplicateEliminatedJoin(const CorrelatedColumns &correlated_columns, + JoinType join_type, + unique_ptr original_plan, + bool perform_delim) { auto delim_join = make_uniq(join_type); delim_join->correlated_columns = correlated_columns; delim_join->perform_delim = perform_delim; @@ -216,7 +217,7 @@ static bool PerformDelimOnType(const LogicalType &type) { return true; } -static bool PerformDuplicateElimination(Binder &binder, vector &correlated_columns) { +static bool PerformDuplicateElimination(Binder &binder, CorrelatedColumns &correlated_columns) { if (!ClientConfig::GetConfig(binder.context).enable_optimizer) { // if optimizations are disabled we always do a delim join return true; @@ -235,7 +236,8 @@ static bool PerformDuplicateElimination(Binder &binder, vector &expr_ptr, unique_ptr Binder::PlanLateralJoin(unique_ptr left, unique_ptr right, - vector &correlated, JoinType join_type, + CorrelatedColumns &correlated, JoinType join_type, unique_ptr condition) { // scan the right operator for correlated columns // correlated LATERAL JOIN diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp index 76b43f60a..f1ffd7496 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -345,11 +345,7 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { try { dummy_binder->Bind(*query_node); } catch (const std::exception &ex) { - // TODO: we would like to do something like "error = ErrorData(ex);" here, - // but that breaks macro's like "create macro m(x) as table (from query_table(x));", - // because dummy-binding these always throws an error instead of a ParameterNotResolvedException. - // So, for now, we allow macro's with bind errors to be created. - // Binding is still useful because we can create the dependencies. + error = ErrorData(ex); } } diff --git a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp index 87a9726ec..b52a04cf2 100644 --- a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp @@ -232,10 +232,18 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { auto bound_join_node = Bind(join); auto root = CreatePlan(*bound_join_node); + auto join_ref = reference(*root); + while (join_ref.get().children.size() == 1) { + join_ref = *join_ref.get().children[0]; + } + if (join_ref.get().children.size() != 2) { + throw NotImplementedException("Expected a join after binding a join operator - but got a %s", + join_ref.get().type); + } // kind of hacky, CreatePlan turns a RIGHT join into a LEFT join so the children get reversed from what we need bool inverted = join.type == JoinType::RIGHT; - auto &source = root->children[inverted ? 1 : 0]; - auto &get = root->children[inverted ? 0 : 1]->Cast(); + auto &source = join_ref.get().children[inverted ? 1 : 0]; + auto &get = join_ref.get().children[inverted ? 0 : 1]->Cast(); auto merge_into = make_uniq(table); merge_into->table_index = GenerateTableIndex(); diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index 2eb211530..b0d0fffcb 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -58,10 +58,15 @@ static void ConstructPivots(PivotRef &ref, vector &pivot_valu } } -static void ExtractPivotExpressions(ParsedExpression &root_expr, case_insensitive_set_t &handled_columns) { +static void ExtractPivotExpressions(ParsedExpression &root_expr, case_insensitive_set_t &handled_columns, + optional_ptr macro_binding) { ParsedExpressionIterator::VisitExpression( root_expr, [&](const ColumnRefExpression &child_colref) { if (child_colref.IsQualified()) { + if (child_colref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && macro_binding && + macro_binding->HasMatchingBinding(child_colref.GetName())) { + throw ParameterNotResolvedException(); + } throw BinderException(child_colref, "PIVOT expression cannot contain qualified columns"); } handled_columns.insert(child_colref.GetColumnName()); @@ -492,7 +497,7 @@ unique_ptr Binder::BindPivot(PivotRef &ref, vector Binder::BindPivot(PivotRef &ref, vector Binder::Bind(TableFunctionRef &ref) { } } - auto get = BindTableFunctionInternal(table_function, ref, std::move(parameters), std::move(named_parameters), - std::move(input_table_types), std::move(input_table_names)); + unique_ptr get; + try { + get = BindTableFunctionInternal(table_function, ref, std::move(parameters), std::move(named_parameters), + std::move(input_table_types), std::move(input_table_names)); + } catch (std::exception &ex) { + error = ErrorData(ex); + error.AddQueryLocation(ref); + error.Throw(); + } auto table_function_ref = make_uniq(std::move(get)); table_function_ref->subquery = std::move(subquery); return std::move(table_function_ref); diff --git a/src/duckdb/src/planner/expression_binder.cpp b/src/duckdb/src/planner/expression_binder.cpp index 5141765bb..220714733 100644 --- a/src/duckdb/src/planner/expression_binder.cpp +++ b/src/duckdb/src/planner/expression_binder.cpp @@ -103,7 +103,9 @@ BindResult ExpressionBinder::BindExpression(unique_ptr &expr, case ExpressionClass::STAR: return BindResult(BinderException::Unsupported(expr_ref, "STAR expression is not supported here")); default: - throw NotImplementedException("Unimplemented expression class"); + return BindResult( + NotImplementedException("Unimplemented expression class in ExpressionBinder::BindExpression: %s", + EnumUtil::ToString(expr_ref.GetExpressionClass()))); } } diff --git a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp index 205b644e8..d532a7a40 100644 --- a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp @@ -17,7 +17,7 @@ void LateralBinder::ExtractCorrelatedColumns(Expression &expr) { // add the correlated column info CorrelatedColumnInfo info(bound_colref); if (std::find(correlated_columns.begin(), correlated_columns.end(), info) == correlated_columns.end()) { - correlated_columns.push_back(std::move(info)); + correlated_columns.AddColumn(std::move(info)); // TODO is adding to the front OK here? } } } @@ -54,8 +54,7 @@ string LateralBinder::UnsupportedAggregateMessage() { return "LATERAL join cannot contain aggregates!"; } -static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, - const vector &correlated_columns) { +static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, const CorrelatedColumns &correlated_columns) { // don't need to reduce this if (expr.depth == 0) { return; @@ -69,8 +68,7 @@ static void ReduceColumnRefDepth(BoundColumnRefExpression &expr, } } -static void ReduceColumnDepth(vector &columns, - const vector &affected_columns) { +static void ReduceColumnDepth(CorrelatedColumns &columns, const CorrelatedColumns &affected_columns) { for (auto &s_correlated : columns) { for (auto &affected : affected_columns) { if (affected == s_correlated) { @@ -83,8 +81,7 @@ static void ReduceColumnDepth(vector &columns, class ExpressionDepthReducerRecursive : public BoundNodeVisitor { public: - explicit ExpressionDepthReducerRecursive(const vector &correlated) - : correlated_columns(correlated) { + explicit ExpressionDepthReducerRecursive(const CorrelatedColumns &correlated) : correlated_columns(correlated) { } void VisitExpression(unique_ptr &expression) override { @@ -106,20 +103,19 @@ class ExpressionDepthReducerRecursive : public BoundNodeVisitor { BoundNodeVisitor::VisitBoundTableRef(ref); } - static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, - const vector &correlated_columns) { + static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, const CorrelatedColumns &correlated_columns) { ReduceColumnDepth(expr.binder->correlated_columns, correlated_columns); ExpressionDepthReducerRecursive recursive(correlated_columns); recursive.VisitBoundQueryNode(*expr.subquery); } private: - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; class ExpressionDepthReducer : public LogicalOperatorVisitor { public: - explicit ExpressionDepthReducer(const vector &correlated) : correlated_columns(correlated) { + explicit ExpressionDepthReducer(const CorrelatedColumns &correlated) : correlated_columns(correlated) { } protected: @@ -133,10 +129,10 @@ class ExpressionDepthReducer : public LogicalOperatorVisitor { return nullptr; } - const vector &correlated_columns; + const CorrelatedColumns &correlated_columns; }; -void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const vector &correlated) { +void LateralBinder::ReduceExpressionDepth(LogicalOperator &op, const CorrelatedColumns &correlated) { ExpressionDepthReducer depth_reducer(correlated); depth_reducer.VisitOperator(op); } diff --git a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp index 198bd072b..720dbe37d 100644 --- a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp @@ -27,9 +27,13 @@ BindResult TableFunctionBinder::BindColumnReference(unique_ptr if (lambda_ref) { return BindLambdaReference(lambda_ref->Cast(), depth); } + if (binder.macro_binding && binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { throw ParameterNotResolvedException(); } + } else if (col_ref.column_names[0].find(DummyBinding::DUMMY_NAME) != string::npos && binder.macro_binding && + binder.macro_binding->HasMatchingBinding(col_ref.GetName())) { + throw ParameterNotResolvedException(); } auto query_location = col_ref.GetQueryLocation(); diff --git a/src/duckdb/src/planner/operator/logical_dependent_join.cpp b/src/duckdb/src/planner/operator/logical_dependent_join.cpp index 2e46dbc78..70af8444a 100644 --- a/src/duckdb/src/planner/operator/logical_dependent_join.cpp +++ b/src/duckdb/src/planner/operator/logical_dependent_join.cpp @@ -3,7 +3,7 @@ namespace duckdb { LogicalDependentJoin::LogicalDependentJoin(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition) : LogicalComparisonJoin(type, LogicalOperatorType::LOGICAL_DEPENDENT_JOIN), join_condition(std::move(condition)), correlated_columns(std::move(correlated_columns)) { @@ -17,7 +17,7 @@ LogicalDependentJoin::LogicalDependentJoin(JoinType join_type) unique_ptr LogicalDependentJoin::Create(unique_ptr left, unique_ptr right, - vector correlated_columns, JoinType type, + CorrelatedColumns correlated_columns, JoinType type, unique_ptr condition) { return make_uniq(std::move(left), std::move(right), std::move(correlated_columns), type, std::move(condition)); diff --git a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp index 7b2909c6d..2fd7a50f4 100644 --- a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp +++ b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp @@ -18,9 +18,8 @@ namespace duckdb { -FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated, - bool perform_delim, bool any_join, - optional_ptr parent) +FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const CorrelatedColumns &correlated, bool perform_delim, + bool any_join, optional_ptr parent) : binder(binder), delim_offset(DConstants::INVALID_INDEX), correlated_columns(correlated), perform_delim(perform_delim), any_join(any_join), parent(parent) { for (idx_t i = 0; i < correlated_columns.size(); i++) { @@ -30,8 +29,7 @@ FlattenDependentJoins::FlattenDependentJoins(Binder &binder, const vector &correlated_columns, +static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, const CorrelatedColumns &correlated_columns, vector bindings, idx_t base_offset, bool perform_delim) { auto col_count = perform_delim ? correlated_columns.size() : 1; for (idx_t i = 0; i < col_count; i++) { @@ -50,7 +48,7 @@ static void CreateDelimJoinConditions(LogicalComparisonJoin &delim_join, unique_ptr FlattenDependentJoins::DecorrelateIndependent(Binder &binder, unique_ptr plan) { - vector correlated; + CorrelatedColumns correlated; FlattenDependentJoins flatten(binder, correlated); return flatten.Decorrelate(std::move(plan)); } @@ -80,12 +78,12 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrsecond = false; // rewrite - idx_t lateral_depth = 0; + idx_t next_lateral_depth = 0; - RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, lateral_depth); + RewriteCorrelatedExpressions rewriter(base_binding, correlated_map, next_lateral_depth); rewriter.VisitOperator(*plan); - RewriteCorrelatedExpressions recursive_rewriter(base_binding, correlated_map, lateral_depth, true); + RewriteCorrelatedExpressions recursive_rewriter(base_binding, correlated_map, next_lateral_depth, true); recursive_rewriter.VisitOperator(*plan); } else { op.children[0] = Decorrelate(std::move(op.children[0])); @@ -94,8 +92,8 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptr(op.correlated_columns[0].binding.table_index); + const auto &op_col = op.correlated_columns[op.correlated_columns.GetDelimIndex()]; + auto window = make_uniq(op_col.binding.table_index); auto row_number = make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; @@ -114,9 +112,9 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrchildren[1], op.is_lateral_join, lateral_depth); if (delim_join->children[1]->type == LogicalOperatorType::LOGICAL_MATERIALIZED_CTE) { - auto &cte = delim_join->children[1]->Cast(); + auto &cte_ref = delim_join->children[1]->Cast(); // check if the left side of the CTE has correlated expressions - auto entry = flatten.has_correlated_expressions.find(*cte.children[0]); + auto entry = flatten.has_correlated_expressions.find(*cte_ref.children[0]); if (entry != flatten.has_correlated_expressions.end()) { if (!entry->second) { // the left side of the CTE has no correlated expressions, we can push the DEPENDENT_JOIN down @@ -132,7 +130,7 @@ unique_ptr FlattenDependentJoins::Decorrelate(unique_ptrchildren[1] = flatten.PushDownDependentJoin(std::move(delim_join->children[1]), propagate_null_values, lateral_depth); data_offset = flatten.data_offset; - auto left_offset = delim_join->children[0]->GetColumnBindings().size(); + const auto left_offset = delim_join->children[0]->GetColumnBindings().size(); if (!parent) { delim_offset = left_offset + flatten.delim_offset; } diff --git a/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp index 9f1c679a1..8554f3f5b 100644 --- a/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/has_correlated_expressions.cpp @@ -7,7 +7,7 @@ namespace duckdb { -HasCorrelatedExpressions::HasCorrelatedExpressions(const vector &correlated, bool lateral, +HasCorrelatedExpressions::HasCorrelatedExpressions(const CorrelatedColumns &correlated, bool lateral, idx_t lateral_depth) : has_correlated_expressions(false), lateral(lateral), correlated_columns(correlated), lateral_depth(lateral_depth) { diff --git a/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp b/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp index 78b3b21ec..f846d9b36 100644 --- a/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_cte_scan.cpp @@ -14,7 +14,7 @@ namespace duckdb { -RewriteCTEScan::RewriteCTEScan(idx_t table_index, const vector &correlated_columns) +RewriteCTEScan::RewriteCTEScan(idx_t table_index, const CorrelatedColumns &correlated_columns) : table_index(table_index), correlated_columns(correlated_columns) { } @@ -49,7 +49,7 @@ void RewriteCTEScan::VisitOperator(LogicalOperator &op) { // The correlated columns must be placed at the beginning of the // correlated_columns list. Otherwise, further column accesses // and rewrites will fail. - join.correlated_columns.emplace(join.correlated_columns.begin(), corr); + join.correlated_columns.AddColumn(std::move(corr)); } } } diff --git a/src/duckdb/src/storage/compression/bitpacking.cpp b/src/duckdb/src/storage/compression/bitpacking.cpp index fa1ffaeba..ae3550c5e 100644 --- a/src/duckdb/src/storage/compression/bitpacking.cpp +++ b/src/duckdb/src/storage/compression/bitpacking.cpp @@ -341,8 +341,6 @@ unique_ptr BitpackingInitAnalyze(ColumnData &col_data, PhysicalTyp template bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { - auto &analyze_state = state.Cast>(); - // We use BITPACKING_METADATA_GROUP_SIZE tuples, which can exceed the block size. // In that case, we disable bitpacking. // we are conservative here by multiplying by 2 @@ -351,6 +349,7 @@ bool BitpackingAnalyze(AnalyzeState &state, Vector &input, idx_t count) { return false; } + auto &analyze_state = state.Cast>(); UnifiedVectorFormat vdata; input.ToUnifiedFormat(count, vdata); @@ -629,9 +628,9 @@ static T DeltaDecode(T *data, T previous_value, const size_t size) { template ::type> struct BitpackingScanState : public SegmentScanState { public: - explicit BitpackingScanState(ColumnSegment &segment) : current_segment(segment) { + explicit BitpackingScanState(const QueryContext &context, ColumnSegment &segment) : current_segment(segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - handle = buffer_manager.Pin(segment.block); + handle = buffer_manager.Pin(context, segment.block); auto data_ptr = handle.Ptr(); // load offset to bitpacking widths pointer @@ -782,8 +781,8 @@ struct BitpackingScanState : public SegmentScanState { }; template -unique_ptr BitpackingInitScan(ColumnSegment &segment) { - auto result = make_uniq>(segment); +unique_ptr BitpackingInitScan(const QueryContext &context, ColumnSegment &segment) { + auto result = make_uniq>(context, segment); return std::move(result); } @@ -892,7 +891,7 @@ void BitpackingScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_c template void BitpackingFetchRow(ColumnSegment &segment, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) { - BitpackingScanState scan_state(segment); + BitpackingScanState scan_state(state.context, segment); scan_state.Skip(segment, NumericCast(row_id)); D_ASSERT(scan_state.current_group_offset < BITPACKING_METADATA_GROUP_SIZE); @@ -956,10 +955,10 @@ void BitpackingSkip(ColumnSegment &segment, ColumnScanState &state, idx_t skip_c // GetSegmentInfo //===--------------------------------------------------------------------===// template -InsertionOrderPreservingMap BitpackingGetSegmentInfo(ColumnSegment &segment) { +InsertionOrderPreservingMap BitpackingGetSegmentInfo(QueryContext context, ColumnSegment &segment) { map counts; auto tuple_count = segment.count.load(); - BitpackingScanState scan_state(segment); + BitpackingScanState scan_state(context, segment); for (idx_t i = 0; i < tuple_count; i += BITPACKING_METADATA_GROUP_SIZE) { if (i) { scan_state.LoadNextGroup(); diff --git a/src/duckdb/src/storage/compression/dict_fsst.cpp b/src/duckdb/src/storage/compression/dict_fsst.cpp index c43567c52..636f5db6d 100644 --- a/src/duckdb/src/storage/compression/dict_fsst.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst.cpp @@ -56,7 +56,7 @@ struct DictFSSTCompressionStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -111,7 +111,8 @@ void DictFSSTCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr DictFSSTCompressionStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr DictFSSTCompressionStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto state = make_uniq(segment, buffer_manager.Pin(segment.block)); state->Initialize(true); diff --git a/src/duckdb/src/storage/compression/dictionary_compression.cpp b/src/duckdb/src/storage/compression/dictionary_compression.cpp index fa027edd9..3529c119a 100644 --- a/src/duckdb/src/storage/compression/dictionary_compression.cpp +++ b/src/duckdb/src/storage/compression/dictionary_compression.cpp @@ -57,7 +57,7 @@ struct DictionaryCompressionStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -118,7 +118,8 @@ void DictionaryCompressionStorage::FinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr DictionaryCompressionStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr DictionaryCompressionStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto state = make_uniq(buffer_manager.Pin(segment.block)); state->Initialize(segment, true); diff --git a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp index afd335dab..89c718525 100644 --- a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp @@ -143,10 +143,10 @@ struct FixedSizeScanState : public SegmentScanState { BufferHandle handle; }; -unique_ptr FixedSizeInitScan(ColumnSegment &segment) { +unique_ptr FixedSizeInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); - result->handle = buffer_manager.Pin(segment.block); + result->handle = buffer_manager.Pin(context, segment.block); return std::move(result); } diff --git a/src/duckdb/src/storage/compression/fsst.cpp b/src/duckdb/src/storage/compression/fsst.cpp index cbb3b3ac7..7e07f7f6f 100644 --- a/src/duckdb/src/storage/compression/fsst.cpp +++ b/src/duckdb/src/storage/compression/fsst.cpp @@ -50,7 +50,7 @@ struct FSSTStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); template static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); @@ -569,7 +569,7 @@ struct FSSTScanState : public StringScanState { } }; -unique_ptr FSSTStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr FSSTStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { auto block_size = segment.GetBlockManager().GetBlockSize(); auto string_block_limit = StringUncompressed::GetStringBlockLimit(block_size); auto state = make_uniq(string_block_limit); diff --git a/src/duckdb/src/storage/compression/numeric_constant.cpp b/src/duckdb/src/storage/compression/numeric_constant.cpp index a4d1e789b..411a85f6d 100644 --- a/src/duckdb/src/storage/compression/numeric_constant.cpp +++ b/src/duckdb/src/storage/compression/numeric_constant.cpp @@ -11,7 +11,7 @@ namespace duckdb { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -unique_ptr ConstantInitScan(ColumnSegment &segment) { +unique_ptr ConstantInitScan(const QueryContext &context, ColumnSegment &segment) { return nullptr; } diff --git a/src/duckdb/src/storage/compression/rle.cpp b/src/duckdb/src/storage/compression/rle.cpp index 57ebaf1fa..ed26d824d 100644 --- a/src/duckdb/src/storage/compression/rle.cpp +++ b/src/duckdb/src/storage/compression/rle.cpp @@ -303,7 +303,7 @@ struct RLEScanState : public SegmentScanState { }; template -unique_ptr RLEInitScan(ColumnSegment &segment) { +unique_ptr RLEInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq>(segment); return std::move(result); } diff --git a/src/duckdb/src/storage/compression/roaring/common.cpp b/src/duckdb/src/storage/compression/roaring/common.cpp index 80f7004de..3d9230787 100644 --- a/src/duckdb/src/storage/compression/roaring/common.cpp +++ b/src/duckdb/src/storage/compression/roaring/common.cpp @@ -208,7 +208,7 @@ void RoaringFinalizeCompress(CompressionState &state_p) { state.Finalize(); } -unique_ptr RoaringInitScan(ColumnSegment &segment) { +unique_ptr RoaringInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(segment); return std::move(result); } diff --git a/src/duckdb/src/storage/compression/string_uncompressed.cpp b/src/duckdb/src/storage/compression/string_uncompressed.cpp index af3b826bf..201e97787 100644 --- a/src/duckdb/src/storage/compression/string_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/string_uncompressed.cpp @@ -77,7 +77,8 @@ void UncompressedStringInitPrefetch(ColumnSegment &segment, PrefetchState &prefe } } -unique_ptr UncompressedStringStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr UncompressedStringStorage::StringInitScan(const QueryContext &context, + ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); result->handle = buffer_manager.Pin(segment.block); diff --git a/src/duckdb/src/storage/compression/validity_uncompressed.cpp b/src/duckdb/src/storage/compression/validity_uncompressed.cpp index 5a71b8974..1e31061c2 100644 --- a/src/duckdb/src/storage/compression/validity_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/validity_uncompressed.cpp @@ -207,7 +207,7 @@ struct ValidityScanState : public SegmentScanState { block_id_t block_id; }; -unique_ptr ValidityInitScan(ColumnSegment &segment) { +unique_ptr ValidityInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(); auto &buffer_manager = BufferManager::GetBufferManager(segment.db); result->handle = buffer_manager.Pin(segment.block); diff --git a/src/duckdb/src/storage/compression/zstd.cpp b/src/duckdb/src/storage/compression/zstd.cpp index 408855284..58ddefa61 100644 --- a/src/duckdb/src/storage/compression/zstd.cpp +++ b/src/duckdb/src/storage/compression/zstd.cpp @@ -81,7 +81,7 @@ struct ZSTDStorage { static void Compress(CompressionState &state_p, Vector &scan_vector, idx_t count); static void FinalizeCompress(CompressionState &state_p); - static unique_ptr StringInitScan(ColumnSegment &segment); + static unique_ptr StringInitScan(const QueryContext &context, ColumnSegment &segment); static void StringScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset); static void StringScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_count, Vector &result); @@ -961,7 +961,7 @@ struct ZSTDScanState : public SegmentScanState { AllocatedData skip_buffer; }; -unique_ptr ZSTDStorage::StringInitScan(ColumnSegment &segment) { +unique_ptr ZSTDStorage::StringInitScan(const QueryContext &context, ColumnSegment &segment) { auto result = make_uniq(segment); return std::move(result); } diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp index 7d19449bb..75f8dd694 100644 --- a/src/duckdb/src/storage/data_table.cpp +++ b/src/duckdb/src/storage/data_table.cpp @@ -245,7 +245,7 @@ void DataTable::InitializeScan(ClientContext &context, DuckTransaction &transact state.checkpoint_lock = transaction.SharedLockTable(*info); auto &local_storage = LocalStorage::Get(transaction); state.Initialize(column_ids, context, table_filters); - row_groups->InitializeScan(state.table_state, column_ids, table_filters); + row_groups->InitializeScan(context, state.table_state, column_ids, table_filters); local_storage.InitializeScan(*this, state.local_state, table_filters); } @@ -253,7 +253,7 @@ void DataTable::InitializeScanWithOffset(DuckTransaction &transaction, TableScan const vector &column_ids, idx_t start_row, idx_t end_row) { state.checkpoint_lock = transaction.SharedLockTable(*info); state.Initialize(column_ids); - row_groups->InitializeScanWithOffset(state.table_state, column_ids, start_row, end_row); + row_groups->InitializeScanWithOffset(transaction.context, state.table_state, column_ids, start_row, end_row); } idx_t DataTable::GetRowGroupSize() const { @@ -681,7 +681,7 @@ void DataTable::VerifyNewConstraint(LocalStorage &local_storage, DataTable &pare throw NotImplementedException("FIXME: ALTER COLUMN with such constraint is not supported yet"); } - parent.row_groups->VerifyNewConstraint(parent, constraint); + parent.row_groups->VerifyNewConstraint(local_storage.GetClientContext(), parent, constraint); local_storage.VerifyNewConstraint(parent, constraint); } @@ -1270,9 +1270,9 @@ void DataTable::RemoveFromIndexes(TableAppendState &state, DataChunk &chunk, Vec }); } -void DataTable::RemoveFromIndexes(Vector &row_identifiers, idx_t count) { +void DataTable::RemoveFromIndexes(const QueryContext &context, Vector &row_identifiers, idx_t count) { D_ASSERT(IsMainTable()); - row_groups->RemoveFromIndexes(info->indexes, row_identifiers, count); + row_groups->RemoveFromIndexes(context, info->indexes, row_identifiers, count); } //===--------------------------------------------------------------------===// @@ -1649,9 +1649,9 @@ void DataTable::CommitDropTable() { //===--------------------------------------------------------------------===// // Column Segment Info //===--------------------------------------------------------------------===// -vector DataTable::GetColumnSegmentInfo() { +vector DataTable::GetColumnSegmentInfo(const QueryContext &context) { auto lock = GetSharedCheckpointLock(); - return row_groups->GetColumnSegmentInfo(); + return row_groups->GetColumnSegmentInfo(context); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/storage/local_storage.cpp b/src/duckdb/src/storage/local_storage.cpp index e3cbb8f3b..4c58c2d5d 100644 --- a/src/duckdb/src/storage/local_storage.cpp +++ b/src/duckdb/src/storage/local_storage.cpp @@ -16,12 +16,12 @@ namespace duckdb { LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) - : table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), optimistic_writer(context, table), - merged_storage(false) { + : context(context), table_ref(table), allocator(Allocator::Get(table.db)), deleted_rows(0), + optimistic_writer(context, table), merged_storage(false) { auto types = table.GetTypes(); auto data_table_info = table.GetDataTableInfo(); - row_groups = OptimisticDataWriter::CreateCollection(table, types); + row_groups = optimistic_writer.CreateCollection(table, types, OptimisticWritePartialManagers::GLOBAL); auto &collection = *row_groups->collection; collection.InitializeEmpty(); @@ -63,8 +63,8 @@ LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &new_data_table, LocalTableStorage &parent, const idx_t alter_column_index, const LogicalType &target_type, const vector &bound_columns, Expression &cast_expr) - : table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), deleted_rows(parent.deleted_rows), - optimistic_collections(std::move(parent.optimistic_collections)), + : context(context), table_ref(new_data_table), allocator(Allocator::Get(new_data_table.db)), + deleted_rows(parent.deleted_rows), optimistic_collections(std::move(parent.optimistic_collections)), optimistic_writer(new_data_table, parent.optimistic_writer), merged_storage(parent.merged_storage) { // Alter the column type. @@ -115,7 +115,7 @@ void LocalTableStorage::InitializeScan(CollectionScanState &state, optional_ptr< if (collection.GetTotalRows() == 0) { throw InternalException("No rows in LocalTableStorage row group for scan"); } - collection.InitializeScan(state, state.GetColumnIds(), table_filters.get()); + collection.InitializeScan(context, state, state.GetColumnIds(), table_filters.get()); } idx_t LocalTableStorage::EstimatedSize() { @@ -564,7 +564,7 @@ idx_t LocalStorage::Delete(DataTable &table, Vector &row_ids, idx_t count) { // delete from unique indices (if any) if (!storage->append_indexes.Empty()) { - storage->GetCollection().RemoveFromIndexes(storage->append_indexes, row_ids, count); + storage->GetCollection().RemoveFromIndexes(context, storage->append_indexes, row_ids, count); } auto ids = FlatVector::GetData(row_ids); @@ -752,7 +752,7 @@ void LocalStorage::VerifyNewConstraint(DataTable &parent, const BoundConstraint if (!storage) { return; } - storage->GetCollection().VerifyNewConstraint(parent, constraint); + storage->GetCollection().VerifyNewConstraint(context, parent, constraint); } } // namespace duckdb diff --git a/src/duckdb/src/storage/metadata/metadata_manager.cpp b/src/duckdb/src/storage/metadata/metadata_manager.cpp index 8674f742d..efcdc68d7 100644 --- a/src/duckdb/src/storage/metadata/metadata_manager.cpp +++ b/src/duckdb/src/storage/metadata/metadata_manager.cpp @@ -99,7 +99,7 @@ MetadataHandle MetadataManager::Pin(const MetadataPointer &pointer) { return Pin(QueryContext(), pointer); } -MetadataHandle MetadataManager::Pin(QueryContext context, const MetadataPointer &pointer) { +MetadataHandle MetadataManager::Pin(const QueryContext &context, const MetadataPointer &pointer) { D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); shared_ptr block_handle; { diff --git a/src/duckdb/src/storage/optimistic_data_writer.cpp b/src/duckdb/src/storage/optimistic_data_writer.cpp index 4f595223f..fbe481364 100644 --- a/src/duckdb/src/storage/optimistic_data_writer.cpp +++ b/src/duckdb/src/storage/optimistic_data_writer.cpp @@ -6,6 +6,9 @@ namespace duckdb { +OptimisticWriteCollection::~OptimisticWriteCollection() { +} + OptimisticDataWriter::OptimisticDataWriter(ClientContext &context, DataTable &table) : context(context), table(table) { } @@ -28,14 +31,14 @@ bool OptimisticDataWriter::PrepareWrite() { // allocate the partial block-manager if none is allocated yet if (!partial_manager) { auto &block_manager = table.GetTableIOManager().GetBlockManagerForRowData(); - partial_manager = - make_uniq(QueryContext(context), block_manager, PartialBlockType::APPEND_TO_TABLE); + partial_manager = make_uniq(context, block_manager, PartialBlockType::APPEND_TO_TABLE); } return true; } unique_ptr OptimisticDataWriter::CreateCollection(DataTable &storage, - const vector &insert_types) { + const vector &insert_types, + OptimisticWritePartialManagers type) { auto table_info = storage.GetDataTableInfo(); auto &io_manager = TableIOManager::Get(storage); @@ -45,6 +48,13 @@ unique_ptr OptimisticDataWriter::CreateCollection(Dat auto result = make_uniq(); result->collection = std::move(row_groups); + if (type == OptimisticWritePartialManagers::PER_COLUMN) { + for (idx_t i = 0; i < insert_types.size(); i++) { + auto &block_manager = table.GetTableIOManager().GetBlockManagerForRowData(); + result->partial_block_managers.push_back(make_uniq( + QueryContext(context), block_manager, PartialBlockType::APPEND_TO_TABLE)); + } + } return result; } @@ -62,7 +72,7 @@ void OptimisticDataWriter::WriteNewRowGroup(OptimisticWriteCollection &row_group for (idx_t i = row_groups.last_flushed; i < row_groups.complete_row_groups; i++) { to_flush.push_back(*row_groups.collection->GetRowGroup(NumericCast(i))); } - FlushToDisk(to_flush); + FlushToDisk(row_groups, to_flush); row_groups.last_flushed = row_groups.complete_row_groups; } } @@ -79,30 +89,40 @@ void OptimisticDataWriter::WriteLastRowGroup(OptimisticWriteCollection &row_grou } // add the last (incomplete) row group to_flush.push_back(*row_groups.collection->GetRowGroup(-1)); - FlushToDisk(to_flush); + FlushToDisk(row_groups, to_flush); + + for (auto &partial_manager : row_groups.partial_block_managers) { + Merge(partial_manager); + } + row_groups.partial_block_managers.clear(); } -void OptimisticDataWriter::FlushToDisk(const vector> &row_groups) { +void OptimisticDataWriter::FlushToDisk(OptimisticWriteCollection &collection, + const vector> &row_groups) { //! The set of column compression types (if any) vector compression_types; D_ASSERT(compression_types.empty()); for (auto &column : table.Columns()) { compression_types.push_back(column.CompressionType()); } - RowGroupWriteInfo info(*partial_manager, compression_types); + RowGroupWriteInfo info(*partial_manager, compression_types, collection.partial_block_managers); RowGroup::WriteToDisk(info, row_groups); } -void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { - if (!other.partial_manager) { +void OptimisticDataWriter::Merge(unique_ptr &other_manager) { + if (!other_manager) { return; } if (!partial_manager) { - partial_manager = std::move(other.partial_manager); + partial_manager = std::move(other_manager); return; } - partial_manager->Merge(*other.partial_manager); - other.partial_manager.reset(); + partial_manager->Merge(*other_manager); + other_manager.reset(); +} + +void OptimisticDataWriter::Merge(OptimisticDataWriter &other) { + Merge(other.partial_manager); } void OptimisticDataWriter::FinalFlush() { diff --git a/src/duckdb/src/storage/serialization/serialize_types.cpp b/src/duckdb/src/storage/serialization/serialize_types.cpp index 453961009..963d5646e 100644 --- a/src/duckdb/src/storage/serialization/serialize_types.cpp +++ b/src/duckdb/src/storage/serialization/serialize_types.cpp @@ -42,6 +42,9 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) case ExtraTypeInfoType::GENERIC_TYPE_INFO: result = make_shared_ptr(type); break; + case ExtraTypeInfoType::GEO_TYPE_INFO: + result = GeoTypeInfo::Deserialize(deserializer); + break; case ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO: result = IntegerLiteralTypeInfo::Deserialize(deserializer); break; @@ -136,6 +139,15 @@ unique_ptr ExtensionTypeInfo::Deserialize(Deserializer &deser return result; } +void GeoTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); +} + +shared_ptr GeoTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new GeoTypeInfo()); + return std::move(result); +} + void IntegerLiteralTypeInfo::Serialize(Serializer &serializer) const { ExtraTypeInfo::Serialize(serializer); serializer.WriteProperty(200, "constant_value", constant_value); diff --git a/src/duckdb/src/storage/standard_buffer_manager.cpp b/src/duckdb/src/storage/standard_buffer_manager.cpp index e15986e1c..73705c43a 100644 --- a/src/duckdb/src/storage/standard_buffer_manager.cpp +++ b/src/duckdb/src/storage/standard_buffer_manager.cpp @@ -338,7 +338,7 @@ BufferHandle StandardBufferManager::Pin(shared_ptr &handle) { return Pin(QueryContext(), handle); } -BufferHandle StandardBufferManager::Pin(QueryContext context, shared_ptr &handle) { +BufferHandle StandardBufferManager::Pin(const QueryContext &context, shared_ptr &handle) { // we need to be careful not to return the BufferHandle to this block while holding the BlockHandle's lock // as exiting this function's scope may cause the destructor of the BufferHandle to be called while holding the lock // the destructor calls Unpin, which grabs the BlockHandle's lock again, causing a deadlock diff --git a/src/duckdb/src/storage/statistics/string_stats.cpp b/src/duckdb/src/storage/statistics/string_stats.cpp index e7d232692..3fe22ecac 100644 --- a/src/duckdb/src/storage/statistics/string_stats.cpp +++ b/src/duckdb/src/storage/statistics/string_stats.cpp @@ -170,6 +170,14 @@ void StringStats::Update(BaseStatistics &stats, const string_t &value) { } } +void StringStats::SetMin(BaseStatistics &stats, const string_t &value) { + ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).min); +} + +void StringStats::SetMax(BaseStatistics &stats, const string_t &value) { + ConstructValue(const_data_ptr_cast(value.GetData()), value.GetSize(), GetDataUnsafe(stats).max); +} + void StringStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { if (other.GetType().id() == LogicalTypeId::VALIDITY) { return; diff --git a/src/duckdb/src/storage/table/array_column_data.cpp b/src/duckdb/src/storage/table/array_column_data.cpp index 7c8a12f13..d92562a94 100644 --- a/src/duckdb/src/storage/table/array_column_data.cpp +++ b/src/duckdb/src/storage/table/array_column_data.cpp @@ -256,7 +256,7 @@ void ArrayColumnData::FetchRow(TransactionData transaction, ColumnFetchState &st // We need to fetch between [row_id * array_size, (row_id + 1) * array_size) auto child_state = make_uniq(); - child_state->Initialize(child_type, nullptr); + child_state->Initialize(state.context, child_type, nullptr); const auto child_offset = start + (UnsafeNumericCast(row_id) - start) * array_size; @@ -302,8 +302,8 @@ unique_ptr ArrayColumnData::CreateCheckpointState(RowGrou unique_ptr ArrayColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { - - auto checkpoint_state = make_uniq(row_group, *this, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); checkpoint_state->validity_state = validity.Checkpoint(row_group, checkpoint_info); checkpoint_state->child_state = child_column->Checkpoint(row_group, checkpoint_info); return std::move(checkpoint_state); @@ -332,12 +332,12 @@ void ArrayColumnData::InitializeColumn(PersistentColumnData &column_data, BaseSt this->count = validity.count.load(); } -void ArrayColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, +void ArrayColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, vector &result) { col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity.GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.back() = 1; - child_column->GetColumnSegmentInfo(row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); } void ArrayColumnData::Verify(RowGroup &parent) { diff --git a/src/duckdb/src/storage/table/column_data.cpp b/src/duckdb/src/storage/table/column_data.cpp index c212fcb18..c38fff709 100644 --- a/src/duckdb/src/storage/table/column_data.cpp +++ b/src/duckdb/src/storage/table/column_data.cpp @@ -673,7 +673,8 @@ void ColumnData::CheckpointScan(ColumnSegment &segment, ColumnScanState &state, unique_ptr ColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { // scan the segments of the column data // set up the checkpoint state - auto checkpoint_state = CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = CreateCheckpointState(row_group, partial_block_manager); checkpoint_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); auto &nodes = data.ReferenceSegments(); @@ -699,6 +700,7 @@ void ColumnData::InitializeColumn(PersistentColumnData &column_data, BaseStatist this->count = 0; for (auto &data_pointer : column_data.pointers) { // Update the count and statistics + data_pointer.row_start = start + count; this->count += data_pointer.tuple_count; // Merge the statistics. If this is a child column, the target_stats reference will point into the parents stats @@ -909,7 +911,7 @@ shared_ptr ColumnData::Deserialize(BlockManager &block_manager, Data return entry; } -void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_path, +void ColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, vector &result) { D_ASSERT(!col_path.empty()); @@ -958,7 +960,7 @@ void ColumnData::GetColumnSegmentInfo(idx_t row_group_index, vector col_p column_info.additional_blocks = segment_state->GetAdditionalBlocks(); } if (compression_function.get_segment_info) { - auto segment_info = compression_function.get_segment_info(*segment); + auto segment_info = compression_function.get_segment_info(context, *segment); vector sinfo; for (auto &item : segment_info) { auto &mode = item.first; diff --git a/src/duckdb/src/storage/table/column_segment.cpp b/src/duckdb/src/storage/table/column_segment.cpp index 347463fbe..6ee49be19 100644 --- a/src/duckdb/src/storage/table/column_segment.cpp +++ b/src/duckdb/src/storage/table/column_segment.cpp @@ -80,7 +80,6 @@ ColumnSegment::ColumnSegment(DatabaseInstance &db, shared_ptr block } ColumnSegment::ColumnSegment(ColumnSegment &other, const idx_t start) - : SegmentBase(start, other.count.load()), db(other.db), type(std::move(other.type)), type_size(other.type_size), segment_type(other.segment_type), stats(std::move(other.stats)), block(std::move(other.block)), function(other.function), block_id(other.block_id), offset(other.offset), @@ -109,7 +108,7 @@ void ColumnSegment::InitializePrefetch(PrefetchState &prefetch_state, ColumnScan } void ColumnSegment::InitializeScan(ColumnScanState &state) { - state.scan_state = function.get().init_scan(*this); + state.scan_state = function.get().init_scan(state.context, *this); } void ColumnSegment::Scan(ColumnScanState &state, idx_t scan_count, Vector &result, idx_t result_offset, diff --git a/src/duckdb/src/storage/table/list_column_data.cpp b/src/duckdb/src/storage/table/list_column_data.cpp index 7685d16ca..986b32dc7 100644 --- a/src/duckdb/src/storage/table/list_column_data.cpp +++ b/src/duckdb/src/storage/table/list_column_data.cpp @@ -312,7 +312,7 @@ void ListColumnData::FetchRow(TransactionData transaction, ColumnFetchState &sta auto &child_type = ListType::GetChildType(result.GetType()); Vector child_scan(child_type, child_scan_count); // seek the scan towards the specified position and read [length] entries - child_state->Initialize(child_type, nullptr); + child_state->Initialize(state.context, child_type, nullptr); child_column->InitializeScanWithOffset(*child_state, start + start_offset); D_ASSERT(child_type.InternalType() == PhysicalType::STRUCT || child_state->row_index + child_scan_count - this->start <= child_column->GetMaxEntry()); @@ -391,13 +391,13 @@ void ListColumnData::InitializeColumn(PersistentColumnData &column_data, BaseSta child_column->InitializeColumn(column_data.child_columns[1], child_stats); } -void ListColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); +void ListColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result) { + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity.GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.back() = 1; - child_column->GetColumnSegmentInfo(row_group_index, col_path, result); + child_column->GetColumnSegmentInfo(context, row_group_index, col_path, result); } } // namespace duckdb diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index 40e20d2d4..eb02ba7e9 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -183,10 +183,11 @@ void RowGroup::InitializeEmpty(const vector &types) { } } -void ColumnScanState::Initialize(const LogicalType &type, const vector &children, - optional_ptr options) { +void ColumnScanState::Initialize(const QueryContext &context_p, const LogicalType &type, + const vector &children, optional_ptr options) { // Register the options in the state scan_options = options; + context = context_p; if (type.id() == LogicalTypeId::VALIDITY) { // validity - nothing to initialize @@ -201,7 +202,7 @@ void ColumnScanState::Initialize(const LogicalType &type, const vector options) { +void ColumnScanState::Initialize(const QueryContext &context_p, const LogicalType &type, + optional_ptr options) { vector children; - Initialize(type, children, options); + Initialize(context_p, type, children, options); } -void CollectionScanState::Initialize(const vector &types) { +void CollectionScanState::Initialize(const QueryContext &context, const vector &types) { auto &column_ids = GetColumnIds(); column_scans = make_unsafe_uniq_array(column_ids.size()); for (idx_t i = 0; i < column_ids.size(); i++) { @@ -245,7 +247,7 @@ void CollectionScanState::Initialize(const vector &types) { continue; } auto col_id = column_ids[i].GetPrimaryIndex(); - column_scans[i].Initialize(types[col_id], column_ids[i].GetChildIndexes(), &GetOptions()); + column_scans[i].Initialize(context, types[col_id], column_ids[i].GetChildIndexes(), &GetOptions()); } } @@ -310,7 +312,7 @@ unique_ptr RowGroup::AlterType(RowGroupCollection &new_collection, con column_data->InitializeAppend(append_state); // scan the original table, and fill the new column with the transformed value - scan_state.Initialize(GetCollection().GetTypes()); + scan_state.Initialize(executor.GetContext(), GetCollection().GetTypes()); InitializeScan(scan_state); DataChunk append_chunk; @@ -914,6 +916,32 @@ void RowGroup::MergeIntoStatistics(TableStatistics &other) { } } +ColumnCheckpointInfo::ColumnCheckpointInfo(RowGroupWriteInfo &info, idx_t column_idx) + : column_idx(column_idx), info(info) { +} + +RowGroupWriteInfo::RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + CheckpointType checkpoint_type) + : manager(manager), compression_types(compression_types), checkpoint_type(checkpoint_type) { +} + +RowGroupWriteInfo::RowGroupWriteInfo(PartialBlockManager &manager, const vector &compression_types, + vector> &column_partial_block_managers_p) + : manager(manager), compression_types(compression_types), checkpoint_type(CheckpointType::FULL_CHECKPOINT), + column_partial_block_managers(column_partial_block_managers_p) { +} + +PartialBlockManager &RowGroupWriteInfo::GetPartialBlockManager(idx_t column_idx) { + if (column_partial_block_managers && !column_partial_block_managers->empty()) { + return *column_partial_block_managers->at(column_idx); + } + return manager; +} + +PartialBlockManager &ColumnCheckpointInfo::GetPartialBlockManager() { + return info.GetPartialBlockManager(column_idx); +} + CompressionType ColumnCheckpointInfo::GetCompressionType() { return info.compression_types[column_idx]; } @@ -1220,10 +1248,11 @@ PartitionStatistics RowGroup::GetPartitionStats() const { //===--------------------------------------------------------------------===// // GetColumnSegmentInfo //===--------------------------------------------------------------------===// -void RowGroup::GetColumnSegmentInfo(idx_t row_group_index, vector &result) { +void RowGroup::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, + vector &result) { for (idx_t col_idx = 0; col_idx < GetColumnCount(); col_idx++) { auto &col_data = GetColumn(col_idx); - col_data.GetColumnSegmentInfo(row_group_index, {col_idx}, result); + col_data.GetColumnSegmentInfo(context, row_group_index, {col_idx}, result); } } diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp index d44ca0544..42c453ea0 100644 --- a/src/duckdb/src/storage/table/row_group_collection.cpp +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -17,6 +17,7 @@ #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/transaction/duck_transaction.hpp" namespace duckdb { @@ -153,13 +154,14 @@ void RowGroupCollection::Verify() { //===--------------------------------------------------------------------===// // Scan //===--------------------------------------------------------------------===// -void RowGroupCollection::InitializeScan(CollectionScanState &state, const vector &column_ids, +void RowGroupCollection::InitializeScan(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, optional_ptr table_filters) { auto row_group = row_groups->GetRootSegment(); D_ASSERT(row_group); state.row_groups = row_groups.get(); state.max_row = row_start + total_rows; - state.Initialize(GetTypes()); + state.Initialize(context, GetTypes()); while (row_group && !row_group->InitializeScan(state)) { row_group = row_groups->GetNextSegment(row_group); } @@ -169,26 +171,28 @@ void RowGroupCollection::InitializeCreateIndexScan(CreateIndexScanState &state) state.segment_lock = row_groups->Lock(); } -void RowGroupCollection::InitializeScanWithOffset(CollectionScanState &state, const vector &column_ids, - idx_t start_row, idx_t end_row) { +void RowGroupCollection::InitializeScanWithOffset(const QueryContext &context, CollectionScanState &state, + const vector &column_ids, idx_t start_row, + idx_t end_row) { auto row_group = row_groups->GetSegment(start_row); D_ASSERT(row_group); state.row_groups = row_groups.get(); state.max_row = end_row; - state.Initialize(GetTypes()); + state.Initialize(context, GetTypes()); idx_t start_vector = (start_row - row_group->start) / STANDARD_VECTOR_SIZE; if (!row_group->InitializeScanWithOffset(state, start_vector)) { throw InternalException("Failed to initialize row group scan with offset"); } } -bool RowGroupCollection::InitializeScanInRowGroup(CollectionScanState &state, RowGroupCollection &collection, - RowGroup &row_group, idx_t vector_index, idx_t max_row) { +bool RowGroupCollection::InitializeScanInRowGroup(const QueryContext &context, CollectionScanState &state, + RowGroupCollection &collection, RowGroup &row_group, + idx_t vector_index, idx_t max_row) { state.max_row = max_row; state.row_groups = collection.row_groups.get(); if (!state.column_scans) { // initialize the scan state - state.Initialize(collection.GetTypes()); + state.Initialize(context, collection.GetTypes()); } return row_group.InitializeScanWithOffset(state, vector_index); } @@ -242,7 +246,8 @@ bool RowGroupCollection::NextParallelScan(ClientContext &context, ParallelCollec D_ASSERT(row_group); // initialize the scan for this row group - bool need_to_scan = InitializeScanInRowGroup(scan_state, *collection, *row_group, vector_index, max_row); + bool need_to_scan = + InitializeScanInRowGroup(context, scan_state, *collection, *row_group, vector_index, max_row); if (!need_to_scan) { // skip this row group continue; @@ -266,7 +271,7 @@ bool RowGroupCollection::Scan(DuckTransaction &transaction, const vector(row_identifiers); // Collect all indexed columns. @@ -717,7 +723,7 @@ void RowGroupCollection::RemoveFromIndexes(TableIndexList &indexes, Vector &row_ auto base_row_id = row_group_vector_idx * STANDARD_VECTOR_SIZE + row_group->start; // Fetch the current vector into fetch_chunk. - state.table_state.Initialize(GetTypes()); + state.table_state.Initialize(context, GetTypes()); row_group->InitializeScanWithOffset(state.table_state, row_group_vector_idx); row_group->ScanCommitted(state.table_state, fetch_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); fetch_chunk.Verify(); @@ -887,7 +893,7 @@ class VacuumTask : public BaseCheckpointTask { TableScanState scan_state; scan_state.Initialize(column_ids); - scan_state.table_state.Initialize(types); + scan_state.table_state.Initialize(QueryContext(), types); scan_state.table_state.max_row = idx_t(-1); idx_t merged_groups = 0; idx_t total_row_groups = vacuum_state.row_group_counts.size(); @@ -1248,11 +1254,11 @@ vector RowGroupCollection::GetPartitionStats() const { //===--------------------------------------------------------------------===// // GetColumnSegmentInfo //===--------------------------------------------------------------------===// -vector RowGroupCollection::GetColumnSegmentInfo() { +vector RowGroupCollection::GetColumnSegmentInfo(const QueryContext &context) { vector result; auto lock = row_groups->Lock(); for (auto &row_group : row_groups->Segments(lock)) { - row_group.GetColumnSegmentInfo(row_group.index, result); + row_group.GetColumnSegmentInfo(context, row_group.index, result); } return result; } @@ -1349,7 +1355,8 @@ shared_ptr RowGroupCollection::AlterType(ClientContext &cont return result; } -void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConstraint &constraint) { +void RowGroupCollection::VerifyNewConstraint(const QueryContext &context, DataTable &parent, + const BoundConstraint &constraint) { if (total_rows == 0) { return; } @@ -1371,7 +1378,7 @@ void RowGroupCollection::VerifyNewConstraint(DataTable &parent, const BoundConst CreateIndexScanState state; auto scan_type = TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED; state.Initialize(column_ids, nullptr); - InitializeScan(state.table_state, column_ids, nullptr); + InitializeScan(context, state.table_state, column_ids, nullptr); InitializeCreateIndexScan(state); diff --git a/src/duckdb/src/storage/table/standard_column_data.cpp b/src/duckdb/src/storage/table/standard_column_data.cpp index c657c63ee..fde7d2463 100644 --- a/src/duckdb/src/storage/table/standard_column_data.cpp +++ b/src/duckdb/src/storage/table/standard_column_data.cpp @@ -241,9 +241,10 @@ unique_ptr StandardColumnData::Checkpoint(RowGroup &row_g // to prevent reading the validity data immediately after it is checkpointed we first checkpoint the main column // this is necessary for concurrent checkpointing as due to the partial block manager checkpointed data might be // flushed to disk by a different thread than the one that wrote it, causing a data race - auto base_state = CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto base_state = CreateCheckpointState(row_group, partial_block_manager); base_state->global_stats = BaseStatistics::CreateEmpty(type).ToUnique(); - auto validity_state_p = validity.CreateCheckpointState(row_group, checkpoint_info.info.manager); + auto validity_state_p = validity.CreateCheckpointState(row_group, partial_block_manager); validity_state_p->global_stats = BaseStatistics::CreateEmpty(validity.type).ToUnique(); auto &validity_state = *validity_state_p; @@ -294,11 +295,12 @@ void StandardColumnData::InitializeColumn(PersistentColumnData &column_data, Bas validity.InitializeColumn(column_data.child_columns[0], target_stats); } -void StandardColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, +void StandardColumnData::GetColumnSegmentInfo(const QueryContext &context, duckdb::idx_t row_group_index, + vector col_path, vector &result) { - ColumnData::GetColumnSegmentInfo(row_group_index, col_path, result); + ColumnData::GetColumnSegmentInfo(context, row_group_index, col_path, result); col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, std::move(col_path), result); + validity.GetColumnSegmentInfo(context, row_group_index, std::move(col_path), result); } void StandardColumnData::Verify(RowGroup &parent) { diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp index 5137330ef..65f322e79 100644 --- a/src/duckdb/src/storage/table/struct_column_data.cpp +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -311,7 +311,8 @@ unique_ptr StructColumnData::CreateCheckpointState(RowGro unique_ptr StructColumnData::Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) { - auto checkpoint_state = make_uniq(row_group, *this, checkpoint_info.info.manager); + auto &partial_block_manager = checkpoint_info.GetPartialBlockManager(); + auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); checkpoint_state->validity_state = validity.Checkpoint(row_group, checkpoint_info); for (auto &sub_column : sub_columns) { checkpoint_state->child_states.push_back(sub_column->Checkpoint(row_group, checkpoint_info)); @@ -361,13 +362,13 @@ void StructColumnData::InitializeColumn(PersistentColumnData &column_data, BaseS this->count = validity.count.load(); } -void StructColumnData::GetColumnSegmentInfo(duckdb::idx_t row_group_index, vector col_path, - vector &result) { +void StructColumnData::GetColumnSegmentInfo(const QueryContext &context, idx_t row_group_index, vector col_path, + vector &result) { col_path.push_back(0); - validity.GetColumnSegmentInfo(row_group_index, col_path, result); + validity.GetColumnSegmentInfo(context, row_group_index, col_path, result); for (idx_t i = 0; i < sub_columns.size(); i++) { col_path.back() = i + 1; - sub_columns[i]->GetColumnSegmentInfo(row_group_index, col_path, result); + sub_columns[i]->GetColumnSegmentInfo(context, row_group_index, col_path, result); } } diff --git a/src/duckdb/src/storage/wal_replay.cpp b/src/duckdb/src/storage/wal_replay.cpp index 77eca9cf7..695db78df 100644 --- a/src/duckdb/src/storage/wal_replay.cpp +++ b/src/duckdb/src/storage/wal_replay.cpp @@ -573,7 +573,7 @@ void WriteAheadLogDeserializer::ReplayIndexData(IndexStorageInfo &info) { // Convert the buffer handle to a persistent block and store the block id. if (!deserialize_only) { auto block_id = block_manager->GetFreeBlockId(); - block_manager->ConvertToPersistent(QueryContext(context), block_id, std::move(block_handle), + block_manager->ConvertToPersistent(context, block_id, std::move(block_handle), std::move(buffer_handle)); data_info.block_pointers[j].block_id = block_id; } diff --git a/src/duckdb/src/transaction/cleanup_state.cpp b/src/duckdb/src/transaction/cleanup_state.cpp index f9a17f265..96483dbc9 100644 --- a/src/duckdb/src/transaction/cleanup_state.cpp +++ b/src/duckdb/src/transaction/cleanup_state.cpp @@ -13,7 +13,7 @@ namespace duckdb { -CleanupState::CleanupState(transaction_t lowest_active_transaction) +CleanupState::CleanupState(const QueryContext &context, transaction_t lowest_active_transaction) : lowest_active_transaction(lowest_active_transaction), current_table(nullptr), count(0) { } @@ -97,7 +97,7 @@ void CleanupState::Flush() { // delete the tuples from all the indexes try { - current_table->RemoveFromIndexes(row_identifiers, count); + current_table->RemoveFromIndexes(context, row_identifiers, count); } catch (...) { // NOLINT: ignore errors here } diff --git a/src/duckdb/src/transaction/duck_transaction_manager.cpp b/src/duckdb/src/transaction/duck_transaction_manager.cpp index eace5283c..06d17189b 100644 --- a/src/duckdb/src/transaction/duck_transaction_manager.cpp +++ b/src/duckdb/src/transaction/duck_transaction_manager.cpp @@ -216,7 +216,7 @@ void DuckTransactionManager::Checkpoint(ClientContext &context, bool force) { options.type = CheckpointType::CONCURRENT_CHECKPOINT; } - storage_manager.CreateCheckpoint(QueryContext(context), options); + storage_manager.CreateCheckpoint(context, options); } unique_ptr DuckTransactionManager::SharedCheckpointLock() { @@ -353,7 +353,7 @@ ErrorData DuckTransactionManager::CommitTransaction(ClientContext &context, Tran options.type = checkpoint_decision.type; auto &storage_manager = db.GetStorageManager(); try { - storage_manager.CreateCheckpoint(QueryContext(context), options); + storage_manager.CreateCheckpoint(context, options); } catch (std::exception &ex) { error.Merge(ErrorData(ex)); } diff --git a/src/duckdb/src/transaction/undo_buffer.cpp b/src/duckdb/src/transaction/undo_buffer.cpp index 4408a972b..b584f5401 100644 --- a/src/duckdb/src/transaction/undo_buffer.cpp +++ b/src/duckdb/src/transaction/undo_buffer.cpp @@ -15,6 +15,7 @@ #include "duckdb/transaction/delete_info.hpp" #include "duckdb/transaction/rollback_state.hpp" #include "duckdb/transaction/wal_write_state.hpp" +#include "duckdb/transaction/duck_transaction.hpp" namespace duckdb { constexpr uint32_t UNDO_ENTRY_HEADER_SIZE = sizeof(UndoFlags) + sizeof(uint32_t); @@ -176,7 +177,7 @@ void UndoBuffer::Cleanup(transaction_t lowest_active_transaction) { // the chunks) // (2) there is no active transaction with start_id < commit_id of this // transaction - CleanupState state(lowest_active_transaction); + CleanupState state(transaction.context, lowest_active_transaction); UndoBuffer::IteratorState iterator_state; IterateEntries(iterator_state, [&](UndoFlags type, data_ptr_t data) { state.CleanupEntry(type, data); }); diff --git a/src/duckdb/third_party/httplib/httplib.hpp b/src/duckdb/third_party/httplib/httplib.hpp index 4aa0458dc..409c47d0b 100644 --- a/src/duckdb/third_party/httplib/httplib.hpp +++ b/src/duckdb/third_party/httplib/httplib.hpp @@ -7077,7 +7077,12 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { } auto location = res.get_header_value("location"); - if (location.empty()) { return false; } + if (location.empty()) { + // s3 requests will not return a location header, and instead a + // X-Amx-Region-Bucket header. Return true so all response headers + // are returned to the httpfs/calling extension + return true; + } const Regex re( R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); diff --git a/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp b/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp new file mode 100644 index 000000000..a848d44a9 --- /dev/null +++ b/src/duckdb/third_party/yyjson/include/yyjson_utils.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// yyjson_utils.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "yyjson.hpp" + +using namespace duckdb_yyjson; // NOLINT + +namespace duckdb { + +struct ConvertedJSONHolder { +public: + ~ConvertedJSONHolder() { + if (doc) { + yyjson_mut_doc_free(doc); + } + if (stringified_json) { + free(stringified_json); + } + } + +public: + yyjson_mut_doc *doc = nullptr; + char *stringified_json = nullptr; +}; + +} // namespace duckdb diff --git a/src/duckdb/ub_extension_parquet_writer.cpp b/src/duckdb/ub_extension_parquet_writer.cpp index cca73c331..5efcd2c3c 100644 --- a/src/duckdb/ub_extension_parquet_writer.cpp +++ b/src/duckdb/ub_extension_parquet_writer.cpp @@ -10,5 +10,7 @@ #include "extension/parquet/writer/primitive_column_writer.cpp" +#include "extension/parquet/writer/variant_column_writer.cpp" + #include "extension/parquet/writer/struct_column_writer.cpp" diff --git a/src/duckdb/ub_extension_parquet_writer_variant.cpp b/src/duckdb/ub_extension_parquet_writer_variant.cpp new file mode 100644 index 000000000..88e32d186 --- /dev/null +++ b/src/duckdb/ub_extension_parquet_writer_variant.cpp @@ -0,0 +1,2 @@ +#include "extension/parquet/writer/variant/convert_variant.cpp" + diff --git a/src/duckdb/ub_src_common_types.cpp b/src/duckdb/ub_src_common_types.cpp index 7f181227e..5bcfc4f96 100644 --- a/src/duckdb/ub_src_common_types.cpp +++ b/src/duckdb/ub_src_common_types.cpp @@ -54,3 +54,5 @@ #include "src/common/types/vector_constants.cpp" +#include "src/common/types/geometry.cpp" + diff --git a/src/duckdb/ub_src_function_cast.cpp b/src/duckdb/ub_src_function_cast.cpp index fcf41bbee..99f3378ca 100644 --- a/src/duckdb/ub_src_function_cast.cpp +++ b/src/duckdb/ub_src_function_cast.cpp @@ -12,6 +12,8 @@ #include "src/function/cast/enum_casts.cpp" +#include "src/function/cast/geo_casts.cpp" + #include "src/function/cast/list_casts.cpp" #include "src/function/cast/map_cast.cpp" diff --git a/src/duckdb/ub_src_function_table_system.cpp b/src/duckdb/ub_src_function_table_system.cpp index afa17b21b..5ca818791 100644 --- a/src/duckdb/ub_src_function_table_system.cpp +++ b/src/duckdb/ub_src_function_table_system.cpp @@ -1,3 +1,5 @@ +#include "src/function/table/system/duckdb_connection_count.cpp" + #include "src/function/table/system/duckdb_approx_database_count.cpp" #include "src/function/table/system/duckdb_columns.cpp" diff --git a/src/duckdb/ub_src_optimizer.cpp b/src/duckdb/ub_src_optimizer.cpp index f8238dab4..cc2d15d70 100644 --- a/src/duckdb/ub_src_optimizer.cpp +++ b/src/duckdb/ub_src_optimizer.cpp @@ -8,6 +8,8 @@ #include "src/optimizer/common_aggregate_optimizer.cpp" +#include "src/optimizer/common_subplan_optimizer.cpp" + #include "src/optimizer/compressed_materialization.cpp" #include "src/optimizer/cse_optimizer.cpp" From 3105398dd885fc47d9c3430597359d0641280eac Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Tue, 7 Oct 2025 03:30:36 +0000 Subject: [PATCH 2/6] Update vendored DuckDB sources to b3c8acdc0e --- src/duckdb/src/function/table/version/pragma_version.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 4ed97cc23..94f46ab3f 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev720" +#define DUCKDB_PATCH_VERSION "0-dev722" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev720" +#define DUCKDB_VERSION "v1.5.0-dev722" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "5657cbdc0b" +#define DUCKDB_SOURCE_ID "b3c8acdc0e" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" From cd2d80d8ed6034c6f907ef14c705729b0bedf244 Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Wed, 8 Oct 2025 06:08:00 +0000 Subject: [PATCH 3/6] Update vendored DuckDB sources to f793ea27c6 --- src/duckdb/src/common/enum_util.cpp | 7 +- .../function/table/version/pragma_version.cpp | 6 +- .../common/insertion_order_preserving_map.hpp | 4 + .../src/include/duckdb/parser/query_node.hpp | 3 +- .../include/duckdb/parser/query_node/list.hpp | 1 + .../parser/query_node/statement_node.hpp | 42 +++ .../src/include/duckdb/parser/tokens.hpp | 1 + .../include/duckdb/planner/bind_context.hpp | 8 +- .../src/include/duckdb/planner/binder.hpp | 27 +- .../expression/bound_subquery_expression.hpp | 2 +- .../duckdb/planner/expression_iterator.hpp | 15 - .../duckdb/planner/logical_operator.hpp | 1 + .../planner/query_node/bound_cte_node.hpp | 6 +- .../query_node/bound_recursive_cte_node.hpp | 4 +- .../query_node/bound_set_operation_node.hpp | 23 +- .../planner/tableref/bound_subqueryref.hpp | 4 +- .../src/parser/query_node/statement_node.cpp | 43 +++ src/duckdb/src/planner/bind_context.cpp | 6 +- src/duckdb/src/planner/binder.cpp | 100 ++----- .../expression/bind_subquery_expression.cpp | 14 +- .../binder/query_node/bind_cte_node.cpp | 56 ++-- .../query_node/bind_recursive_cte_node.cpp | 52 ++-- .../binder/query_node/bind_select_node.cpp | 123 ++++---- .../binder/query_node/bind_setop_node.cpp | 263 +++++++++++------- .../binder/query_node/bind_statement_node.cpp | 26 ++ .../binder/query_node/plan_cte_node.cpp | 38 +-- .../query_node/plan_recursive_cte_node.cpp | 8 +- .../planner/binder/query_node/plan_setop.cpp | 19 +- .../binder/query_node/plan_subquery.cpp | 2 +- .../binder/tableref/bind_basetableref.cpp | 12 +- .../planner/binder/tableref/bind_pivot.cpp | 44 +-- .../binder/tableref/bind_subqueryref.cpp | 4 +- .../binder/tableref/bind_table_function.cpp | 12 +- .../binder/tableref/plan_subqueryref.cpp | 2 +- .../expression_binder/lateral_binder.cpp | 27 +- .../src/planner/expression_iterator.cpp | 152 ---------- src/duckdb/src/planner/logical_operator.cpp | 13 + .../rewrite_correlated_expressions.cpp | 50 ++-- src/duckdb/ub_src_parser_query_node.cpp | 2 + .../ub_src_planner_binder_query_node.cpp | 2 + 40 files changed, 596 insertions(+), 628 deletions(-) create mode 100644 src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp create mode 100644 src/duckdb/src/parser/query_node/statement_node.cpp create mode 100644 src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 49417ac45..6862e4295 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -3624,19 +3624,20 @@ const StringUtil::EnumStringLiteral *GetQueryNodeTypeValues() { { static_cast(QueryNodeType::SET_OPERATION_NODE), "SET_OPERATION_NODE" }, { static_cast(QueryNodeType::BOUND_SUBQUERY_NODE), "BOUND_SUBQUERY_NODE" }, { static_cast(QueryNodeType::RECURSIVE_CTE_NODE), "RECURSIVE_CTE_NODE" }, - { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" } + { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" }, + { static_cast(QueryNodeType::STATEMENT_NODE), "STATEMENT_NODE" } }; return values; } template<> const char* EnumUtil::ToChars(QueryNodeType value) { - return StringUtil::EnumToString(GetQueryNodeTypeValues(), 5, "QueryNodeType", static_cast(value)); + return StringUtil::EnumToString(GetQueryNodeTypeValues(), 6, "QueryNodeType", static_cast(value)); } template<> QueryNodeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 5, "QueryNodeType", value)); + return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 6, "QueryNodeType", value)); } const StringUtil::EnumStringLiteral *GetQueryResultTypeValues() { diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 94f46ab3f..f797f6910 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev722" +#define DUCKDB_PATCH_VERSION "0-dev732" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev722" +#define DUCKDB_VERSION "v1.5.0-dev732" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "b3c8acdc0e" +#define DUCKDB_SOURCE_ID "f793ea27c6" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp index 7240d62d8..5b6f11730 100644 --- a/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp +++ b/src/duckdb/src/include/duckdb/common/insertion_order_preserving_map.hpp @@ -95,6 +95,10 @@ class InsertionOrderPreservingMap { map.resize(nz); } + void clear() { // NOLINT: match stl API + map.clear(); + } + void insert(const string &key, V &&value) { // NOLINT: match stl API if (contains(key)) { return; diff --git a/src/duckdb/src/include/duckdb/parser/query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node.hpp index ec03da095..956bd63f7 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node.hpp @@ -25,7 +25,8 @@ enum class QueryNodeType : uint8_t { SET_OPERATION_NODE = 2, BOUND_SUBQUERY_NODE = 3, RECURSIVE_CTE_NODE = 4, - CTE_NODE = 5 + CTE_NODE = 5, + STATEMENT_NODE = 6 }; struct CommonTableExpressionInfo; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp index 94bfd3438..3a2894cc4 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/list.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/list.hpp @@ -2,3 +2,4 @@ #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/statement_node.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp new file mode 100644 index 000000000..9e813335c --- /dev/null +++ b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp @@ -0,0 +1,42 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/parser/query_node/statement_node.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" +#include "duckdb/parser/sql_statement.hpp" + +namespace duckdb { + +class StatementNode : public QueryNode { +public: + static constexpr const QueryNodeType TYPE = QueryNodeType::STATEMENT_NODE; + +public: + explicit StatementNode(SQLStatement &stmt_p); + + SQLStatement &stmt; + +public: + const vector> &GetSelectList() const override; + //! Convert the query node to a string + string ToString() const override; + + bool Equals(const QueryNode *other) const override; + //! Create a copy of this SelectNode + unique_ptr Copy() const override; + + //! Serializes a QueryNode to a stand-alone binary blob + //! Deserializes a blob back into a QueryNode + + void Serialize(Serializer &serializer) const override; + static unique_ptr Deserialize(Deserializer &source); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tokens.hpp b/src/duckdb/src/include/duckdb/parser/tokens.hpp index 6eeb8c5e2..d5646739c 100644 --- a/src/duckdb/src/include/duckdb/parser/tokens.hpp +++ b/src/duckdb/src/include/duckdb/parser/tokens.hpp @@ -53,6 +53,7 @@ class SelectNode; class SetOperationNode; class RecursiveCTENode; class CTENode; +class StatementNode; //===--------------------------------------------------------------------===// // Expressions diff --git a/src/duckdb/src/include/duckdb/planner/bind_context.hpp b/src/duckdb/src/include/duckdb/planner/bind_context.hpp index d9c20dd1d..d4d487400 100644 --- a/src/duckdb/src/include/duckdb/planner/bind_context.hpp +++ b/src/duckdb/src/include/duckdb/planner/bind_context.hpp @@ -23,7 +23,7 @@ namespace duckdb { class Binder; class LogicalGet; -class BoundQueryNode; +struct BoundStatement; class StarExpression; @@ -105,11 +105,11 @@ class BindContext { const vector &types, vector &bound_column_ids, optional_ptr entry, virtual_column_map_t virtual_columns); //! Adds a table view with a given alias to the BindContext. - void AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, ViewCatalogEntry &view); + void AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery); + void AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery); //! Adds a subquery with a given alias to the BindContext. - void AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery); + void AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery); //! Adds a binding to a catalog entry with a given alias to the BindContext. void AddEntryBinding(idx_t index, const string &alias, const vector &names, const vector &types, StandardEntry &entry); diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index a2bac2325..8f1608112 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -69,6 +69,7 @@ struct PivotColumnEntry; struct UnpivotEntry; struct CopyInfo; struct CopyOption; +struct BoundSetOpChild; template class IndexVector; @@ -403,23 +404,25 @@ class Binder : public enable_shared_from_this { unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); - unique_ptr BindMaterializedCTE(CommonTableExpressionMap &cte_map); - unique_ptr BindCTE(CTENode &statement); + BoundStatement BindCTE(CTENode &statement); - unique_ptr BindNode(SelectNode &node); - unique_ptr BindNode(SetOperationNode &node); - unique_ptr BindNode(RecursiveCTENode &node); - unique_ptr BindNode(CTENode &node); - unique_ptr BindNode(QueryNode &node); + BoundStatement BindNode(SelectNode &node); + BoundStatement BindNode(SetOperationNode &node); + BoundStatement BindNode(RecursiveCTENode &node); + BoundStatement BindNode(CTENode &node); + BoundStatement BindNode(QueryNode &node); + BoundStatement BindNode(StatementNode &node); unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); unique_ptr CreatePlan(BoundRecursiveCTENode &node); unique_ptr CreatePlan(BoundCTENode &node); - unique_ptr CreatePlan(BoundCTENode &node, unique_ptr base); unique_ptr CreatePlan(BoundSelectNode &statement); unique_ptr CreatePlan(BoundSetOperationNode &node); unique_ptr CreatePlan(BoundQueryNode &node); + BoundSetOpChild BindSetOpChild(QueryNode &child); + unique_ptr BindSetOpNode(SetOperationNode &statement); + unique_ptr BindJoin(Binder &parent, TableRef &ref); unique_ptr Bind(BaseTableRef &ref); unique_ptr Bind(BoundRefWrapper &ref); @@ -489,8 +492,8 @@ class Binder : public enable_shared_from_this { JoinType join_type = JoinType::INNER, unique_ptr condition = nullptr); - unique_ptr CastLogicalOperatorToTypes(vector &source_types, - vector &target_types, + unique_ptr CastLogicalOperatorToTypes(const vector &source_types, + const vector &target_types, unique_ptr op); BindingAlias FindBinding(const string &using_column, const string &join_side); @@ -522,7 +525,9 @@ class Binder : public enable_shared_from_this { LogicalType BindLogicalTypeInternal(const LogicalType &type, optional_ptr catalog, const string &schema); - unique_ptr BindSelectNode(SelectNode &statement, unique_ptr from_table); + BoundStatement BindSelectNode(SelectNode &statement, unique_ptr from_table); + unique_ptr BindSelectNodeInternal(SelectNode &statement); + unique_ptr BindSelectNodeInternal(SelectNode &statement, unique_ptr from_table); unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const string &target_database_name); unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name); diff --git a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp index aa07a67b9..35792c8d4 100644 --- a/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression/bound_subquery_expression.hpp @@ -29,7 +29,7 @@ class BoundSubqueryExpression : public Expression { //! The binder used to bind the subquery node shared_ptr binder; //! The bound subquery node - unique_ptr subquery; + BoundStatement subquery; //! The subquery type SubqueryType subquery_type; //! the child expressions to compare with (in case of IN, ANY, ALL operators) diff --git a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp index 5b2e2e8a8..804adc56f 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp @@ -14,7 +14,6 @@ #include namespace duckdb { -class BoundQueryNode; class BoundTableRef; class ExpressionIterator { @@ -47,18 +46,4 @@ class ExpressionIterator { } }; -class BoundNodeVisitor { -public: - virtual ~BoundNodeVisitor() = default; - - virtual void VisitBoundQueryNode(BoundQueryNode &op); - virtual void VisitBoundTableRef(BoundTableRef &ref); - virtual void VisitExpression(unique_ptr &expression); - -protected: - // The VisitExpressionChildren method is called at the end of every call to VisitExpression to recursively visit all - // expressions in an expression tree. It can be overloaded to prevent automatically visiting the entire tree. - virtual void VisitExpressionChildren(Expression &expression); -}; - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/logical_operator.hpp b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp index e7f533bdd..743a9153b 100644 --- a/src/duckdb/src/include/duckdb/planner/logical_operator.hpp +++ b/src/duckdb/src/include/duckdb/planner/logical_operator.hpp @@ -45,6 +45,7 @@ class LogicalOperator { public: virtual vector GetColumnBindings(); + virtual idx_t GetRootIndex(); static string ColumnBindingsToString(const vector &bindings); void PrintColumnBindings(); static vector GenerateColumnBindings(idx_t table_idx, idx_t column_count); diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp index cbfdecd1f..67c076ab6 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp @@ -25,9 +25,9 @@ class BoundCTENode : public BoundQueryNode { string ctename; //! The cte node - unique_ptr query; + BoundStatement query; //! The child node - unique_ptr child; + BoundStatement child; //! Index used by the set operation idx_t setop_index; //! The binder used by the query side of the CTE @@ -39,7 +39,7 @@ class BoundCTENode : public BoundQueryNode { public: idx_t GetRootIndex() override { - return child->GetRootIndex(); + return child.plan->GetRootIndex(); } }; diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp index 3da295e2a..6a1819464 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp @@ -27,9 +27,9 @@ class BoundRecursiveCTENode : public BoundQueryNode { bool union_all; //! The left side of the set operation - unique_ptr left; + BoundStatement left; //! The right side of the set operation - unique_ptr right; + BoundStatement right; //! Target columns for the recursive key variant vector> key_targets; diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp index 01fa37caf..391ca26e6 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp @@ -13,13 +13,7 @@ #include "duckdb/planner/bound_query_node.hpp" namespace duckdb { - -struct BoundSetOpChild { - unique_ptr node; - shared_ptr binder; - //! Exprs used by the UNION BY NAME operations to add a new projection - vector> reorder_expressions; -}; +struct BoundSetOpChild; //! Bound equivalent of SetOperationNode class BoundSetOperationNode : public BoundQueryNode { @@ -29,6 +23,7 @@ class BoundSetOperationNode : public BoundQueryNode { public: BoundSetOperationNode() : BoundQueryNode(QueryNodeType::SET_OPERATION_NODE) { } + ~BoundSetOperationNode() override; //! The type of set operation SetOperationType setop_type = SetOperationType::NONE; @@ -46,4 +41,18 @@ class BoundSetOperationNode : public BoundQueryNode { } }; +struct BoundSetOpChild { + unique_ptr bound_node; + BoundStatement node; + shared_ptr binder; + //! Original select list (if this was a SELECT statement) + vector> select_list; + //! Exprs used by the UNION BY NAME operations to add a new projection + vector> reorder_expressions; + + const vector &GetNames(); + const vector &GetTypes(); + idx_t GetRootIndex(); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp index 2d1061c98..a07994f8a 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp @@ -20,13 +20,13 @@ class BoundSubqueryRef : public BoundTableRef { static constexpr const TableReferenceType TYPE = TableReferenceType::SUBQUERY; public: - BoundSubqueryRef(shared_ptr binder_p, unique_ptr subquery) + BoundSubqueryRef(shared_ptr binder_p, BoundStatement subquery) : BoundTableRef(TableReferenceType::SUBQUERY), binder(std::move(binder_p)), subquery(std::move(subquery)) { } //! The binder used to bind the subquery shared_ptr binder; //! The bound subquery node (if any) - unique_ptr subquery; + BoundStatement subquery; }; } // namespace duckdb diff --git a/src/duckdb/src/parser/query_node/statement_node.cpp b/src/duckdb/src/parser/query_node/statement_node.cpp new file mode 100644 index 000000000..e27b2e6c0 --- /dev/null +++ b/src/duckdb/src/parser/query_node/statement_node.cpp @@ -0,0 +1,43 @@ +#include "duckdb/parser/query_node/statement_node.hpp" + +namespace duckdb { + +StatementNode::StatementNode(SQLStatement &stmt_p) : QueryNode(QueryNodeType::STATEMENT_NODE), stmt(stmt_p) { +} + +const vector> &StatementNode::GetSelectList() const { + throw InternalException("StatementNode has no select list"); +} +//! Convert the query node to a string +string StatementNode::ToString() const { + return stmt.ToString(); +} + +bool StatementNode::Equals(const QueryNode *other_p) const { + if (!QueryNode::Equals(other_p)) { + return false; + } + if (this == other_p) { + return true; + } + auto &other = other_p->Cast(); + return RefersToSameObject(stmt, other.stmt); +} + +//! Create a copy of this SelectNode +unique_ptr StatementNode::Copy() const { + return make_uniq(stmt); +} + +//! Serializes a QueryNode to a stand-alone binary blob +//! Deserializes a blob back into a QueryNode + +void StatementNode::Serialize(Serializer &serializer) const { + throw InternalException("StatementNode cannot be serialized"); +} + +unique_ptr StatementNode::Deserialize(Deserializer &source) { + throw InternalException("StatementNode cannot be deserialized"); +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/bind_context.cpp b/src/duckdb/src/planner/bind_context.cpp index b6e5df81f..10f32432a 100644 --- a/src/duckdb/src/planner/bind_context.cpp +++ b/src/duckdb/src/planner/bind_context.cpp @@ -686,7 +686,7 @@ vector BindContext::AliasColumnNames(const string &table_name, const vec return result; } -void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery) { +void BindContext::AddSubquery(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } @@ -696,13 +696,13 @@ void BindContext::AddEntryBinding(idx_t index, const string &alias, const vector AddBinding(make_uniq(alias, types, names, index, entry)); } -void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundQueryNode &subquery, +void BindContext::AddView(idx_t index, const string &alias, SubqueryRef &ref, BoundStatement &subquery, ViewCatalogEntry &view) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddEntryBinding(index, alias, names, subquery.types, view.Cast()); } -void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundQueryNode &subquery) { +void BindContext::AddSubquery(idx_t index, const string &alias, TableFunctionRef &ref, BoundStatement &subquery) { auto names = AliasColumnNames(alias, subquery.names, ref.column_name_alias); AddGenericBinding(index, alias, names, subquery.types); } diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 01e6fbfca..440152476 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -70,7 +70,13 @@ Binder::Binder(ClientContext &context, shared_ptr parent_p, BinderType b } } -unique_ptr Binder::BindMaterializedCTE(CommonTableExpressionMap &cte_map) { +template +BoundStatement Binder::BindWithCTE(T &statement) { + auto &cte_map = statement.cte_map; + if (cte_map.map.empty()) { + return Bind(statement); + } + // Extract materialized CTEs from cte_map vector> materialized_ctes; for (auto &cte : cte_map.map) { @@ -83,58 +89,18 @@ unique_ptr Binder::BindMaterializedCTE(CommonTableExpressionMap &c materialized_ctes.push_back(std::move(mat_cte)); } - if (materialized_ctes.empty()) { - return nullptr; - } - - unique_ptr cte_root = nullptr; + unique_ptr cte_root = make_uniq(statement); while (!materialized_ctes.empty()) { unique_ptr node_result; node_result = std::move(materialized_ctes.back()); node_result->cte_map = cte_map.Copy(); - if (cte_root) { - node_result->child = std::move(cte_root); - } else { - node_result->child = nullptr; - } + node_result->child = std::move(cte_root); cte_root = std::move(node_result); materialized_ctes.pop_back(); } AddCTEMap(cte_map); - auto bound_cte = BindCTE(cte_root->Cast()); - - return bound_cte; -} - -template -BoundStatement Binder::BindWithCTE(T &statement) { - BoundStatement bound_statement; - auto bound_cte = BindMaterializedCTE(statement.template Cast().cte_map); - if (bound_cte) { - reference tail_ref = *bound_cte; - - while (tail_ref.get().child && tail_ref.get().child->type == QueryNodeType::CTE_NODE) { - tail_ref = tail_ref.get().child->Cast(); - } - - auto &tail = tail_ref.get(); - bound_statement = tail.child_binder->Bind(statement.template Cast()); - - tail.types = bound_statement.types; - tail.names = bound_statement.names; - - for (auto &c : tail.query_binder->correlated_columns) { - tail.child_binder->AddCorrelatedColumn(c); - } - MoveCorrelatedExpressions(*tail.child_binder); - - auto plan = std::move(bound_statement.plan); - bound_statement.plan = CreatePlan(*bound_cte, std::move(plan)); - } else { - bound_statement = Bind(statement.template Cast()); - } - return bound_statement; + return Bind(*cte_root); } BoundStatement Binder::Bind(SQLStatement &statement) { @@ -204,54 +170,28 @@ void Binder::AddCTEMap(CommonTableExpressionMap &cte_map) { } } -unique_ptr Binder::BindNode(QueryNode &node) { +BoundStatement Binder::BindNode(QueryNode &node) { // first we visit the set of CTEs and add them to the bind context AddCTEMap(node.cte_map); // now we bind the node - unique_ptr result; switch (node.type) { case QueryNodeType::SELECT_NODE: - result = BindNode(node.Cast()); - break; + return BindNode(node.Cast()); case QueryNodeType::RECURSIVE_CTE_NODE: - result = BindNode(node.Cast()); - break; + return BindNode(node.Cast()); case QueryNodeType::CTE_NODE: - result = BindNode(node.Cast()); - break; + return BindNode(node.Cast()); + case QueryNodeType::SET_OPERATION_NODE: + return BindNode(node.Cast()); + case QueryNodeType::STATEMENT_NODE: + return BindNode(node.Cast()); default: - D_ASSERT(node.type == QueryNodeType::SET_OPERATION_NODE); - result = BindNode(node.Cast()); - break; + throw InternalException("Unsupported query node type"); } - return result; } BoundStatement Binder::Bind(QueryNode &node) { - BoundStatement result; - auto bound_node = BindNode(node); - - result.names = bound_node->names; - result.types = bound_node->types; - - // and plan it - result.plan = CreatePlan(*bound_node); - return result; -} - -unique_ptr Binder::CreatePlan(BoundQueryNode &node) { - switch (node.type) { - case QueryNodeType::SELECT_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::SET_OPERATION_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::RECURSIVE_CTE_NODE: - return CreatePlan(node.Cast()); - case QueryNodeType::CTE_NODE: - return CreatePlan(node.Cast()); - default: - throw InternalException("Unsupported bound query node type"); - } + return BindNode(node); } unique_ptr Binder::Bind(TableRef &ref) { diff --git a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp index d413c88ed..8e15f3b28 100644 --- a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp @@ -13,14 +13,14 @@ class BoundSubqueryNode : public QueryNode { static constexpr const QueryNodeType TYPE = QueryNodeType::BOUND_SUBQUERY_NODE; public: - BoundSubqueryNode(shared_ptr subquery_binder, unique_ptr bound_node, + BoundSubqueryNode(shared_ptr subquery_binder, BoundStatement bound_node, unique_ptr subquery) : QueryNode(QueryNodeType::BOUND_SUBQUERY_NODE), subquery_binder(std::move(subquery_binder)), bound_node(std::move(bound_node)), subquery(std::move(subquery)) { } shared_ptr subquery_binder; - unique_ptr bound_node; + BoundStatement bound_node; unique_ptr subquery; const vector> &GetSelectList() const override { @@ -116,15 +116,15 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept idx_t expected_columns = 1; if (expr.child) { auto &child = BoundExpression::GetExpression(*expr.child); - ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node->types); + ExtractSubqueryChildren(child, child_expressions, bound_subquery.bound_node.types); if (child_expressions.empty()) { child_expressions.push_back(std::move(child)); } expected_columns = child_expressions.size(); } - if (bound_subquery.bound_node->types.size() != expected_columns) { + if (bound_subquery.bound_node.types.size() != expected_columns) { throw BinderException(expr, "Subquery returns %zu columns - expected %d", - bound_subquery.bound_node->types.size(), expected_columns); + bound_subquery.bound_node.types.size(), expected_columns); } } // both binding the child and binding the subquery was successful @@ -132,7 +132,7 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept auto subquery_binder = std::move(bound_subquery.subquery_binder); auto bound_node = std::move(bound_subquery.bound_node); LogicalType return_type = - expr.subquery_type == SubqueryType::SCALAR ? bound_node->types[0] : LogicalType(LogicalTypeId::BOOLEAN); + expr.subquery_type == SubqueryType::SCALAR ? bound_node.types[0] : LogicalType(LogicalTypeId::BOOLEAN); if (return_type.id() == LogicalTypeId::UNKNOWN) { return_type = LogicalType::SQLNULL; } @@ -144,7 +144,7 @@ BindResult ExpressionBinder::BindExpression(SubqueryExpression &expr, idx_t dept for (idx_t child_idx = 0; child_idx < child_expressions.size(); child_idx++) { auto &child = child_expressions[child_idx]; auto child_type = ExpressionBinder::GetExpressionReturnType(*child); - auto &subquery_type = bound_node->types[child_idx]; + auto &subquery_type = bound_node.types[child_idx]; LogicalType compare_type; if (!LogicalType::TryGetMaxLogicalType(context, child_type, subquery_type, compare_type)) { throw BinderException( diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp index 2a7cf8346..8e8ab74ad 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -8,7 +8,7 @@ namespace duckdb { -unique_ptr Binder::BindNode(CTENode &statement) { +BoundStatement Binder::BindNode(CTENode &statement) { // first recursively visit the materialized CTE operations // the left side is visited first and is added to the BindContext of the right side D_ASSERT(statement.query); @@ -16,28 +16,28 @@ unique_ptr Binder::BindNode(CTENode &statement) { return BindCTE(statement); } -unique_ptr Binder::BindCTE(CTENode &statement) { - auto result = make_uniq(); +BoundStatement Binder::BindCTE(CTENode &statement) { + BoundCTENode result; // first recursively visit the materialized CTE operations // the left side is visited first and is added to the BindContext of the right side D_ASSERT(statement.query); - result->ctename = statement.ctename; - result->materialized = statement.materialized; - result->setop_index = GenerateTableIndex(); + result.ctename = statement.ctename; + result.materialized = statement.materialized; + result.setop_index = GenerateTableIndex(); - AddCTE(result->ctename); + AddCTE(result.ctename); - result->query_binder = Binder::CreateBinder(context, this); - result->query = result->query_binder->BindNode(*statement.query); + result.query_binder = Binder::CreateBinder(context, this); + result.query = result.query_binder->BindNode(*statement.query); // the result types of the CTE are the types of the LHS - result->types = result->query->types; + result.types = result.query.types; // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->query->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; + result.names = result.query.names; + for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { + result.names[i] = statement.aliases[i]; } // Rename columns if duplicate names are detected @@ -45,7 +45,7 @@ unique_ptr Binder::BindCTE(CTENode &statement) { vector names; // Use a case-insensitive set to track names case_insensitive_set_t ci_names; - for (auto &n : result->names) { + for (auto &n : result.names) { string name = n; while (ci_names.find(name) != ci_names.end()) { name = n + "_" + std::to_string(index++); @@ -55,16 +55,16 @@ unique_ptr Binder::BindCTE(CTENode &statement) { } // This allows the right side to reference the CTE - bind_context.AddGenericBinding(result->setop_index, statement.ctename, names, result->types); + bind_context.AddGenericBinding(result.setop_index, statement.ctename, names, result.types); - result->child_binder = Binder::CreateBinder(context, this); + result.child_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - result->child_binder->bind_context.RemoveCTEBinding(statement.ctename); - result->child_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, names, result->types); + result.child_binder->bind_context.RemoveCTEBinding(statement.ctename); + result.child_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, names, result.types); if (statement.child) { // Move all modifiers to the child node. @@ -74,21 +74,25 @@ unique_ptr Binder::BindCTE(CTENode &statement) { statement.modifiers.clear(); - result->child = result->child_binder->BindNode(*statement.child); - for (auto &c : result->query_binder->correlated_columns) { - result->child_binder->AddCorrelatedColumn(c); + result.child = result.child_binder->BindNode(*statement.child); + for (auto &c : result.query_binder->correlated_columns) { + result.child_binder->AddCorrelatedColumn(c); } // the result types of the CTE are the types of the LHS - result->types = result->child->types; - result->names = result->child->names; + result.types = result.child.types; + result.names = result.child.names; - MoveCorrelatedExpressions(*result->child_binder); + MoveCorrelatedExpressions(*result.child_binder); } - MoveCorrelatedExpressions(*result->query_binder); + MoveCorrelatedExpressions(*result.query_binder); - return result; + BoundStatement result_statement; + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp index 54e9e9fa5..18f22ae50 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -8,8 +8,8 @@ namespace duckdb { -unique_ptr Binder::BindNode(RecursiveCTENode &statement) { - auto result = make_uniq(); +BoundStatement Binder::BindNode(RecursiveCTENode &statement) { + BoundRecursiveCTENode result; // first recursively visit the recursive CTE operations // the left side is visited first and is added to the BindContext of the right side @@ -19,53 +19,53 @@ unique_ptr Binder::BindNode(RecursiveCTENode &statement) { throw BinderException("UNION ALL cannot be used with USING KEY in recursive CTE."); } - result->ctename = statement.ctename; - result->union_all = statement.union_all; - result->setop_index = GenerateTableIndex(); + result.ctename = statement.ctename; + result.union_all = statement.union_all; + result.setop_index = GenerateTableIndex(); - result->left_binder = Binder::CreateBinder(context, this); - result->left = result->left_binder->BindNode(*statement.left); + result.left_binder = Binder::CreateBinder(context, this); + result.left = result.left_binder->BindNode(*statement.left); // the result types of the CTE are the types of the LHS - result->types = result->left->types; + result.types = result.left.types; // names are picked from the LHS, unless aliases are explicitly specified - result->names = result->left->names; - for (idx_t i = 0; i < statement.aliases.size() && i < result->names.size(); i++) { - result->names[i] = statement.aliases[i]; + result.names = result.left.names; + for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { + result.names[i] = statement.aliases[i]; } // This allows the right side to reference the CTE recursively - bind_context.AddGenericBinding(result->setop_index, statement.ctename, result->names, result->types); + bind_context.AddGenericBinding(result.setop_index, statement.ctename, result.names, result.types); - result->right_binder = Binder::CreateBinder(context, this); + result.right_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - result->right_binder->bind_context.RemoveCTEBinding(statement.ctename); - result->right_binder->bind_context.AddCTEBinding(result->setop_index, statement.ctename, result->names, - result->types, !statement.key_targets.empty()); + result.right_binder->bind_context.RemoveCTEBinding(statement.ctename); + result.right_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, result.names, result.types, + !statement.key_targets.empty()); - result->right = result->right_binder->BindNode(*statement.right); - for (auto &c : result->left_binder->correlated_columns) { - result->right_binder->AddCorrelatedColumn(c); + result.right = result.right_binder->BindNode(*statement.right); + for (auto &c : result.left_binder->correlated_columns) { + result.right_binder->AddCorrelatedColumn(c); } // move the correlated expressions from the child binders to this binder - MoveCorrelatedExpressions(*result->left_binder); - MoveCorrelatedExpressions(*result->right_binder); + MoveCorrelatedExpressions(*result.left_binder); + MoveCorrelatedExpressions(*result.right_binder); // bind specified keys to the referenced column auto expression_binder = ExpressionBinder(*this, context); for (unique_ptr &expr : statement.key_targets) { auto bound_expr = expression_binder.Bind(expr); D_ASSERT(bound_expr->type == ExpressionType::BOUND_COLUMN_REF); - result->key_targets.push_back(std::move(bound_expr)); + result.key_targets.push_back(std::move(bound_expr)); } // now both sides have been bound we can resolve types - if (result->left->types.size() != result->right->types.size()) { + if (result.left.types.size() != result.right.types.size()) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -74,7 +74,11 @@ unique_ptr Binder::BindNode(RecursiveCTENode &statement) { throw NotImplementedException("FIXME: bind modifiers in recursive CTE"); } - return std::move(result); + BoundStatement result_statement; + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index 4f52dfc4a..6edddac51 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -363,7 +363,7 @@ void Binder::BindModifiers(BoundQueryNode &result, idx_t table_index, const vect } } -unique_ptr Binder::BindNode(SelectNode &statement) { +BoundStatement Binder::BindNode(SelectNode &statement) { D_ASSERT(statement.from_table); // first bind the FROM table statement @@ -372,6 +372,15 @@ unique_ptr Binder::BindNode(SelectNode &statement) { return BindSelectNode(statement, std::move(from_table)); } +unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement) { + D_ASSERT(statement.from_table); + + // first bind the FROM table statement + auto from = std::move(statement.from_table); + auto from_table = Bind(*from); + return BindSelectNodeInternal(statement, std::move(from_table)); +} + void Binder::BindWhereStarExpression(unique_ptr &expr) { // expand any expressions in the upper AND recursively if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { @@ -403,21 +412,23 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { } } -unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ptr from_table) { +unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement, + unique_ptr from_table) { D_ASSERT(from_table); D_ASSERT(!statement.from_table); - auto result = make_uniq(); - result->projection_index = GenerateTableIndex(); - result->group_index = GenerateTableIndex(); - result->aggregate_index = GenerateTableIndex(); - result->groupings_index = GenerateTableIndex(); - result->window_index = GenerateTableIndex(); - result->prune_index = GenerateTableIndex(); - - result->from_table = std::move(from_table); + auto result_ptr = make_uniq(); + auto &result = *result_ptr; + result.projection_index = GenerateTableIndex(); + result.group_index = GenerateTableIndex(); + result.aggregate_index = GenerateTableIndex(); + result.groupings_index = GenerateTableIndex(); + result.window_index = GenerateTableIndex(); + result.prune_index = GenerateTableIndex(); + + result.from_table = std::move(from_table); // bind the sample clause if (statement.sample) { - result->sample_options = std::move(statement.sample); + result.sample_options = std::move(statement.sample); } // visit the select list and expand any "*" statements @@ -429,19 +440,19 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ } statement.select_list = std::move(new_select_list); - auto &bind_state = result->bind_state; + auto &bind_state = result.bind_state; for (idx_t i = 0; i < statement.select_list.size(); i++) { auto &expr = statement.select_list[i]; - result->names.push_back(expr->GetName()); + result.names.push_back(expr->GetName()); ExpressionBinder::QualifyColumnNames(*this, expr); if (!expr->GetAlias().empty()) { bind_state.alias_map[expr->GetAlias()] = i; - result->names[i] = expr->GetAlias(); + result.names[i] = expr->GetAlias(); } bind_state.projection_map[*expr] = i; bind_state.original_expressions.push_back(expr->Copy()); } - result->column_count = statement.select_list.size(); + result.column_count = statement.select_list.size(); // first visit the WHERE clause // the WHERE clause happens before the GROUP BY, PROJECTION or HAVING clauses @@ -452,12 +463,12 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ ColumnAliasBinder alias_binder(bind_state); WhereBinder where_binder(*this, context, &alias_binder); unique_ptr condition = std::move(statement.where_clause); - result->where_clause = where_binder.Bind(condition); + result.where_clause = where_binder.Bind(condition); } // now bind all the result modifiers; including DISTINCT and ORDER BY targets OrderBinder order_binder({*this}, statement, bind_state); - PrepareModifiers(order_binder, statement, *result); + PrepareModifiers(order_binder, statement, result); vector> unbound_groups; BoundGroupInformation info; @@ -465,7 +476,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (!group_expressions.empty()) { // the statement has a GROUP BY clause, bind it unbound_groups.resize(group_expressions.size()); - GroupBinder group_binder(*this, context, statement, result->group_index, bind_state, info.alias_map); + GroupBinder group_binder(*this, context, statement, result.group_index, bind_state, info.alias_map); for (idx_t i = 0; i < group_expressions.size(); i++) { // we keep a copy of the unbound expression; @@ -489,7 +500,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (!contains_subquery && requires_collation) { // if there is a collation on a group x, we should group by the collated expr, // but also push a first(x) aggregate in case x is selected (uncollated) - info.collated_groups[i] = result->aggregates.size(); + info.collated_groups[i] = result.aggregates.size(); auto first_fun = FirstFunctionGetter::GetFunction(bound_expr_ref.return_type); vector> first_children; @@ -499,9 +510,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ FunctionBinder function_binder(*this); auto function = function_binder.BindAggregateFunction(first_fun, std::move(first_children)); function->SetAlias("__collated_group"); - result->aggregates.push_back(std::move(function)); + result.aggregates.push_back(std::move(function)); } - result->groups.group_expressions.push_back(std::move(bound_expr)); + result.groups.group_expressions.push_back(std::move(bound_expr)); // in the unbound expression we DO bind the table names of any ColumnRefs // we do this to make sure that "table.a" and "a" are treated the same @@ -512,13 +523,13 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ info.map[*unbound_groups[i]] = i; } } - result->groups.grouping_sets = std::move(statement.groups.grouping_sets); + result.groups.grouping_sets = std::move(statement.groups.grouping_sets); // bind the HAVING clause, if any if (statement.having) { - HavingBinder having_binder(*this, context, *result, info, statement.aggregate_handling); + HavingBinder having_binder(*this, context, result, info, statement.aggregate_handling); ExpressionBinder::QualifyColumnNames(having_binder, statement.having); - result->having = having_binder.Bind(statement.having); + result.having = having_binder.Bind(statement.having); } // bind the QUALIFY clause, if any @@ -527,9 +538,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES) { throw BinderException("Combining QUALIFY with GROUP BY ALL is not supported yet"); } - QualifyBinder qualify_binder(*this, context, *result, info); + QualifyBinder qualify_binder(*this, context, result, info); ExpressionBinder::QualifyColumnNames(*this, statement.qualify); - result->qualify = qualify_binder.Bind(statement.qualify); + result.qualify = qualify_binder.Bind(statement.qualify); if (qualify_binder.HasBoundColumns()) { if (qualify_binder.BoundAggregates()) { throw BinderException("Cannot mix aggregates with non-aggregated columns!"); @@ -539,7 +550,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ } // after that, we bind to the SELECT list - SelectBinder select_binder(*this, context, *result, info); + SelectBinder select_binder(*this, context, result, info); // if we expand select-list expressions, e.g., via UNNEST, then we need to possibly // adjust the column index of the already bound ORDER BY modifiers, and not only set their types @@ -549,13 +560,13 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ for (idx_t i = 0; i < statement.select_list.size(); i++) { bool is_window = statement.select_list[i]->IsWindow(); - idx_t unnest_count = result->unnests.size(); + idx_t unnest_count = result.unnests.size(); LogicalType result_type; auto expr = select_binder.Bind(statement.select_list[i], &result_type, true); - bool is_original_column = i < result->column_count; + bool is_original_column = i < result.column_count; bool can_group_by_all = statement.aggregate_handling == AggregateHandling::FORCE_AGGREGATES && is_original_column; - result->bound_column_count++; + result.bound_column_count++; if (expr->GetExpressionType() == ExpressionType::BOUND_EXPANDED) { if (!is_original_column) { @@ -571,9 +582,9 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ for (auto &struct_expr : struct_expressions) { new_names.push_back(struct_expr->GetName()); - result->types.push_back(struct_expr->return_type); + result.types.push_back(struct_expr->return_type); internal_sql_types.push_back(struct_expr->return_type); - result->select_list.push_back(std::move(struct_expr)); + result.select_list.push_back(std::move(struct_expr)); } bind_state.AddExpandedColumn(struct_expressions.size()); continue; @@ -594,7 +605,7 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ if (is_window) { throw BinderException("Cannot group on a window clause"); } - if (result->unnests.size() > unnest_count) { + if (result.unnests.size() > unnest_count) { throw BinderException("Cannot group on an UNNEST or UNLIST clause"); } // we are forcing aggregates, and the node has columns bound @@ -602,10 +613,10 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ group_by_all_indexes.push_back(i); } - result->select_list.push_back(std::move(expr)); + result.select_list.push_back(std::move(expr)); if (is_original_column) { - new_names.push_back(std::move(result->names[i])); - result->types.push_back(result_type); + new_names.push_back(std::move(result.names[i])); + result.types.push_back(result_type); } internal_sql_types.push_back(result_type); @@ -617,31 +628,31 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ // push the GROUP BY ALL expressions into the group set for (auto &group_by_all_index : group_by_all_indexes) { - auto &expr = result->select_list[group_by_all_index]; + auto &expr = result.select_list[group_by_all_index]; auto group_ref = make_uniq( - expr->return_type, ColumnBinding(result->group_index, result->groups.group_expressions.size())); - result->groups.group_expressions.push_back(std::move(expr)); + expr->return_type, ColumnBinding(result.group_index, result.groups.group_expressions.size())); + result.groups.group_expressions.push_back(std::move(expr)); expr = std::move(group_ref); } set group_by_all_indexes_set; if (!group_by_all_indexes.empty()) { - idx_t num_set_indexes = result->groups.group_expressions.size(); + idx_t num_set_indexes = result.groups.group_expressions.size(); for (idx_t i = 0; i < num_set_indexes; i++) { group_by_all_indexes_set.insert(i); } - D_ASSERT(result->groups.grouping_sets.empty()); - result->groups.grouping_sets.push_back(group_by_all_indexes_set); + D_ASSERT(result.groups.grouping_sets.empty()); + result.groups.grouping_sets.push_back(group_by_all_indexes_set); } - result->column_count = new_names.size(); - result->names = std::move(new_names); - result->need_prune = result->select_list.size() > result->column_count; + result.column_count = new_names.size(); + result.names = std::move(new_names); + result.need_prune = result.select_list.size() > result.column_count; // in the normal select binder, we bind columns as if there is no aggregation // i.e. in the query [SELECT i, SUM(i) FROM integers;] the "i" will be bound as a normal column // since we have an aggregation, we need to either (1) throw an error, or (2) wrap the column in a FIRST() aggregate // we choose the former one [CONTROVERSIAL: this is the PostgreSQL behavior] - if (!result->groups.group_expressions.empty() || !result->aggregates.empty() || statement.having || - !result->groups.grouping_sets.empty()) { + if (!result.groups.group_expressions.empty() || !result.aggregates.empty() || statement.having || + !result.groups.grouping_sets.empty()) { if (statement.aggregate_handling == AggregateHandling::NO_AGGREGATES_ALLOWED) { throw BinderException("Aggregates cannot be present in a Project relation!"); } else { @@ -672,13 +683,23 @@ unique_ptr Binder::BindSelectNode(SelectNode &statement, unique_ // QUALIFY clause requires at least one window function to be specified in at least one of the SELECT column list or // the filter predicate of the QUALIFY clause - if (statement.qualify && result->windows.empty()) { + if (statement.qualify && result.windows.empty()) { throw BinderException("at least one window function must appear in the SELECT column or QUALIFY clause"); } // now that the SELECT list is bound, we set the types of DISTINCT/ORDER BY expressions - BindModifiers(*result, result->projection_index, result->names, internal_sql_types, bind_state); - return std::move(result); + BindModifiers(result, result.projection_index, result.names, internal_sql_types, bind_state); + return result_ptr; +} + +BoundStatement Binder::BindSelectNode(SelectNode &statement, unique_ptr from_table) { + auto result = BindSelectNodeInternal(statement, std::move(from_table)); + + BoundStatement result_statement; + result_statement.types = result->types; + result_statement.names = result->names; + result_statement.plan = CreatePlan(*result); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp index 50c6b3c06..d70f6d2cc 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp @@ -14,82 +14,108 @@ namespace duckdb { -static void GatherAliases(BoundQueryNode &node, SelectBindState &bind_state, const vector &reorder_idx) { - if (node.type == QueryNodeType::SET_OPERATION_NODE) { - // setop, recurse - auto &setop = node.Cast(); +BoundSetOperationNode::~BoundSetOperationNode() { +} - // create new reorder index - if (setop.setop_type == SetOperationType::UNION_BY_NAME) { - // for UNION BY NAME - create a new re-order index - case_insensitive_map_t reorder_map; - for (idx_t col_idx = 0; col_idx < setop.names.size(); ++col_idx) { - reorder_map[setop.names[col_idx]] = reorder_idx[col_idx]; - } +struct SetOpAliasGatherer { +public: + explicit SetOpAliasGatherer(SelectBindState &bind_state_p) : bind_state(bind_state_p) { + } - // use new reorder index - for (auto &child : setop.bound_children) { - vector new_reorder_idx; - for (idx_t col_idx = 0; col_idx < child.node->names.size(); col_idx++) { - auto &col_name = child.node->names[col_idx]; - auto entry = reorder_map.find(col_name); - if (entry == reorder_map.end()) { - throw InternalException("SetOp - Column name not found in reorder_map in UNION BY NAME"); - } - new_reorder_idx.push_back(entry->second); - } - GatherAliases(*child.node, bind_state, new_reorder_idx); - } - return; - } + void GatherAliases(BoundSetOpChild &node, const vector &reorder_idx); + void GatherAliases(BoundSetOperationNode &node, const vector &reorder_idx); - for (auto &child : setop.bound_children) { - GatherAliases(*child.node, bind_state, reorder_idx); - } - } else { - // query node - D_ASSERT(node.type == QueryNodeType::SELECT_NODE); - auto &select = node.Cast(); - // fill the alias lists with the names - D_ASSERT(reorder_idx.size() == select.names.size()); - for (idx_t i = 0; i < select.names.size(); i++) { - auto &name = select.names[i]; - // first check if the alias is already in there - auto entry = bind_state.alias_map.find(name); +private: + SelectBindState &bind_state; +}; - idx_t index = reorder_idx[i]; +const vector &BoundSetOpChild::GetNames() { + return bound_node ? bound_node->names : node.names; +} +const vector &BoundSetOpChild::GetTypes() { + return bound_node ? bound_node->types : node.types; +} +idx_t BoundSetOpChild::GetRootIndex() { + return bound_node ? bound_node->GetRootIndex() : node.plan->GetRootIndex(); +} +void SetOpAliasGatherer::GatherAliases(BoundSetOpChild &node, const vector &reorder_idx) { + if (node.bound_node) { + GatherAliases(*node.bound_node, reorder_idx); + return; + } + + // query node + auto &select_names = node.GetNames(); + // fill the alias lists with the names + D_ASSERT(reorder_idx.size() == select_names.size()); + for (idx_t i = 0; i < select_names.size(); i++) { + auto &name = select_names[i]; + // first check if the alias is already in there + auto entry = bind_state.alias_map.find(name); - if (entry == bind_state.alias_map.end()) { - // the alias is not in there yet, just assign it - bind_state.alias_map[name] = index; + idx_t index = reorder_idx[i]; + + if (entry == bind_state.alias_map.end()) { + // the alias is not in there yet, just assign it + bind_state.alias_map[name] = index; + } + } + // check if the expression matches one of the expressions in the original expression list + for (idx_t i = 0; i < node.select_list.size(); i++) { + auto &expr = node.select_list[i]; + idx_t index = reorder_idx[i]; + // now check if the node is already in the set of expressions + auto expr_entry = bind_state.projection_map.find(*expr); + if (expr_entry != bind_state.projection_map.end()) { + // the node is in there + // repeat the same as with the alias: if there is an ambiguity we insert "-1" + if (expr_entry->second != index) { + bind_state.projection_map[*expr] = DConstants::INVALID_INDEX; } + } else { + // not in there yet, just place it in there + bind_state.projection_map[*expr] = index; } - // check if the expression matches one of the expressions in the original expression list - for (idx_t i = 0; i < select.bind_state.original_expressions.size(); i++) { - auto &expr = select.bind_state.original_expressions[i]; - idx_t index = reorder_idx[i]; - // now check if the node is already in the set of expressions - auto expr_entry = bind_state.projection_map.find(*expr); - if (expr_entry != bind_state.projection_map.end()) { - // the node is in there - // repeat the same as with the alias: if there is an ambiguity we insert "-1" - if (expr_entry->second != index) { - bind_state.projection_map[*expr] = DConstants::INVALID_INDEX; + } +} + +void SetOpAliasGatherer::GatherAliases(BoundSetOperationNode &setop, const vector &reorder_idx) { + // create new reorder index + if (setop.setop_type == SetOperationType::UNION_BY_NAME) { + // for UNION BY NAME - create a new re-order index + case_insensitive_map_t reorder_map; + for (idx_t col_idx = 0; col_idx < setop.names.size(); ++col_idx) { + reorder_map[setop.names[col_idx]] = reorder_idx[col_idx]; + } + + // use new reorder index + for (auto &child : setop.bound_children) { + vector new_reorder_idx; + auto &child_names = child.GetNames(); + for (idx_t col_idx = 0; col_idx < child_names.size(); col_idx++) { + auto &col_name = child_names[col_idx]; + auto entry = reorder_map.find(col_name); + if (entry == reorder_map.end()) { + throw InternalException("SetOp - Column name not found in reorder_map in UNION BY NAME"); } - } else { - // not in there yet, just place it in there - bind_state.projection_map[*expr] = index; + new_reorder_idx.push_back(entry->second); } + GatherAliases(child, new_reorder_idx); + } + } else { + for (auto &child : setop.bound_children) { + GatherAliases(child, reorder_idx); } } } -static void GatherAliases(BoundQueryNode &node, SelectBindState &bind_state) { +static void GatherAliases(BoundSetOperationNode &node, SelectBindState &bind_state) { + SetOpAliasGatherer gatherer(bind_state); vector reorder_idx; for (idx_t i = 0; i < node.names.size(); i++) { reorder_idx.push_back(i); } - GatherAliases(node, bind_state, reorder_idx); + gatherer.GatherAliases(node, reorder_idx); } static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode &result, bool can_contain_nulls) { @@ -101,10 +127,10 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & // We throw a binder exception if two same name in the SELECT list D_ASSERT(result.names.empty()); for (auto &child : result.bound_children) { - auto &child_node = *child.node; + auto &child_names = child.GetNames(); case_insensitive_map_t node_name_map; - for (idx_t i = 0; i < child_node.names.size(); ++i) { - auto &col_name = child_node.names[i]; + for (idx_t i = 0; i < child_names.size(); ++i) { + auto &col_name = child_names[i]; if (node_name_map.find(col_name) != node_name_map.end()) { throw BinderException( "UNION (ALL) BY NAME operation doesn't support duplicate names in the SELECT list - " @@ -129,7 +155,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & auto &col_name = result.names[i]; LogicalType result_type(LogicalTypeId::INVALID); for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { - auto &child = result.bound_children[child_idx]; + auto &child_types = result.bound_children[child_idx].GetTypes(); auto &child_name_map = node_name_maps[child_idx]; // check if the column exists in this child node auto entry = child_name_map.find(col_name); @@ -137,7 +163,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & need_reorder = true; } else { auto col_idx_in_child = entry->second; - auto &child_col_type = child.node->types[col_idx_in_child]; + auto &child_col_type = child_types[col_idx_in_child]; // the child exists in this node - compute the type if (result_type.id() == LogicalTypeId::INVALID) { result_type = child_col_type; @@ -179,34 +205,58 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & } else { // the column exists - reference it auto col_idx_in_child = entry->second; - auto &child_col_type = child.node->types[col_idx_in_child]; + auto &child_col_type = child.GetTypes()[col_idx_in_child]; expr = make_uniq(child_col_type, - ColumnBinding(child.node->GetRootIndex(), col_idx_in_child)); + ColumnBinding(child.GetRootIndex(), col_idx_in_child)); } child.reorder_expressions.push_back(std::move(expr)); } } } -static void GatherSetOpBinders(BoundQueryNode &node, Binder &binder, vector> &binders) { - if (node.type != QueryNodeType::SET_OPERATION_NODE) { - binders.push_back(binder); +BoundSetOpChild Binder::BindSetOpChild(QueryNode &child) { + BoundSetOpChild bound_child; + if (child.type == QueryNodeType::SET_OPERATION_NODE) { + bound_child.bound_node = BindSetOpNode(child.Cast()); + } else { + bound_child.binder = Binder::CreateBinder(context, this); + bound_child.binder->can_contain_nulls = true; + if (child.type == QueryNodeType::SELECT_NODE) { + auto &select_node = child.Cast(); + auto bound_select_node = bound_child.binder->BindSelectNodeInternal(select_node); + for (auto &expr : bound_select_node->bind_state.original_expressions) { + bound_child.select_list.push_back(expr->Copy()); + } + bound_child.node.names = bound_select_node->names; + bound_child.node.types = bound_select_node->types; + bound_child.node.plan = bound_child.binder->CreatePlan(*bound_select_node); + } else { + bound_child.node = bound_child.binder->BindNode(child); + } + } + return bound_child; +} + +static void GatherSetOpBinders(BoundSetOpChild &setop_child, vector> &binders) { + if (setop_child.binder) { + binders.push_back(*setop_child.binder); return; } - auto &setop_node = node.Cast(); + auto &setop_node = *setop_child.bound_node; for (auto &child : setop_node.bound_children) { - GatherSetOpBinders(*child.node, *child.binder, binders); + GatherSetOpBinders(child, binders); } } -unique_ptr Binder::BindNode(SetOperationNode &statement) { - auto result = make_uniq(); - result->setop_type = statement.setop_type; - result->setop_all = statement.setop_all; +unique_ptr Binder::BindSetOpNode(SetOperationNode &statement) { + auto result_ptr = make_uniq(); + auto &result = *result_ptr; + result.setop_type = statement.setop_type; + result.setop_all = statement.setop_all; // first recursively visit the set operations // all children have an independent BindContext and Binder - result->setop_index = GenerateTableIndex(); + result.setop_index = GenerateTableIndex(); if (statement.children.size() < 2) { throw InternalException("Set Operations must have at least 2 children"); } @@ -215,27 +265,27 @@ unique_ptr Binder::BindNode(SetOperationNode &statement) { throw InternalException("Set Operation type must have exactly 2 children - except for UNION/UNION_BY_NAME"); } for (auto &child : statement.children) { - BoundSetOpChild bound_child; - bound_child.binder = Binder::CreateBinder(context, this); - bound_child.binder->can_contain_nulls = true; - bound_child.node = bound_child.binder->BindNode(*child); - result->bound_children.push_back(std::move(bound_child)); + result.bound_children.push_back(BindSetOpChild(*child)); } + vector> binders; + for (auto &child : result.bound_children) { + GatherSetOpBinders(child, binders); + } // move the correlated expressions from the child binders to this binder - for (auto &bound_child : result->bound_children) { - MoveCorrelatedExpressions(*bound_child.binder); + for (auto &child_binder : binders) { + MoveCorrelatedExpressions(child_binder.get()); } - if (result->setop_type == SetOperationType::UNION_BY_NAME) { + if (result.setop_type == SetOperationType::UNION_BY_NAME) { // UNION BY NAME - merge the columns from all sides - BuildUnionByNameInfo(context, *result, can_contain_nulls); + BuildUnionByNameInfo(context, result, can_contain_nulls); } else { // UNION ALL BY POSITION - the columns of both sides must match exactly - result->names = result->bound_children[0].node->names; - auto result_columns = result->bound_children[0].node->types.size(); - for (idx_t i = 1; i < result->bound_children.size(); ++i) { - if (result->bound_children[i].node->types.size() != result_columns) { + result.names = result.bound_children[0].GetNames(); + auto result_columns = result.bound_children[0].GetTypes().size(); + for (idx_t i = 1; i < result.bound_children.size(); ++i) { + if (result.bound_children[i].GetTypes().size() != result_columns) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -243,40 +293,43 @@ unique_ptr Binder::BindNode(SetOperationNode &statement) { // figure out the types of the setop result by picking the max of both for (idx_t i = 0; i < result_columns; i++) { - auto result_type = result->bound_children[0].node->types[i]; - for (idx_t child_idx = 1; child_idx < result->bound_children.size(); ++child_idx) { - auto &child_node = *result->bound_children[child_idx].node; - result_type = LogicalType::ForceMaxLogicalType(result_type, child_node.types[i]); + auto result_type = result.bound_children[0].GetTypes()[i]; + for (idx_t child_idx = 1; child_idx < result.bound_children.size(); ++child_idx) { + auto &child_types = result.bound_children[child_idx].GetTypes(); + result_type = LogicalType::ForceMaxLogicalType(result_type, child_types[i]); } if (!can_contain_nulls) { if (ExpressionBinder::ContainsNullType(result_type)) { result_type = ExpressionBinder::ExchangeNullType(result_type); } } - result->types.push_back(result_type); + result.types.push_back(result_type); } } SelectBindState bind_state; if (!statement.modifiers.empty()) { // handle the ORDER BY/DISTINCT clauses - - // we recursively visit the children of this node to extract aliases and expressions that can be referenced - // in the ORDER BYs - GatherAliases(*result, bind_state); + GatherAliases(result, bind_state); // now we perform the actual resolution of the ORDER BY/DISTINCT expressions - vector> binders; - for (auto &child : result->bound_children) { - GatherSetOpBinders(*child.node, *child.binder, binders); - } OrderBinder order_binder(binders, bind_state); - PrepareModifiers(order_binder, statement, *result); + PrepareModifiers(order_binder, statement, result); } // finally bind the types of the ORDER/DISTINCT clause expressions - BindModifiers(*result, result->setop_index, result->names, result->types, bind_state); - return std::move(result); + BindModifiers(result, result.setop_index, result.names, result.types, bind_state); + return result_ptr; +} + +BoundStatement Binder::BindNode(SetOperationNode &statement) { + auto result = BindSetOpNode(statement); + + BoundStatement result_statement; + result_statement.types = result->types; + result_statement.names = result->names; + result_statement.plan = CreatePlan(*result); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp new file mode 100644 index 000000000..6f6f9941a --- /dev/null +++ b/src/duckdb/src/planner/binder/query_node/bind_statement_node.cpp @@ -0,0 +1,26 @@ +#include "duckdb/parser/query_node/statement_node.hpp" +#include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/update_statement.hpp" +#include "duckdb/parser/statement/delete_statement.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" +#include "duckdb/planner/binder.hpp" + +namespace duckdb { + +BoundStatement Binder::BindNode(StatementNode &statement) { + // switch on type here to ensure we bind WITHOUT ctes to prevent infinite recursion + switch (statement.stmt.type) { + case StatementType::INSERT_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::DELETE_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::UPDATE_STATEMENT: + return Bind(statement.stmt.Cast()); + case StatementType::MERGE_INTO_STATEMENT: + return Bind(statement.stmt.Cast()); + default: + return Bind(statement.stmt); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp index 5bd06c0e5..dc4cc8770 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp @@ -10,8 +10,8 @@ namespace duckdb { unique_ptr Binder::CreatePlan(BoundCTENode &node) { // Generate the logical plan for the cte_query and child. - auto cte_query = CreatePlan(*node.query); - auto cte_child = CreatePlan(*node.child); + auto cte_query = std::move(node.query.plan); + auto cte_child = std::move(node.child.plan); auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), std::move(cte_query), std::move(cte_child), node.materialized); @@ -23,38 +23,4 @@ unique_ptr Binder::CreatePlan(BoundCTENode &node) { return VisitQueryNode(node, std::move(root)); } -unique_ptr Binder::CreatePlan(BoundCTENode &node, unique_ptr base) { - // Generate the logical plan for the cte_query and child. - auto cte_query = CreatePlan(*node.query); - unique_ptr root; - if (node.child && node.child->type == QueryNodeType::CTE_NODE) { - root = CreatePlan(node.child->Cast(), std::move(base)); - } else if (node.child) { - root = CreatePlan(*node.child); - } else { - root = std::move(base); - } - - // Only keep the materialized CTE, if it is used - if (node.child_binder->bind_context.cte_references[node.ctename] && - *node.child_binder->bind_context.cte_references[node.ctename] > 0) { - - // Push the CTE through single-child operators so query modifiers appear ABOVE the CTE (internal issue #2652) - // Otherwise, we may have a LIMIT on top of the CTE, and an ORDER BY in the query, and we can't make a TopN - reference> cte_child = root; - while (cte_child.get()->children.size() == 1 && cte_child.get()->type != LogicalOperatorType::LOGICAL_CTE_REF) { - cte_child = cte_child.get()->children[0]; - } - cte_child.get() = - make_uniq(node.ctename, node.setop_index, node.types.size(), std::move(cte_query), - std::move(cte_child.get()), node.materialized); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || - node.child_binder->has_unplanned_dependent_joins || - node.query_binder->has_unplanned_dependent_joins; - } - return VisitQueryNode(node, std::move(root)); -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp index 4064136b6..dccbae4cd 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp @@ -13,16 +13,16 @@ unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { node.left_binder->is_outside_flattened = is_outside_flattened; node.right_binder->is_outside_flattened = is_outside_flattened; - auto left_node = node.left_binder->CreatePlan(*node.left); - auto right_node = node.right_binder->CreatePlan(*node.right); + auto left_node = std::move(node.left.plan); + auto right_node = std::move(node.right.plan); // check if there are any unplanned subqueries left in either child has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.left_binder->has_unplanned_dependent_joins || node.right_binder->has_unplanned_dependent_joins; // for both the left and right sides, cast them to the same types - left_node = CastLogicalOperatorToTypes(node.left->types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(node.right->types, node.types, std::move(right_node)); + left_node = CastLogicalOperatorToTypes(node.left.types, node.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(node.right.types, node.types, std::move(right_node)); bool ref_recurring = node.right_binder->bind_context.cte_references["recurring." + node.ctename] && *node.right_binder->bind_context.cte_references["recurring." + node.ctename] != 0; diff --git a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp index 9b0fa7c94..fec93aa51 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp @@ -10,8 +10,8 @@ namespace duckdb { // Optionally push a PROJECTION operator -unique_ptr Binder::CastLogicalOperatorToTypes(vector &source_types, - vector &target_types, +unique_ptr Binder::CastLogicalOperatorToTypes(const vector &source_types, + const vector &target_types, unique_ptr op) { D_ASSERT(op); // first check if we even need to cast @@ -114,10 +114,15 @@ unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { D_ASSERT(node.bound_children.size() >= 2); vector> children; for (auto &child : node.bound_children) { - child.binder->is_outside_flattened = is_outside_flattened; + unique_ptr child_node; + if (child.bound_node) { + child_node = CreatePlan(*child.bound_node); + } else { + child.binder->is_outside_flattened = is_outside_flattened; - // construct the logical plan for the child node - auto child_node = child.binder->CreatePlan(*child.node); + // construct the logical plan for the child node + child_node = std::move(child.node.plan); + } if (!child.reorder_expressions.empty()) { // if we have re-order expressions push a projection vector child_types; @@ -132,10 +137,10 @@ unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { child_node = CastLogicalOperatorToTypes(child_types, node.types, std::move(child_node)); } else { // otherwise push only casts - child_node = CastLogicalOperatorToTypes(child.node->types, node.types, std::move(child_node)); + child_node = CastLogicalOperatorToTypes(child.GetTypes(), node.types, std::move(child_node)); } // check if there are any unplanned subqueries left in any child - if (child.binder->has_unplanned_dependent_joins) { + if (child.binder && child.binder->has_unplanned_dependent_joins) { has_unplanned_dependent_joins = true; } children.push_back(std::move(child_node)); diff --git a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp index 6b979bf17..29a419ab7 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp @@ -367,7 +367,7 @@ unique_ptr Binder::PlanSubquery(BoundSubqueryExpression &expr, uniqu // first we translate the QueryNode of the subquery into a logical plan auto sub_binder = Binder::CreateBinder(context, this); sub_binder->is_outside_flattened = false; - auto subquery_root = sub_binder->CreatePlan(*expr.subquery); + auto subquery_root = std::move(expr.subquery.plan); D_ASSERT(subquery_root); // now we actually flatten the subquery diff --git a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp index da1dacb15..304ec793b 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp @@ -363,7 +363,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { // we bind the view subquery and the original view with different "can_contain_nulls", // but we don't want to throw an error when SQLNULL does not match up with INTEGER, // so we exchange all SQLNULL with INTEGER here before comparing - auto bound_types = ExchangeAllNullTypes(bound_subquery.subquery->types); + auto bound_types = ExchangeAllNullTypes(bound_subquery.subquery.types); auto view_types = ExchangeAllNullTypes(view_catalog_entry.types); if (bound_types != view_types) { auto actual_types = StringUtil::ToString(bound_types, ", "); @@ -372,17 +372,17 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { "Contents of view were altered: types don't match! Expected [%s], but found [%s] instead", expected_types, actual_types); } - if (bound_subquery.subquery->names.size() == view_catalog_entry.names.size() && - bound_subquery.subquery->names != view_catalog_entry.names) { - auto actual_names = StringUtil::Join(bound_subquery.subquery->names, ", "); + if (bound_subquery.subquery.names.size() == view_catalog_entry.names.size() && + bound_subquery.subquery.names != view_catalog_entry.names) { + auto actual_names = StringUtil::Join(bound_subquery.subquery.names, ", "); auto expected_names = StringUtil::Join(view_catalog_entry.names, ", "); throw BinderException( "Contents of view were altered: names don't match! Expected [%s], but found [%s] instead", expected_names, actual_names); } } - bind_context.AddView(bound_subquery.subquery->GetRootIndex(), subquery.alias, subquery, - *bound_subquery.subquery, view_catalog_entry); + bind_context.AddView(bound_subquery.subquery.plan->GetRootIndex(), subquery.alias, subquery, + bound_subquery.subquery, view_catalog_entry); return bound_child; } default: diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index b0d0fffcb..e77e93ce0 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -9,7 +9,6 @@ #include "duckdb/parser/expression/conjunction_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" #include "duckdb/common/types/value_map.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" @@ -21,6 +20,7 @@ #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" #include "duckdb/main/query_result.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { @@ -388,19 +388,22 @@ void ExtractPivotAggregates(BoundTableRef &node, vector> throw InternalException("Pivot - Expected a subquery"); } auto &subq = node.Cast(); - if (subq.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected a select node"); - } - auto &select = subq.subquery->Cast(); - if (select.from_table->type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected another subquery"); - } - auto &subq2 = select.from_table->Cast(); - if (subq2.subquery->type != QueryNodeType::SELECT_NODE) { - throw InternalException("Pivot - Expected another select node"); + reference op(*subq.subquery.plan); + bool found_first_aggregate = false; + while (true) { + if (op.get().type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + if (found_first_aggregate) { + break; + } + found_first_aggregate = true; + } + if (op.get().children.size() != 1) { + throw InternalException("Pivot - expected an aggregate"); + } + op = *op.get().children[0]; } - auto &select2 = subq2.subquery->Cast(); - for (auto &aggr : select2.aggregates) { + auto &aggr_op = op.get().Cast(); + for (auto &aggr : aggr_op.expressions) { if (aggr->GetAlias() == "__collated_group") { continue; } @@ -864,12 +867,10 @@ unique_ptr Binder::Bind(PivotRef &ref) { // bind the generated select node auto child_binder = Binder::CreateBinder(context, this); auto bound_select_node = child_binder->BindNode(*select_node); - auto root_index = bound_select_node->GetRootIndex(); - BoundQueryNode *bound_select_ptr = bound_select_node.get(); + auto root_index = bound_select_node.plan->GetRootIndex(); - unique_ptr result; MoveCorrelatedExpressions(*child_binder); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); + auto result = make_uniq(std::move(child_binder), std::move(bound_select_node)); auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; SubqueryRef subquery_ref(nullptr, subquery_alias); subquery_ref.column_name_alias = std::move(ref.column_name_alias); @@ -877,17 +878,16 @@ unique_ptr Binder::Bind(PivotRef &ref) { // if a WHERE clause was provided - bind a subquery holding the WHERE clause // we need to bind a new subquery here because the WHERE clause has to be applied AFTER the unnest child_binder = Binder::CreateBinder(context, this); - child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); + child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result->subquery); auto where_query = make_uniq(); where_query->select_list.push_back(make_uniq()); where_query->where_clause = std::move(where_clause); bound_select_node = child_binder->BindSelectNode(*where_query, std::move(result)); - bound_select_ptr = bound_select_node.get(); - root_index = bound_select_node->GetRootIndex(); + root_index = bound_select_node.plan->GetRootIndex(); result = make_uniq(std::move(child_binder), std::move(bound_select_node)); } - bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, *bound_select_ptr); - return result; + bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result->subquery); + return std::move(result); } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp index 9eed0ea61..503299d4e 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp @@ -9,7 +9,7 @@ unique_ptr Binder::Bind(SubqueryRef &ref) { binder->can_contain_nulls = true; auto subquery = binder->BindNode(*ref.subquery->node); binder->alias = ref.alias.empty() ? "unnamed_subquery" : ref.alias; - idx_t bind_index = subquery->GetRootIndex(); + idx_t bind_index = subquery.plan->GetRootIndex(); string subquery_alias; if (ref.alias.empty()) { auto index = unnamed_subquery_index++; @@ -22,7 +22,7 @@ unique_ptr Binder::Bind(SubqueryRef &ref) { subquery_alias = ref.alias; } auto result = make_uniq(std::move(binder), std::move(subquery)); - bind_context.AddSubquery(bind_index, subquery_alias, ref, *result->subquery); + bind_context.AddSubquery(bind_index, subquery_alias, ref, result->subquery); MoveCorrelatedExpressions(*result->binder); return std::move(result); } diff --git a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp index 91bd5775d..57ef40c3f 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp @@ -104,7 +104,7 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi // bind table in-out function BindTableInTableOutFunction(expressions, subquery); // fetch the arguments from the subquery - arguments = subquery->subquery->types; + arguments = subquery->subquery.types; return true; } bool seen_subquery = false; @@ -388,7 +388,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { binder->can_contain_nulls = true; binder->alias = ref.alias.empty() ? "unnamed_query" : ref.alias; - unique_ptr query; + BoundStatement query; try { query = binder->BindNode(*query_node); } catch (std::exception &ex) { @@ -397,13 +397,13 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { error.Throw(); } - idx_t bind_index = query->GetRootIndex(); + idx_t bind_index = query.plan->GetRootIndex(); // string alias; string alias = (ref.alias.empty() ? "unnamed_query" + to_string(bind_index) : ref.alias); auto result = make_uniq(std::move(binder), std::move(query)); // remember ref here is TableFunctionRef and NOT base class - bind_context.AddSubquery(bind_index, alias, ref, *result->subquery); + bind_context.AddSubquery(bind_index, alias, ref, result->subquery); MoveCorrelatedExpressions(*result->binder); return std::move(result); } @@ -438,8 +438,8 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { vector input_table_names; if (subquery) { - input_table_types = subquery->subquery->types; - input_table_names = subquery->subquery->names; + input_table_types = subquery->subquery.types; + input_table_names = subquery->subquery.names; } else if (table_function.in_out_function) { for (auto ¶m : parameters) { input_table_types.push_back(param.type()); diff --git a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp index 821654460..745bff555 100644 --- a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp +++ b/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp @@ -7,7 +7,7 @@ unique_ptr Binder::CreatePlan(BoundSubqueryRef &ref) { // generate the logical plan for the subquery // this happens separately from the current LogicalPlan generation ref.binder->is_outside_flattened = is_outside_flattened; - auto subquery = ref.binder->CreatePlan(*ref.subquery); + auto subquery = std::move(ref.subquery.plan); if (ref.binder->has_unplanned_dependent_joins) { has_unplanned_dependent_joins = true; } diff --git a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp index d532a7a40..0b693558a 100644 --- a/src/duckdb/src/planner/expression_binder/lateral_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/lateral_binder.cpp @@ -3,7 +3,7 @@ #include "duckdb/planner/logical_operator_visitor.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_subquery_expression.hpp" -#include "duckdb/planner/tableref/bound_joinref.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" namespace duckdb { @@ -79,34 +79,35 @@ static void ReduceColumnDepth(CorrelatedColumns &columns, const CorrelatedColumn } } -class ExpressionDepthReducerRecursive : public BoundNodeVisitor { +class ExpressionDepthReducerRecursive : public LogicalOperatorVisitor { public: explicit ExpressionDepthReducerRecursive(const CorrelatedColumns &correlated) : correlated_columns(correlated) { } - void VisitExpression(unique_ptr &expression) override { - if (expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - ReduceColumnRefDepth(expression->Cast(), correlated_columns); - } else if (expression->GetExpressionType() == ExpressionType::SUBQUERY) { - ReduceExpressionSubquery(expression->Cast(), correlated_columns); + void VisitExpression(unique_ptr *expression) override { + auto &expr = **expression; + if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + ReduceColumnRefDepth(expr.Cast(), correlated_columns); + } else if (expr.GetExpressionType() == ExpressionType::SUBQUERY) { + ReduceExpressionSubquery(expr.Cast(), correlated_columns); } - BoundNodeVisitor::VisitExpression(expression); + LogicalOperatorVisitor::VisitExpression(expression); } - void VisitBoundTableRef(BoundTableRef &ref) override { - if (ref.type == TableReferenceType::JOIN) { + void VisitOperator(LogicalOperator &op) override { + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { // rewrite correlated columns in child joins - auto &bound_join = ref.Cast(); + auto &bound_join = op.Cast(); ReduceColumnDepth(bound_join.correlated_columns, correlated_columns); } // visit the children of the table ref - BoundNodeVisitor::VisitBoundTableRef(ref); + LogicalOperatorVisitor::VisitOperator(op); } static void ReduceExpressionSubquery(BoundSubqueryExpression &expr, const CorrelatedColumns &correlated_columns) { ReduceColumnDepth(expr.binder->correlated_columns, correlated_columns); ExpressionDepthReducerRecursive recursive(correlated_columns); - recursive.VisitBoundQueryNode(*expr.subquery); + recursive.VisitOperator(*expr.subquery.plan); } private: diff --git a/src/duckdb/src/planner/expression_iterator.cpp b/src/duckdb/src/planner/expression_iterator.cpp index 042712732..9f67f915c 100644 --- a/src/duckdb/src/planner/expression_iterator.cpp +++ b/src/duckdb/src/planner/expression_iterator.cpp @@ -183,156 +183,4 @@ void ExpressionIterator::VisitExpressionClassMutable( *expr, [&](unique_ptr &child) { VisitExpressionClassMutable(child, expr_class, callback); }); } -void BoundNodeVisitor::VisitExpression(unique_ptr &expression) { - VisitExpressionChildren(*expression); -} - -void BoundNodeVisitor::VisitExpressionChildren(Expression &expr) { - ExpressionIterator::EnumerateChildren(expr, [&](unique_ptr &expr) { VisitExpression(expr); }); -} - -void BoundNodeVisitor::VisitBoundQueryNode(BoundQueryNode &node) { - switch (node.type) { - case QueryNodeType::SET_OPERATION_NODE: { - auto &bound_setop = node.Cast(); - for (auto &child : bound_setop.bound_children) { - VisitBoundQueryNode(*child.node); - } - break; - } - case QueryNodeType::RECURSIVE_CTE_NODE: { - auto &cte_node = node.Cast(); - VisitBoundQueryNode(*cte_node.left); - VisitBoundQueryNode(*cte_node.right); - break; - } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - VisitBoundQueryNode(*cte_node.child); - VisitBoundQueryNode(*cte_node.query); - break; - } - case QueryNodeType::SELECT_NODE: { - auto &bound_select = node.Cast(); - for (auto &expr : bound_select.select_list) { - VisitExpression(expr); - } - if (bound_select.where_clause) { - VisitExpression(bound_select.where_clause); - } - for (auto &expr : bound_select.groups.group_expressions) { - VisitExpression(expr); - } - if (bound_select.having) { - VisitExpression(bound_select.having); - } - for (auto &expr : bound_select.aggregates) { - VisitExpression(expr); - } - for (auto &entry : bound_select.unnests) { - for (auto &expr : entry.second.expressions) { - VisitExpression(expr); - } - } - for (auto &expr : bound_select.windows) { - VisitExpression(expr); - } - if (bound_select.from_table) { - VisitBoundTableRef(*bound_select.from_table); - } - break; - } - default: - throw NotImplementedException("Unimplemented query node in ExpressionIterator"); - } - for (idx_t i = 0; i < node.modifiers.size(); i++) { - switch (node.modifiers[i]->type) { - case ResultModifierType::DISTINCT_MODIFIER: - for (auto &expr : node.modifiers[i]->Cast().target_distincts) { - VisitExpression(expr); - } - break; - case ResultModifierType::ORDER_MODIFIER: - for (auto &order : node.modifiers[i]->Cast().orders) { - VisitExpression(order.expression); - } - break; - case ResultModifierType::LIMIT_MODIFIER: { - auto &limit_expr = node.modifiers[i]->Cast().limit_val.GetExpression(); - auto &offset_expr = node.modifiers[i]->Cast().offset_val.GetExpression(); - if (limit_expr) { - VisitExpression(limit_expr); - } - if (offset_expr) { - VisitExpression(offset_expr); - } - break; - } - default: - break; - } - } -} - -class LogicalBoundNodeVisitor : public LogicalOperatorVisitor { -public: - explicit LogicalBoundNodeVisitor(BoundNodeVisitor &parent) : parent(parent) { - } - - void VisitExpression(unique_ptr *expression) override { - auto &expr = **expression; - parent.VisitExpression(*expression); - VisitExpressionChildren(expr); - } - -protected: - BoundNodeVisitor &parent; -}; - -void BoundNodeVisitor::VisitBoundTableRef(BoundTableRef &ref) { - switch (ref.type) { - case TableReferenceType::EXPRESSION_LIST: { - auto &bound_expr_list = ref.Cast(); - for (auto &expr_list : bound_expr_list.values) { - for (auto &expr : expr_list) { - VisitExpression(expr); - } - } - break; - } - case TableReferenceType::JOIN: { - auto &bound_join = ref.Cast(); - if (bound_join.condition) { - VisitExpression(bound_join.condition); - } - VisitBoundTableRef(*bound_join.left); - VisitBoundTableRef(*bound_join.right); - break; - } - case TableReferenceType::SUBQUERY: { - auto &bound_subquery = ref.Cast(); - VisitBoundQueryNode(*bound_subquery.subquery); - break; - } - case TableReferenceType::TABLE_FUNCTION: { - auto &bound_table_function = ref.Cast(); - LogicalBoundNodeVisitor node_visitor(*this); - if (bound_table_function.get) { - node_visitor.VisitOperator(*bound_table_function.get); - } - if (bound_table_function.subquery) { - VisitBoundTableRef(*bound_table_function.subquery); - } - break; - } - case TableReferenceType::EMPTY_FROM: - case TableReferenceType::BASE_TABLE: - case TableReferenceType::CTE: - break; - default: - throw NotImplementedException("Unimplemented table reference type (%s) in ExpressionIterator", - EnumUtil::ToString(ref.type)); - } -} - } // namespace duckdb diff --git a/src/duckdb/src/planner/logical_operator.cpp b/src/duckdb/src/planner/logical_operator.cpp index e16062573..016b7d605 100644 --- a/src/duckdb/src/planner/logical_operator.cpp +++ b/src/duckdb/src/planner/logical_operator.cpp @@ -31,6 +31,19 @@ vector LogicalOperator::GetColumnBindings() { return {ColumnBinding(0, 0)}; } +idx_t LogicalOperator::GetRootIndex() { + auto bindings = GetColumnBindings(); + if (bindings.empty()) { + throw InternalException("Empty bindings in GetRootIndex"); + } + auto root_index = bindings[0].table_index; + for (idx_t i = 1; i < bindings.size(); i++) { + if (bindings[i].table_index != root_index) { + throw InternalException("GetRootIndex - multiple column bindings found"); + } + } + return root_index; +} void LogicalOperator::SetParamsEstimatedCardinality(InsertionOrderPreservingMap &result) const { if (has_estimated_cardinality) { result[RenderTreeNode::ESTIMATED_CARDINALITY] = StringUtil::Format("%llu", estimated_cardinality); diff --git a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp index 903840dda..d376eec81 100644 --- a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp @@ -71,14 +71,14 @@ unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundColumnRef } //! Helper class used to recursively rewrite correlated expressions within nested subqueries. -class RewriteCorrelatedRecursive : public BoundNodeVisitor { +class RewriteCorrelatedRecursive : public LogicalOperatorVisitor { public: RewriteCorrelatedRecursive(ColumnBinding base_binding, column_binding_map_t &correlated_map); - void VisitBoundTableRef(BoundTableRef &ref) override; - void VisitExpression(unique_ptr &expression) override; + void VisitOperator(LogicalOperator &op) override; + void VisitExpression(unique_ptr *expression) override; - void RewriteCorrelatedSubquery(Binder &binder, BoundQueryNode &subquery); + void RewriteCorrelatedSubquery(Binder &binder, LogicalOperator &subquery); ColumnBinding base_binding; column_binding_map_t &correlated_map; @@ -92,7 +92,7 @@ unique_ptr RewriteCorrelatedExpressions::VisitReplace(BoundSubqueryE // subquery detected within this subquery // recursively rewrite it using the RewriteCorrelatedRecursive class RewriteCorrelatedRecursive rewrite(base_binding, correlated_map); - rewrite.RewriteCorrelatedSubquery(*expr.binder, *expr.subquery); + rewrite.RewriteCorrelatedSubquery(*expr.binder, *expr.subquery.plan); return nullptr; } @@ -101,40 +101,30 @@ RewriteCorrelatedRecursive::RewriteCorrelatedRecursive(ColumnBinding base_bindin : base_binding(base_binding), correlated_map(correlated_map) { } -void RewriteCorrelatedRecursive::VisitBoundTableRef(BoundTableRef &ref) { - if (ref.type == TableReferenceType::JOIN) { +void RewriteCorrelatedRecursive::VisitOperator(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_DEPENDENT_JOIN) { // rewrite correlated columns in child joins - auto &bound_join = ref.Cast(); - for (auto &corr : bound_join.correlated_columns) { + auto &dep_join = op.Cast(); + for (auto &corr : dep_join.correlated_columns) { auto entry = correlated_map.find(corr.binding); if (entry != correlated_map.end()) { corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); } } - } else if (ref.type == TableReferenceType::SUBQUERY) { - auto &subquery = ref.Cast(); - RewriteCorrelatedSubquery(*subquery.binder, *subquery.subquery); - return; } // visit the children of the table ref - BoundNodeVisitor::VisitBoundTableRef(ref); + LogicalOperatorVisitor::VisitOperator(op); } -void RewriteCorrelatedRecursive::RewriteCorrelatedSubquery(Binder &binder, BoundQueryNode &subquery) { - // rewrite the binding in the correlated list of the subquery) - for (auto &corr : binder.correlated_columns) { - auto entry = correlated_map.find(corr.binding); - if (entry != correlated_map.end()) { - corr.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); - } - } - VisitBoundQueryNode(subquery); +void RewriteCorrelatedRecursive::RewriteCorrelatedSubquery(Binder &binder, LogicalOperator &op) { + VisitOperator(op); } -void RewriteCorrelatedRecursive::VisitExpression(unique_ptr &expression) { - if (expression->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { +void RewriteCorrelatedRecursive::VisitExpression(unique_ptr *expression) { + auto &expr = **expression; + if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { // bound column reference - auto &bound_colref = expression->Cast(); + auto &bound_colref = expr.Cast(); if (bound_colref.depth == 0) { // not a correlated column, ignore return; @@ -148,13 +138,13 @@ void RewriteCorrelatedRecursive::VisitExpression(unique_ptr &express bound_colref.binding = ColumnBinding(base_binding.table_index, base_binding.column_index + entry->second); bound_colref.depth--; } - } else if (expression->GetExpressionType() == ExpressionType::SUBQUERY) { + } else if (expr.GetExpressionType() == ExpressionType::SUBQUERY) { // we encountered another subquery: rewrite recursively - auto &bound_subquery = expression->Cast(); - RewriteCorrelatedSubquery(*bound_subquery.binder, *bound_subquery.subquery); + auto &bound_subquery = expr.Cast(); + RewriteCorrelatedSubquery(*bound_subquery.binder, *bound_subquery.subquery.plan); } // recurse into the children of this subquery - BoundNodeVisitor::VisitExpression(expression); + LogicalOperatorVisitor::VisitExpression(expression); } RewriteCountAggregates::RewriteCountAggregates(column_binding_map_t &replacement_map) diff --git a/src/duckdb/ub_src_parser_query_node.cpp b/src/duckdb/ub_src_parser_query_node.cpp index f0fefe80e..131571749 100644 --- a/src/duckdb/ub_src_parser_query_node.cpp +++ b/src/duckdb/ub_src_parser_query_node.cpp @@ -6,3 +6,5 @@ #include "src/parser/query_node/set_operation_node.cpp" +#include "src/parser/query_node/statement_node.cpp" + diff --git a/src/duckdb/ub_src_planner_binder_query_node.cpp b/src/duckdb/ub_src_planner_binder_query_node.cpp index 2250c80ca..3ec7b7ecb 100644 --- a/src/duckdb/ub_src_planner_binder_query_node.cpp +++ b/src/duckdb/ub_src_planner_binder_query_node.cpp @@ -6,6 +6,8 @@ #include "src/planner/binder/query_node/bind_cte_node.cpp" +#include "src/planner/binder/query_node/bind_statement_node.cpp" + #include "src/planner/binder/query_node/bind_table_macro_node.cpp" #include "src/planner/binder/query_node/plan_query_node.cpp" From e1211e79b14491acc309d6d69668fb142567296d Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Thu, 9 Oct 2025 05:22:35 +0000 Subject: [PATCH 4/6] Update vendored DuckDB sources to 353406bd7f --- .../json/json_functions/json_create.cpp | 1 + .../extension/parquet/column_writer.cpp | 2 +- src/duckdb/extension/parquet/geo_parquet.cpp | 202 +----- .../extension/parquet/include/geo_parquet.hpp | 170 +---- .../include/writer/parquet_write_stats.hpp | 30 +- .../extension/parquet/parquet_metadata.cpp | 15 +- .../extension/parquet/parquet_writer.cpp | 58 +- src/duckdb/src/common/enum_util.cpp | 60 +- src/duckdb/src/common/enums/metric_type.cpp | 6 + .../src/common/enums/optimizer_type.cpp | 2 + .../src/common/sort/partition_state.cpp | 671 ------------------ src/duckdb/src/common/sorting/hashed_sort.cpp | 126 +++- src/duckdb/src/common/types/geometry.cpp | 112 +++ .../operator/join/physical_asof_join.cpp | 646 +++++++++++------ .../function/table/version/pragma_version.cpp | 6 +- .../src/include/duckdb/common/enum_util.hpp | 18 +- .../duckdb/common/enums/metric_type.hpp | 1 + .../duckdb/common/enums/optimizer_type.hpp | 1 + .../duckdb/common/sort/partition_state.hpp | 245 ------- .../duckdb/common/sorting/hashed_sort.hpp | 11 +- .../include/duckdb/common/types/geometry.hpp | 171 ++++- .../operator/join/physical_range_join.hpp | 2 +- src/duckdb/src/include/duckdb/main/config.hpp | 2 + .../src/include/duckdb/main/settings.hpp | 12 + .../optimizer/topn_window_elimination.hpp | 68 ++ .../duckdb/parser/parsed_data/vacuum_info.hpp | 1 - .../include/duckdb/parser/parser_options.hpp | 1 + .../parser/tableref/bound_ref_wrapper.hpp | 5 +- .../include/duckdb/planner/bind_context.hpp | 17 +- .../src/include/duckdb/planner/binder.hpp | 129 ++-- .../include/duckdb/planner/bound_tableref.hpp | 46 -- .../include/duckdb/planner/bound_tokens.hpp | 11 - .../expression_binder/select_bind_state.hpp | 1 - .../duckdb/planner/expression_iterator.hpp | 1 - .../planner/query_node/bound_select_node.hpp | 3 +- .../include/duckdb/planner/table_binding.hpp | 13 +- .../planner/tableref/bound_basetableref.hpp | 30 - .../tableref/bound_column_data_ref.hpp | 30 - .../duckdb/planner/tableref/bound_cteref.hpp | 40 -- .../planner/tableref/bound_delimgetref.hpp | 26 - .../planner/tableref/bound_dummytableref.hpp | 26 - .../tableref/bound_expressionlistref.hpp | 33 - .../duckdb/planner/tableref/bound_joinref.hpp | 13 +- .../planner/tableref/bound_pivotref.hpp | 11 +- .../planner/tableref/bound_pos_join_ref.hpp | 38 - .../planner/tableref/bound_subqueryref.hpp | 32 - .../planner/tableref/bound_table_function.hpp | 31 - .../include/duckdb/planner/tableref/list.hpp | 9 - .../storage/statistics/base_statistics.hpp | 14 +- .../storage/statistics/geometry_stats.hpp | 144 ++++ .../duckdb/storage/string_uncompressed.hpp | 6 +- src/duckdb/src/main/client_context.cpp | 1 + src/duckdb/src/main/config.cpp | 13 +- src/duckdb/src/main/query_profiler.cpp | 6 +- .../main/settings/autogenerated_settings.cpp | 22 + .../src/main/settings/custom_settings.cpp | 18 + src/duckdb/src/optimizer/optimizer.cpp | 7 + .../src/optimizer/topn_window_elimination.cpp | 592 +++++++++++++++ src/duckdb/src/parser/parser.cpp | 17 +- .../src/parser/tableref/bound_ref_wrapper.cpp | 2 +- src/duckdb/src/planner/bind_context.cpp | 22 +- src/duckdb/src/planner/binder.cpp | 150 ++-- .../expression/bind_parameter_expression.cpp | 10 +- .../binder/query_node/bind_cte_node.cpp | 1 - .../query_node/bind_recursive_cte_node.cpp | 4 - .../binder/query_node/bind_select_node.cpp | 7 +- .../query_node/plan_recursive_cte_node.cpp | 11 +- .../binder/query_node/plan_select_node.cpp | 6 +- .../planner/binder/statement/bind_attach.cpp | 1 - .../planner/binder/statement/bind_call.cpp | 2 - .../planner/binder/statement/bind_create.cpp | 19 +- .../planner/binder/statement/bind_delete.cpp | 28 +- .../planner/binder/statement/bind_drop.cpp | 1 - .../planner/binder/statement/bind_execute.cpp | 2 +- .../binder/statement/bind_extension.cpp | 6 +- .../planner/binder/statement/bind_insert.cpp | 5 - .../binder/statement/bind_logical_plan.cpp | 2 +- .../binder/statement/bind_merge_into.cpp | 12 +- .../planner/binder/statement/bind_prepare.cpp | 2 +- .../planner/binder/statement/bind_simple.cpp | 9 +- .../binder/statement/bind_summarize.cpp | 3 +- .../planner/binder/statement/bind_update.cpp | 19 +- .../planner/binder/statement/bind_vacuum.cpp | 19 +- .../binder/tableref/bind_basetableref.cpp | 125 ++-- .../binder/tableref/bind_bound_table_ref.cpp | 4 +- .../binder/tableref/bind_column_data_ref.cpp | 21 +- .../binder/tableref/bind_delimgetref.cpp | 13 +- .../binder/tableref/bind_emptytableref.cpp | 8 +- .../tableref/bind_expressionlistref.cpp | 63 +- .../planner/binder/tableref/bind_joinref.cpp | 12 +- .../planner/binder/tableref/bind_pivot.cpp | 60 +- .../planner/binder/tableref/bind_showref.cpp | 19 +- .../binder/tableref/bind_subqueryref.cpp | 15 +- .../binder/tableref/bind_table_function.cpp | 103 +-- .../binder/tableref/plan_basetableref.cpp | 11 - .../binder/tableref/plan_column_data_ref.cpp | 15 - .../planner/binder/tableref/plan_cteref.cpp | 11 - .../binder/tableref/plan_delimgetref.cpp | 11 - .../binder/tableref/plan_dummytableref.cpp | 11 - .../tableref/plan_expressionlistref.cpp | 27 - .../planner/binder/tableref/plan_joinref.cpp | 4 +- .../planner/binder/tableref/plan_pivotref.cpp | 13 - .../binder/tableref/plan_subqueryref.cpp | 17 - .../binder/tableref/plan_table_function.cpp | 28 - .../src/planner/operator/logical_vacuum.cpp | 11 +- src/duckdb/src/planner/planner.cpp | 2 +- .../rewrite_correlated_expressions.cpp | 1 - src/duckdb/src/planner/table_binding.cpp | 4 + src/duckdb/src/storage/checkpoint_manager.cpp | 1 - .../src/storage/compression/dict_fsst.cpp | 6 +- src/duckdb/src/storage/compression/fsst.cpp | 5 +- .../storage/statistics/base_statistics.cpp | 30 + .../src/storage/statistics/geometry_stats.cpp | 171 +++++ src/duckdb/ub_src_common_sort.cpp | 2 - src/duckdb/ub_src_optimizer.cpp | 2 + src/duckdb/ub_src_planner_binder_tableref.cpp | 18 - src/duckdb/ub_src_storage_statistics.cpp | 2 + 117 files changed, 2560 insertions(+), 2654 deletions(-) delete mode 100644 src/duckdb/src/common/sort/partition_state.cpp delete mode 100644 src/duckdb/src/include/duckdb/common/sort/partition_state.hpp create mode 100644 src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/bound_tableref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp create mode 100644 src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp create mode 100644 src/duckdb/src/optimizer/topn_window_elimination.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_cteref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp delete mode 100644 src/duckdb/src/planner/binder/tableref/plan_table_function.cpp create mode 100644 src/duckdb/src/storage/statistics/geometry_stats.cpp diff --git a/src/duckdb/extension/json/json_functions/json_create.cpp b/src/duckdb/extension/json/json_functions/json_create.cpp index 4cd00249c..8387ef750 100644 --- a/src/duckdb/extension/json/json_functions/json_create.cpp +++ b/src/duckdb/extension/json/json_functions/json_create.cpp @@ -616,6 +616,7 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m case LogicalTypeId::VALIDITY: case LogicalTypeId::TABLE: case LogicalTypeId::LAMBDA: + case LogicalTypeId::GEOMETRY: // TODO! Add support for GEOMETRY throw InternalException("Unsupported type arrived at JSON create function"); } } diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 55bfa2007..304e21751 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -97,7 +97,7 @@ bool ColumnWriterStatistics::HasGeoStats() { return false; } -optional_ptr ColumnWriterStatistics::GetGeoStats() { +optional_ptr ColumnWriterStatistics::GetGeoStats() { return nullptr; } diff --git a/src/duckdb/extension/parquet/geo_parquet.cpp b/src/duckdb/extension/parquet/geo_parquet.cpp index bddc36b43..48e2b047f 100644 --- a/src/duckdb/extension/parquet/geo_parquet.cpp +++ b/src/duckdb/extension/parquet/geo_parquet.cpp @@ -16,171 +16,6 @@ namespace duckdb { using namespace duckdb_yyjson; // NOLINT -//------------------------------------------------------------------------------ -// WKB stats -//------------------------------------------------------------------------------ -namespace { - -class BinaryReader { -public: - const char *beg; - const char *end; - const char *ptr; - - BinaryReader(const char *beg, uint32_t len) : beg(beg), end(beg + len), ptr(beg) { - } - - template - T Read() { - if (ptr + sizeof(T) > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - T val; - memcpy(&val, ptr, sizeof(T)); - ptr += sizeof(T); - return val; - } - - void Skip(idx_t len) { - if (ptr + len > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - ptr += len; - } - - const char *Reserve(idx_t len) { - if (ptr + len > end) { - throw InvalidInputException("Unexpected end of WKB data"); - } - auto ret = ptr; - ptr += len; - return ret; - } - - bool IsAtEnd() const { - return ptr >= end; - } -}; - -} // namespace - -static void UpdateBoundsFromVertexArray(GeometryExtent &bbox, uint32_t flag, const char *vert_array, - uint32_t vert_count) { - switch (flag) { - case 0: { // XY - constexpr auto vert_width = sizeof(double) * 2; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[2]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - } - } break; - case 1: { // XYZ - constexpr auto vert_width = sizeof(double) * 3; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[3]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendZ(vert[2]); - } - } break; - case 2: { // XYM - constexpr auto vert_width = sizeof(double) * 3; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[3]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendM(vert[2]); - } - } break; - case 3: { // XYZM - constexpr auto vert_width = sizeof(double) * 4; - for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { - double vert[4]; - memcpy(vert, vert_array + vert_idx * vert_width, vert_width); - bbox.ExtendX(vert[0]); - bbox.ExtendY(vert[1]); - bbox.ExtendZ(vert[2]); - bbox.ExtendM(vert[3]); - } - } break; - default: - break; - } -} - -void GeometryStats::Update(const string_t &wkb) { - BinaryReader reader(wkb.GetData(), wkb.GetSize()); - - bool first_geom = true; - while (!reader.IsAtEnd()) { - reader.Read(); // byte order - auto type = reader.Read(); - auto kind = type % 1000; - auto flag = type / 1000; - const auto hasz = (flag & 0x01) != 0; - const auto hasm = (flag & 0x02) != 0; - - if (first_geom) { - // Only add the top-level geometry type - types.Add(type); - first_geom = false; - } - - const auto vert_width = sizeof(double) * (2 + (hasz ? 1 : 0) + (hasm ? 1 : 0)); - - switch (kind) { - case 1: { // POINT - - // Point are special in that they are considered "empty" if they are all-nan - const auto vert_array = reader.Reserve(vert_width); - const auto dims_count = 2 + (hasz ? 1 : 0) + (hasm ? 1 : 0); - double vert_point[4] = {0, 0, 0, 0}; - - memcpy(vert_point, vert_array, vert_width); - - for (auto dim_idx = 0; dim_idx < dims_count; dim_idx++) { - if (!std::isnan(vert_point[dim_idx])) { - bbox.ExtendX(vert_point[0]); - bbox.ExtendY(vert_point[1]); - if (hasz && hasm) { - bbox.ExtendZ(vert_point[2]); - bbox.ExtendM(vert_point[3]); - } else if (hasz) { - bbox.ExtendZ(vert_point[2]); - } else if (hasm) { - bbox.ExtendM(vert_point[2]); - } - break; - } - } - } break; - case 2: { // LINESTRING - const auto vert_count = reader.Read(); - const auto vert_array = reader.Reserve(vert_count * vert_width); - UpdateBoundsFromVertexArray(bbox, flag, vert_array, vert_count); - } break; - case 3: { // POLYGON - const auto ring_count = reader.Read(); - for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { - const auto vert_count = reader.Read(); - const auto vert_array = reader.Reserve(vert_count * vert_width); - UpdateBoundsFromVertexArray(bbox, flag, vert_array, vert_count); - } - } break; - case 4: // MULTIPOINT - case 5: // MULTILINESTRING - case 6: // MULTIPOLYGON - case 7: { // GEOMETRYCOLLECTION - reader.Skip(sizeof(uint32_t)); - } break; - } - } -} - //------------------------------------------------------------------------------ // GeoParquetFileMetadata //------------------------------------------------------------------------------ @@ -292,7 +127,7 @@ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_ } void GeoParquetFileMetadata::AddGeoParquetStats(const string &column_name, const LogicalType &type, - const GeometryStats &stats) { + const GeometryStatsData &stats) { // Lock the metadata lock_guard glock(write_lock); @@ -301,12 +136,10 @@ void GeoParquetFileMetadata::AddGeoParquetStats(const string &column_name, const if (it == geometry_columns.end()) { auto &column = geometry_columns[column_name]; - column.stats.types.Combine(stats.types); - column.stats.bbox.Combine(stats.bbox); + column.stats.Merge(stats); column.insertion_index = geometry_columns.size() - 1; } else { - it->second.stats.types.Combine(stats.types); - it->second.stats.bbox.Combine(stats.bbox); + it->second.stats.Merge(stats); } } @@ -315,7 +148,7 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) // GeoParquet does not support M or ZM coordinates. So remove any columns that have them. unordered_set invalid_columns; for (auto &column : geometry_columns) { - if (column.second.stats.bbox.HasM()) { + if (column.second.stats.extent.HasM()) { invalid_columns.insert(column.first); } } @@ -358,28 +191,29 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) const auto column_json = yyjson_mut_obj_add_obj(doc, json_columns, column.first.c_str()); yyjson_mut_obj_add_str(doc, column_json, "encoding", "WKB"); const auto geometry_types = yyjson_mut_obj_add_arr(doc, column_json, "geometry_types"); + for (auto &type_name : column.second.stats.types.ToString(false)) { yyjson_mut_arr_add_strcpy(doc, geometry_types, type_name.c_str()); } - const auto &bbox = column.second.stats.bbox; + const auto &bbox = column.second.stats.extent; - if (bbox.IsSet()) { + if (bbox.HasXY()) { const auto bbox_arr = yyjson_mut_obj_add_arr(doc, column_json, "bbox"); - if (!column.second.stats.bbox.HasZ()) { - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymax); + if (!column.second.stats.extent.HasZ()) { + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_max); } else { - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.zmin); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.xmax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.ymax); - yyjson_mut_arr_add_real(doc, bbox_arr, bbox.zmax); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.z_min); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.x_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.y_max); + yyjson_mut_arr_add_real(doc, bbox_arr, bbox.z_max); } } diff --git a/src/duckdb/extension/parquet/include/geo_parquet.hpp b/src/duckdb/extension/parquet/include/geo_parquet.hpp index 6dc82bc8d..424e7c324 100644 --- a/src/duckdb/extension/parquet/include/geo_parquet.hpp +++ b/src/duckdb/extension/parquet/include/geo_parquet.hpp @@ -18,172 +18,6 @@ namespace duckdb { struct ParquetColumnSchema; - -struct GeometryKindSet { - - uint8_t bits[4] = {0, 0, 0, 0}; - - void Add(uint32_t wkb_type) { - auto kind = wkb_type % 1000; - auto dims = wkb_type / 1000; - if (kind < 1 || kind > 7 || (dims) > 3) { - return; - } - bits[dims] |= (1 << (kind - 1)); - } - - void Combine(const GeometryKindSet &other) { - for (uint32_t d = 0; d < 4; d++) { - bits[d] |= other.bits[d]; - } - } - - bool IsEmpty() const { - for (uint32_t d = 0; d < 4; d++) { - if (bits[d] != 0) { - return false; - } - } - return true; - } - - template - vector ToList() const { - vector result; - for (uint32_t d = 0; d < 4; d++) { - for (uint32_t i = 1; i <= 7; i++) { - if (bits[d] & (1 << (i - 1))) { - result.push_back(i + d * 1000); - } - } - } - return result; - } - - vector ToString(bool snake_case) const { - vector result; - for (uint32_t d = 0; d < 4; d++) { - for (uint32_t i = 1; i <= 7; i++) { - if (bits[d] & (1 << (i - 1))) { - string str; - switch (i) { - case 1: - str = snake_case ? "point" : "Point"; - break; - case 2: - str = snake_case ? "linestring" : "LineString"; - break; - case 3: - str = snake_case ? "polygon" : "Polygon"; - break; - case 4: - str = snake_case ? "multipoint" : "MultiPoint"; - break; - case 5: - str = snake_case ? "multilinestring" : "MultiLineString"; - break; - case 6: - str = snake_case ? "multipolygon" : "MultiPolygon"; - break; - case 7: - str = snake_case ? "geometrycollection" : "GeometryCollection"; - break; - default: - str = snake_case ? "unknown" : "Unknown"; - break; - } - switch (d) { - case 1: - str += snake_case ? "_z" : " Z"; - break; - case 2: - str += snake_case ? "_m" : " M"; - break; - case 3: - str += snake_case ? "_zm" : " ZM"; - break; - default: - break; - } - - result.push_back(str); - } - } - } - return result; - } -}; - -struct GeometryExtent { - - double xmin = NumericLimits::Maximum(); - double xmax = NumericLimits::Minimum(); - double ymin = NumericLimits::Maximum(); - double ymax = NumericLimits::Minimum(); - double zmin = NumericLimits::Maximum(); - double zmax = NumericLimits::Minimum(); - double mmin = NumericLimits::Maximum(); - double mmax = NumericLimits::Minimum(); - - bool IsSet() const { - return xmin != NumericLimits::Maximum() && xmax != NumericLimits::Minimum() && - ymin != NumericLimits::Maximum() && ymax != NumericLimits::Minimum(); - } - - bool HasZ() const { - return zmin != NumericLimits::Maximum() && zmax != NumericLimits::Minimum(); - } - - bool HasM() const { - return mmin != NumericLimits::Maximum() && mmax != NumericLimits::Minimum(); - } - - void Combine(const GeometryExtent &other) { - xmin = std::min(xmin, other.xmin); - xmax = std::max(xmax, other.xmax); - ymin = std::min(ymin, other.ymin); - ymax = std::max(ymax, other.ymax); - zmin = std::min(zmin, other.zmin); - zmax = std::max(zmax, other.zmax); - mmin = std::min(mmin, other.mmin); - mmax = std::max(mmax, other.mmax); - } - - void Combine(const double &xmin_p, const double &xmax_p, const double &ymin_p, const double &ymax_p) { - xmin = std::min(xmin, xmin_p); - xmax = std::max(xmax, xmax_p); - ymin = std::min(ymin, ymin_p); - ymax = std::max(ymax, ymax_p); - } - - void ExtendX(const double &x) { - xmin = std::min(xmin, x); - xmax = std::max(xmax, x); - } - void ExtendY(const double &y) { - ymin = std::min(ymin, y); - ymax = std::max(ymax, y); - } - void ExtendZ(const double &z) { - zmin = std::min(zmin, z); - zmax = std::max(zmax, z); - } - void ExtendM(const double &m) { - mmin = std::min(mmin, m); - mmax = std::max(mmax, m); - } -}; - -struct GeometryStats { - GeometryKindSet types; - GeometryExtent bbox; - - void Update(const string_t &wkb); -}; - -//------------------------------------------------------------------------------ -// GeoParquetMetadata -//------------------------------------------------------------------------------ class ParquetReader; class ColumnReader; class ClientContext; @@ -204,7 +38,7 @@ struct GeoParquetColumnMetadata { GeoParquetColumnEncoding geometry_encoding; // The statistics of the geometry column - GeometryStats stats; + GeometryStatsData stats; // The crs of the geometry column (if any) in PROJJSON format string projjson; @@ -215,7 +49,7 @@ struct GeoParquetColumnMetadata { class GeoParquetFileMetadata { public: - void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStats &stats); + void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStatsData &stats); void Write(duckdb_parquet::FileMetaData &file_meta_data); // Try to read GeoParquet metadata. Returns nullptr if not found, invalid or the required spatial extension is not diff --git a/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp b/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp index 1016c81fe..a74845ad2 100644 --- a/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp +++ b/src/duckdb/extension/parquet/include/writer/parquet_write_stats.hpp @@ -28,7 +28,7 @@ class ColumnWriterStatistics { virtual bool MaxIsExact(); virtual bool HasGeoStats(); - virtual optional_ptr GetGeoStats(); + virtual optional_ptr GetGeoStats(); virtual void WriteGeoStats(duckdb_parquet::GeospatialStatistics &stats); public: @@ -255,10 +255,11 @@ class UUIDStatisticsState : public ColumnWriterStatistics { class GeoStatisticsState final : public ColumnWriterStatistics { public: explicit GeoStatisticsState() : has_stats(false) { + geo_stats.SetEmpty(); } bool has_stats; - GeometryStats geo_stats; + GeometryStatsData geo_stats; public: void Update(const string_t &val) { @@ -268,37 +269,36 @@ class GeoStatisticsState final : public ColumnWriterStatistics { bool HasGeoStats() override { return has_stats; } - optional_ptr GetGeoStats() override { + optional_ptr GetGeoStats() override { return geo_stats; } void WriteGeoStats(duckdb_parquet::GeospatialStatistics &stats) override { const auto &types = geo_stats.types; - const auto &bbox = geo_stats.bbox; - - if (bbox.IsSet()) { + const auto &bbox = geo_stats.extent; + if (bbox.HasXY()) { stats.__isset.bbox = true; - stats.bbox.xmin = bbox.xmin; - stats.bbox.xmax = bbox.xmax; - stats.bbox.ymin = bbox.ymin; - stats.bbox.ymax = bbox.ymax; + stats.bbox.xmin = bbox.x_min; + stats.bbox.xmax = bbox.x_max; + stats.bbox.ymin = bbox.y_min; + stats.bbox.ymax = bbox.y_max; if (bbox.HasZ()) { stats.bbox.__isset.zmin = true; stats.bbox.__isset.zmax = true; - stats.bbox.zmin = bbox.zmin; - stats.bbox.zmax = bbox.zmax; + stats.bbox.zmin = bbox.z_min; + stats.bbox.zmax = bbox.z_max; } if (bbox.HasM()) { stats.bbox.__isset.mmin = true; stats.bbox.__isset.mmax = true; - stats.bbox.mmin = bbox.mmin; - stats.bbox.mmax = bbox.mmax; + stats.bbox.mmin = bbox.m_min; + stats.bbox.mmax = bbox.m_max; } } stats.__isset.geospatial_types = true; - stats.geospatial_types = types.ToList(); + stats.geospatial_types = types.ToWKBList(); } }; diff --git a/src/duckdb/extension/parquet/parquet_metadata.cpp b/src/duckdb/extension/parquet/parquet_metadata.cpp index 2f34efae2..af2c7f534 100644 --- a/src/duckdb/extension/parquet/parquet_metadata.cpp +++ b/src/duckdb/extension/parquet/parquet_metadata.cpp @@ -334,11 +334,20 @@ static Value ConvertParquetGeoStatsTypes(const duckdb_parquet::GeospatialStatist vector types; types.reserve(stats.geospatial_types.size()); - GeometryKindSet kind_set; + GeometryTypeSet type_set; for (auto &type : stats.geospatial_types) { - kind_set.Add(type); + const auto geom_type = (type % 1000); + const auto vert_type = (type / 1000); + if (geom_type < 1 || geom_type > 7) { + throw InvalidInputException("Unsupported geometry type in Parquet geo metadata"); + } + if (vert_type < 0 || vert_type > 3) { + throw InvalidInputException("Unsupported geometry vertex type in Parquet geo metadata"); + } + type_set.Add(static_cast(geom_type), static_cast(vert_type)); } - for (auto &type_name : kind_set.ToString(true)) { + + for (auto &type_name : type_set.ToString(true)) { types.push_back(Value(type_name)); } return Value::LIST(LogicalType::VARCHAR, types); diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 99a420242..a2e34b9e5 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -336,9 +336,9 @@ struct ColumnStatsUnifier { bool can_have_nan = false; bool has_nan = false; - unique_ptr geo_stats; + unique_ptr geo_stats; - virtual void UnifyGeoStats(const GeometryStats &other) { + virtual void UnifyGeoStats(const GeometryStatsData &other) { } virtual void UnifyMinMax(const string &new_min, const string &new_max) = 0; @@ -686,14 +686,13 @@ struct BlobStatsUnifier : public BaseStringStatsUnifier { struct GeoStatsUnifier : public ColumnStatsUnifier { - void UnifyGeoStats(const GeometryStats &other) override { + void UnifyGeoStats(const GeometryStatsData &other) override { if (geo_stats) { - geo_stats->bbox.Combine(other.bbox); - geo_stats->types.Combine(other.types); + geo_stats->Merge(other); } else { // Make copy - geo_stats = make_uniq(); - geo_stats->bbox = other.bbox; + geo_stats = make_uniq(); + geo_stats->extent = other.extent; geo_stats->types = other.types; } } @@ -707,17 +706,17 @@ struct GeoStatsUnifier : public ColumnStatsUnifier { return string(); } - const auto &bbox = geo_stats->bbox; + const auto &bbox = geo_stats->extent; const auto &types = geo_stats->types; - const auto bbox_value = Value::STRUCT({{"xmin", bbox.xmin}, - {"xmax", bbox.xmax}, - {"ymin", bbox.ymin}, - {"ymax", bbox.ymax}, - {"zmin", bbox.zmin}, - {"zmax", bbox.zmax}, - {"mmin", bbox.mmin}, - {"mmax", bbox.mmax}}); + const auto bbox_value = Value::STRUCT({{"xmin", bbox.x_min}, + {"xmax", bbox.x_max}, + {"ymin", bbox.y_min}, + {"ymax", bbox.y_max}, + {"zmin", bbox.z_min}, + {"zmax", bbox.z_max}, + {"mmin", bbox.m_min}, + {"mmax", bbox.m_max}}); vector type_strings; for (const auto &type : types.ToString(true)) { @@ -903,22 +902,25 @@ void ParquetWriter::GatherWrittenStatistics() { column_stats["has_nan"] = Value::BOOLEAN(stats_unifier->has_nan); } if (stats_unifier->geo_stats) { - const auto &bbox = stats_unifier->geo_stats->bbox; + const auto &bbox = stats_unifier->geo_stats->extent; const auto &types = stats_unifier->geo_stats->types; - column_stats["bbox_xmin"] = Value::DOUBLE(bbox.xmin); - column_stats["bbox_xmax"] = Value::DOUBLE(bbox.xmax); - column_stats["bbox_ymin"] = Value::DOUBLE(bbox.ymin); - column_stats["bbox_ymax"] = Value::DOUBLE(bbox.ymax); + if (bbox.HasXY()) { - if (bbox.HasZ()) { - column_stats["bbox_zmin"] = Value::DOUBLE(bbox.zmin); - column_stats["bbox_zmax"] = Value::DOUBLE(bbox.zmax); - } + column_stats["bbox_xmin"] = Value::DOUBLE(bbox.x_min); + column_stats["bbox_xmax"] = Value::DOUBLE(bbox.x_max); + column_stats["bbox_ymin"] = Value::DOUBLE(bbox.y_min); + column_stats["bbox_ymax"] = Value::DOUBLE(bbox.y_max); + + if (bbox.HasZ()) { + column_stats["bbox_zmin"] = Value::DOUBLE(bbox.z_min); + column_stats["bbox_zmax"] = Value::DOUBLE(bbox.z_max); + } - if (bbox.HasM()) { - column_stats["bbox_mmin"] = Value::DOUBLE(bbox.mmin); - column_stats["bbox_mmax"] = Value::DOUBLE(bbox.mmax); + if (bbox.HasM()) { + column_stats["bbox_mmin"] = Value::DOUBLE(bbox.m_min); + column_stats["bbox_mmax"] = Value::DOUBLE(bbox.m_max); + } } if (!types.IsEmpty()) { diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 6862e4295..1005c6f91 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -82,7 +82,6 @@ #include "duckdb/common/multi_file/multi_file_options.hpp" #include "duckdb/common/operator/decimal_cast_operators.hpp" #include "duckdb/common/printer.hpp" -#include "duckdb/common/sort/partition_state.hpp" #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/column/column_data_scan_states.hpp" @@ -2843,6 +2842,7 @@ const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { { static_cast(MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE), "OPTIMIZER_BUILD_SIDE_PROBE_SIDE" }, { static_cast(MetricsType::OPTIMIZER_LIMIT_PUSHDOWN), "OPTIMIZER_LIMIT_PUSHDOWN" }, { static_cast(MetricsType::OPTIMIZER_TOP_N), "OPTIMIZER_TOP_N" }, + { static_cast(MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION), "OPTIMIZER_TOP_N_WINDOW_ELIMINATION" }, { static_cast(MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION), "OPTIMIZER_COMPRESSED_MATERIALIZATION" }, { static_cast(MetricsType::OPTIMIZER_DUPLICATE_GROUPS), "OPTIMIZER_DUPLICATE_GROUPS" }, { static_cast(MetricsType::OPTIMIZER_REORDER_FILTER), "OPTIMIZER_REORDER_FILTER" }, @@ -2860,12 +2860,12 @@ const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { template<> const char* EnumUtil::ToChars(MetricsType value) { - return StringUtil::EnumToString(GetMetricsTypeValues(), 55, "MetricsType", static_cast(value)); + return StringUtil::EnumToString(GetMetricsTypeValues(), 56, "MetricsType", static_cast(value)); } template<> MetricsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 55, "MetricsType", value)); + return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 56, "MetricsType", value)); } const StringUtil::EnumStringLiteral *GetMultiFileColumnMappingModeValues() { @@ -3088,6 +3088,7 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::BUILD_SIDE_PROBE_SIDE), "BUILD_SIDE_PROBE_SIDE" }, { static_cast(OptimizerType::LIMIT_PUSHDOWN), "LIMIT_PUSHDOWN" }, { static_cast(OptimizerType::TOP_N), "TOP_N" }, + { static_cast(OptimizerType::TOP_N_WINDOW_ELIMINATION), "TOP_N_WINDOW_ELIMINATION" }, { static_cast(OptimizerType::COMPRESSED_MATERIALIZATION), "COMPRESSED_MATERIALIZATION" }, { static_cast(OptimizerType::DUPLICATE_GROUPS), "DUPLICATE_GROUPS" }, { static_cast(OptimizerType::REORDER_FILTER), "REORDER_FILTER" }, @@ -3105,12 +3106,12 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { template<> const char* EnumUtil::ToChars(OptimizerType value) { - return StringUtil::EnumToString(GetOptimizerTypeValues(), 30, "OptimizerType", static_cast(value)); + return StringUtil::EnumToString(GetOptimizerTypeValues(), 31, "OptimizerType", static_cast(value)); } template<> OptimizerType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 30, "OptimizerType", value)); + return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 31, "OptimizerType", value)); } const StringUtil::EnumStringLiteral *GetOrderByNullTypeValues() { @@ -3266,28 +3267,6 @@ ParserExtensionResultType EnumUtil::FromString(const return static_cast(StringUtil::StringToEnum(GetParserExtensionResultTypeValues(), 3, "ParserExtensionResultType", value)); } -const StringUtil::EnumStringLiteral *GetPartitionSortStageValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PartitionSortStage::INIT), "INIT" }, - { static_cast(PartitionSortStage::SCAN), "SCAN" }, - { static_cast(PartitionSortStage::PREPARE), "PREPARE" }, - { static_cast(PartitionSortStage::MERGE), "MERGE" }, - { static_cast(PartitionSortStage::SORTED), "SORTED" }, - { static_cast(PartitionSortStage::FINISHED), "FINISHED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PartitionSortStage value) { - return StringUtil::EnumToString(GetPartitionSortStageValues(), 6, "PartitionSortStage", static_cast(value)); -} - -template<> -PartitionSortStage EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPartitionSortStageValues(), 6, "PartitionSortStage", value)); -} - const StringUtil::EnumStringLiteral *GetPartitionedColumnDataTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(PartitionedColumnDataType::INVALID), "INVALID" }, @@ -4250,19 +4229,20 @@ const StringUtil::EnumStringLiteral *GetStatisticsTypeValues() { { static_cast(StatisticsType::LIST_STATS), "LIST_STATS" }, { static_cast(StatisticsType::STRUCT_STATS), "STRUCT_STATS" }, { static_cast(StatisticsType::BASE_STATS), "BASE_STATS" }, - { static_cast(StatisticsType::ARRAY_STATS), "ARRAY_STATS" } + { static_cast(StatisticsType::ARRAY_STATS), "ARRAY_STATS" }, + { static_cast(StatisticsType::GEOMETRY_STATS), "GEOMETRY_STATS" } }; return values; } template<> const char* EnumUtil::ToChars(StatisticsType value) { - return StringUtil::EnumToString(GetStatisticsTypeValues(), 6, "StatisticsType", static_cast(value)); + return StringUtil::EnumToString(GetStatisticsTypeValues(), 7, "StatisticsType", static_cast(value)); } template<> StatisticsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatisticsTypeValues(), 6, "StatisticsType", value)); + return static_cast(StringUtil::StringToEnum(GetStatisticsTypeValues(), 7, "StatisticsType", value)); } const StringUtil::EnumStringLiteral *GetStatsInfoValues() { @@ -4962,6 +4942,26 @@ VerifyExistenceType EnumUtil::FromString(const char *value) return static_cast(StringUtil::StringToEnum(GetVerifyExistenceTypeValues(), 3, "VerifyExistenceType", value)); } +const StringUtil::EnumStringLiteral *GetVertexTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(VertexType::XY), "XY" }, + { static_cast(VertexType::XYZ), "XYZ" }, + { static_cast(VertexType::XYM), "XYM" }, + { static_cast(VertexType::XYZM), "XYZM" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(VertexType value) { + return StringUtil::EnumToString(GetVertexTypeValues(), 4, "VertexType", static_cast(value)); +} + +template<> +VertexType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetVertexTypeValues(), 4, "VertexType", value)); +} + const StringUtil::EnumStringLiteral *GetWALTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(WALType::INVALID), "INVALID" }, diff --git a/src/duckdb/src/common/enums/metric_type.cpp b/src/duckdb/src/common/enums/metric_type.cpp index c2bb79f50..a4b7c7338 100644 --- a/src/duckdb/src/common/enums/metric_type.cpp +++ b/src/duckdb/src/common/enums/metric_type.cpp @@ -31,6 +31,7 @@ profiler_settings_t MetricsUtils::GetOptimizerMetrics() { MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE, MetricsType::OPTIMIZER_LIMIT_PUSHDOWN, MetricsType::OPTIMIZER_TOP_N, + MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION, MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION, MetricsType::OPTIMIZER_DUPLICATE_GROUPS, MetricsType::OPTIMIZER_REORDER_FILTER, @@ -96,6 +97,8 @@ MetricsType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { return MetricsType::OPTIMIZER_LIMIT_PUSHDOWN; case OptimizerType::TOP_N: return MetricsType::OPTIMIZER_TOP_N; + case OptimizerType::TOP_N_WINDOW_ELIMINATION: + return MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION; case OptimizerType::COMPRESSED_MATERIALIZATION: return MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION; case OptimizerType::DUPLICATE_GROUPS: @@ -161,6 +164,8 @@ OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricsType type) { return OptimizerType::LIMIT_PUSHDOWN; case MetricsType::OPTIMIZER_TOP_N: return OptimizerType::TOP_N; + case MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION: + return OptimizerType::TOP_N_WINDOW_ELIMINATION; case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: return OptimizerType::COMPRESSED_MATERIALIZATION; case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: @@ -208,6 +213,7 @@ bool MetricsUtils::IsOptimizerMetric(MetricsType type) { case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: case MetricsType::OPTIMIZER_TOP_N: + case MetricsType::OPTIMIZER_TOP_N_WINDOW_ELIMINATION: case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: case MetricsType::OPTIMIZER_REORDER_FILTER: diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp index be5fc8309..c7441a0fa 100644 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/exception/parser_exception.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/optimizer/optimizer.hpp" namespace duckdb { @@ -29,6 +30,7 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"column_lifetime", OptimizerType::COLUMN_LIFETIME}, {"limit_pushdown", OptimizerType::LIMIT_PUSHDOWN}, {"top_n", OptimizerType::TOP_N}, + {"top_n_window_elimination", OptimizerType::TOP_N_WINDOW_ELIMINATION}, {"build_side_probe_side", OptimizerType::BUILD_SIDE_PROBE_SIDE}, {"compressed_materialization", OptimizerType::COMPRESSED_MATERIALIZATION}, {"duplicate_groups", OptimizerType::DUPLICATE_GROUPS}, diff --git a/src/duckdb/src/common/sort/partition_state.cpp b/src/duckdb/src/common/sort/partition_state.cpp deleted file mode 100644 index 2a0a65895..000000000 --- a/src/duckdb/src/common/sort/partition_state.cpp +++ /dev/null @@ -1,671 +0,0 @@ -#include "duckdb/common/sort/partition_state.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/parallel/executor_task.hpp" - -namespace duckdb { - -PartitionGlobalHashGroup::PartitionGlobalHashGroup(ClientContext &context, const Orders &partitions, - const Orders &orders, const Types &payload_types, bool external) - : count(0) { - - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - global_sort = make_uniq(context, orders, payload_layout); - global_sort->external = external; - - // Set up a comparator for the partition subset - partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(partitions.size()); -} - -void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks) { - D_ASSERT(count > 0); - - SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); - - partition_mask.SetValidUnsafe(0); - unordered_map prefixes; - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(0); - D_ASSERT(order_mask.first >= partition_layout.column_count); - prefixes[order_mask.first] = global_sort->sort_layout.GetPrefixComparisonLayout(order_mask.first); - } - - for (++curr; curr.GetIndex() < count; ++curr) { - // Compare the partition subset first because if that differs, then so does the full ordering - const auto part_cmp = ComparePartitions(prev, curr); - - if (part_cmp) { - partition_mask.SetValidUnsafe(curr.GetIndex()); - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } else { - for (auto &order_mask : order_masks) { - if (prev.Compare(curr, prefixes[order_mask.first])) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } - } - ++prev; - } -} - -void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, - const Orders &order_bys, - const vector> &partition_stats) { - - // we sort by both 1) partition by expression list and 2) order by expressions - const auto partition_cols = partition_bys.size(); - for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { - auto &pexpr = partition_bys[prt_idx]; - - if (partition_stats.empty() || !partition_stats[prt_idx]) { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), nullptr); - } else { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), - partition_stats[prt_idx]->ToUnique()); - } - partitions.emplace_back(orders.back().Copy()); - } - - for (const auto &order : order_bys) { - orders.emplace_back(order.Copy()); - } -} - -PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, - const vector> &partition_bys, - const vector &order_bys, - const Types &payload_types, - const vector> &partition_stats, - idx_t estimated_cardinality) - : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), - fixed_bits(0), payload_types(payload_types), memory_per_thread(0), max_bits(1), count(0) { - - GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); - - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - external = ClientConfig::GetConfig(context).force_external; - - const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); - while (max_bits < 10 && (thread_pages >> max_bits) > 1) { - ++max_bits; - } - - grouping_types_ptr = make_shared_ptr(); - if (!orders.empty()) { - if (partitions.empty()) { - // Sort early into a dedicated hash group if we only sort. - grouping_types_ptr->Initialize(payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - auto new_group = make_uniq(context, partitions, orders, payload_types, external); - hash_groups.emplace_back(std::move(new_group)); - } else { - auto types = payload_types; - types.push_back(LogicalType::HASH); - grouping_types_ptr->Initialize(types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - ResizeGroupingData(estimated_cardinality); - } - } -} - -bool PartitionGlobalSinkState::HasMergeTasks() const { - if (grouping_data) { - auto &groups = grouping_data->GetPartitions(); - return !groups.empty(); - } else if (!hash_groups.empty()) { - D_ASSERT(hash_groups.size() == 1); - return hash_groups[0]->count > 0; - } else { - return false; - } -} - -void PartitionGlobalSinkState::SyncPartitioning(const PartitionGlobalSinkState &other) { - fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; - - const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; - if (fixed_bits != old_bits) { - const auto hash_col_idx = payload_types.size(); - grouping_data = - make_uniq(buffer_manager, grouping_types_ptr, fixed_bits, hash_col_idx); - } -} - -unique_ptr PartitionGlobalSinkState::CreatePartition(idx_t new_bits) const { - const auto hash_col_idx = payload_types.size(); - return make_uniq(buffer_manager, grouping_types_ptr, new_bits, hash_col_idx); -} - -void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { - // Have we started to combine? Then just live with it. - if (fixed_bits || (grouping_data && !grouping_data->GetPartitions().empty())) { - return; - } - // Is the average partition size too large? - const idx_t partition_size = DEFAULT_ROW_GROUP_SIZE; - const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; - auto new_bits = bits ? bits : 4; - while (new_bits < max_bits && (cardinality / RadixPartitioning::NumberOfPartitions(new_bits)) > partition_size) { - ++new_bits; - } - - // Repartition the grouping data - if (new_bits != bits) { - grouping_data = CreatePartition(new_bits); - } -} - -void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // We are done if the local_partition is right sized. - auto &local_radix = local_partition->Cast(); - const auto new_bits = grouping_data->GetRadixBits(); - if (local_radix.GetRadixBits() == new_bits) { - return; - } - - // If the local partition is now too small, flush it and reallocate - auto new_partition = CreatePartition(new_bits); - local_partition->FlushAppendState(*local_append); - local_partition->Repartition(context, *new_partition); - - local_partition = std::move(new_partition); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); -} - -void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // Make sure grouping_data doesn't change under us. - lock_guard guard(lock); - - if (!local_partition) { - local_partition = CreatePartition(grouping_data->GetRadixBits()); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); - return; - } - - // Grow the groups if they are too big - ResizeGroupingData(count); - - // Sync local partition to have the same bit count - SyncLocalPartition(local_partition, local_append); -} - -void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - if (!local_partition) { - return; - } - local_partition->FlushAppendState(*local_append); - - // Make sure grouping_data doesn't change under us. - // Combine has an internal mutex, so this is single-threaded anyway. - lock_guard guard(lock); - SyncLocalPartition(local_partition, local_append); - grouping_data->Combine(*local_partition); -} - -PartitionLocalMergeState::PartitionLocalMergeState(PartitionGlobalSinkState &gstate) - : merge_state(nullptr), stage(PartitionSortStage::INIT), finished(true), executor(gstate.context) { - - // Set up the sort expression computation. - vector sort_types; - for (auto &order : gstate.orders) { - auto &oexpr = order.expression; - sort_types.emplace_back(oexpr->return_type); - executor.AddExpression(*oexpr); - } - sort_chunk.Initialize(gstate.allocator, sort_types); - payload_chunk.Initialize(gstate.allocator, gstate.payload_types); -} - -void PartitionLocalMergeState::Scan() { - if (!merge_state->group_data) { - // OVER(ORDER BY...) - // Already sorted - return; - } - - auto &group_data = *merge_state->group_data; - auto &hash_group = *merge_state->hash_group; - auto &chunk_state = merge_state->chunk_state; - // Copy the data from the group into the sort code. - auto &global_sort = *hash_group.global_sort; - LocalSortState local_sort; - local_sort.Initialize(global_sort, global_sort.buffer_manager); - - TupleDataScanState local_scan; - group_data.InitializeScan(local_scan, merge_state->column_ids); - while (group_data.Scan(chunk_state, local_scan, payload_chunk)) { - sort_chunk.Reset(); - executor.Execute(payload_chunk, sort_chunk); - - local_sort.SinkChunk(sort_chunk, payload_chunk); - if (local_sort.SizeInBytes() > merge_state->memory_per_thread) { - local_sort.Sort(global_sort, true); - } - hash_group.count += payload_chunk.size(); - } - - global_sort.AddLocalState(local_sort); -} - -// Per-thread sink state -PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { - - vector group_types; - for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { - auto &pexpr = *gstate.partitions[prt_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - sort_cols = gstate.orders.size() + group_types.size(); - - if (sort_cols) { - auto payload_types = gstate.payload_types; - if (!group_types.empty()) { - // OVER(PARTITION BY...) - group_chunk.Initialize(allocator, group_types); - payload_types.emplace_back(LogicalType::HASH); - } else { - // OVER(ORDER BY...) - for (idx_t ord_idx = 0; ord_idx < gstate.orders.size(); ord_idx++) { - auto &pexpr = *gstate.orders[ord_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - group_chunk.Initialize(allocator, group_types); - - // Single partition - auto &global_sort = *gstate.hash_groups[0]->global_sort; - local_sort = make_uniq(); - local_sort->Initialize(global_sort, global_sort.buffer_manager); - } - // OVER(...) - payload_chunk.Initialize(allocator, payload_types); - } else { - // OVER() - payload_layout.Initialize(gstate.payload_types); - } -} - -void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { - const auto count = input_chunk.size(); - D_ASSERT(group_chunk.ColumnCount() > 0); - - // OVER(PARTITION BY...) (hash grouping) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - VectorOperations::Hash(group_chunk.data[0], hash_vector, count); - for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { - VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); - } -} - -void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { - gstate.count += input_chunk.size(); - - // OVER() - if (sort_cols == 0) { - // No sorts, so build paged row chunks - if (!rows) { - const auto entry_size = payload_layout.GetRowWidth(); - const auto block_size = gstate.buffer_manager.GetBlockSize(); - const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, block_size / entry_size + 1); - rows = make_uniq(gstate.buffer_manager, capacity, entry_size); - strings = make_uniq(gstate.buffer_manager, block_size, 1U, true); - } - const auto row_count = input_chunk.size(); - const auto row_sel = FlatVector::IncrementalSelectionVector(); - Vector addresses(LogicalType::POINTER); - auto key_locations = FlatVector::GetData(addresses); - const auto prev_rows_blocks = rows->blocks.size(); - auto handles = rows->Build(row_count, key_locations, nullptr, row_sel); - auto input_data = input_chunk.ToUnifiedFormat(); - RowOperations::Scatter(input_chunk, input_data.get(), payload_layout, addresses, *strings, *row_sel, row_count); - // Mark that row blocks contain pointers (heap blocks are pinned) - if (!payload_layout.AllConstant()) { - D_ASSERT(strings->keep_pinned); - for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { - rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink"); - } - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - local_sort->SinkChunk(group_chunk, input_chunk); - - auto &hash_group = *gstate.hash_groups[0]; - hash_group.count += input_chunk.size(); - - if (local_sort->SizeInBytes() > gstate.memory_per_thread) { - auto &global_sort = *hash_group.global_sort; - local_sort->Sort(global_sort, true); - } - return; - } - - // OVER(...) - payload_chunk.Reset(); - auto &hash_vector = payload_chunk.data.back(); - Hash(input_chunk, hash_vector); - for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { - payload_chunk.data[col_idx].Reference(input_chunk.data[col_idx]); - } - payload_chunk.SetCardinality(input_chunk); - - gstate.UpdateLocalPartition(local_partition, local_append); - local_partition->Append(*local_append, payload_chunk); -} - -void PartitionLocalSinkState::Combine() { - // OVER() - if (sort_cols == 0) { - // Only one partition again, so need a global lock. - lock_guard glock(gstate.lock); - if (gstate.rows) { - if (rows) { - gstate.rows->Merge(*rows); - gstate.strings->Merge(*strings); - rows.reset(); - strings.reset(); - } - } else { - gstate.rows = std::move(rows); - gstate.strings = std::move(strings); - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - auto &hash_group = *gstate.hash_groups[0]; - auto &global_sort = *hash_group.global_sort; - global_sort.AddLocalState(*local_sort); - local_sort.reset(); - return; - } - - // OVER(...) - gstate.CombineLocalPartition(local_partition, local_append); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, - hash_t hash_bin) - : sink(sink), group_data(std::move(group_data_p)), group_idx(sink.hash_groups.size()), - memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - auto new_group = make_uniq(sink.context, sink.partitions, sink.orders, sink.payload_types, - sink.external); - sink.hash_groups.emplace_back(std::move(new_group)); - - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; - - column_ids.reserve(sink.payload_types.size()); - for (column_t i = 0; i < sink.payload_types.size(); ++i) { - column_ids.emplace_back(i); - } - group_data->InitializeScan(chunk_state, column_ids); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) - : sink(sink), group_idx(0), memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - const hash_t hash_bin = 0; - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; -} - -void PartitionLocalMergeState::Prepare() { - merge_state->group_data.reset(); - - auto &global_sort = *merge_state->global_sort; - global_sort.PrepareMergePhase(); -} - -void PartitionLocalMergeState::Merge() { - auto &global_sort = *merge_state->global_sort; - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); -} - -void PartitionLocalMergeState::Sorted() { - merge_state->sink.OnSortedPartition(merge_state->group_idx); -} - -void PartitionLocalMergeState::ExecuteTask() { - switch (stage) { - case PartitionSortStage::SCAN: - Scan(); - break; - case PartitionSortStage::PREPARE: - Prepare(); - break; - case PartitionSortStage::MERGE: - Merge(); - break; - case PartitionSortStage::SORTED: - Sorted(); - break; - default: - throw InternalException("Unexpected PartitionSortStage in ExecuteTask!"); - } - - merge_state->CompleteTask(); - finished = true; -} - -bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { - lock_guard guard(lock); - - if (tasks_assigned >= total_tasks && !TryPrepareNextStage()) { - return false; - } - - local_state.merge_state = this; - local_state.stage = stage; - local_state.finished = false; - tasks_assigned++; - - return true; -} - -void PartitionGlobalMergeState::CompleteTask() { - lock_guard guard(lock); - - ++tasks_completed; -} - -bool PartitionGlobalMergeState::TryPrepareNextStage() { - if (tasks_completed < total_tasks) { - return false; - } - - tasks_assigned = tasks_completed = 0; - - switch (stage.load()) { - case PartitionSortStage::INIT: - // If the partitions are unordered, don't scan in parallel - // because it produces non-deterministic orderings. - // This can theoretically happen with ORDER BY, - // but that is something the query should be explicit about. - total_tasks = sink.orders.size() > sink.partitions.size() ? num_threads : 1; - stage = PartitionSortStage::SCAN; - return true; - - case PartitionSortStage::SCAN: - total_tasks = 1; - stage = PartitionSortStage::PREPARE; - return true; - - case PartitionSortStage::PREPARE: - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - stage = PartitionSortStage::MERGE; - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::MERGE: - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::SORTED: - stage = PartitionSortStage::FINISHED; - total_tasks = 0; - return false; - - case PartitionSortStage::FINISHED: - return false; - } - - stage = PartitionSortStage::SORTED; - total_tasks = 1; - - return true; -} - -PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { - // Schedule all the sorts for maximum thread utilisation - if (sink.grouping_data) { - auto &partitions = sink.grouping_data->GetPartitions(); - sink.bin_groups.resize(partitions.size(), partitions.size()); - for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { - auto &group_data = partitions[hash_bin]; - // Prepare for merge sort phase - if (group_data->Count()) { - auto state = make_uniq(sink, std::move(group_data), hash_bin); - states.emplace_back(std::move(state)); - } - } - } else { - // OVER(ORDER BY...) - // Already sunk into the single global sort, so set up single merge with no data - sink.bin_groups.resize(1, 1); - auto state = make_uniq(sink); - states.emplace_back(std::move(state)); - } - - sink.OnBeginMerge(); -} - -class PartitionMergeTask : public ExecutorTask { -public: - PartitionMergeTask(shared_ptr event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p, - PartitionGlobalSinkState &gstate, const PhysicalOperator &op) - : ExecutorTask(context_p, std::move(event_p), op), local_state(gstate), hash_groups(hash_groups_p) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - - string TaskType() const override { - return "PartitionMergeTask"; - } - -private: - struct ExecutorCallback : public PartitionGlobalMergeStates::Callback { - explicit ExecutorCallback(Executor &executor) : executor(executor) { - } - - bool HasError() const override { - return executor.HasError(); - } - - Executor &executor; - }; - - PartitionLocalMergeState local_state; - PartitionGlobalMergeStates &hash_groups; -}; - -bool PartitionGlobalMergeStates::ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback) { - // Loop until all hash groups are done - size_t sorted = 0; - while (sorted < states.size()) { - // First check if there is an unfinished task for this thread - if (callback.HasError()) { - return false; - } - if (!local_state.TaskFinished()) { - local_state.ExecuteTask(); - continue; - } - - // Thread is done with its assigned task, try to fetch new work - for (auto group = sorted; group < states.size(); ++group) { - auto &global_state = states[group]; - if (global_state->IsFinished()) { - // This hash group is done - // Update the high water mark of densely completed groups - if (sorted == group) { - ++sorted; - } - continue; - } - - // Try to assign work for this hash group to this thread - if (global_state->AssignTask(local_state)) { - // We assigned a task to this thread! - // Break out of this loop to re-enter the top-level loop and execute the task - break; - } - - // We were able to prepare the next merge round, - // but we were not able to assign a task for it to this thread - // The tasks were assigned to other threads while this thread waited for the lock - // Go to the next iteration to see if another hash group has a task - } - } - - return true; -} - -TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { - ExecutorCallback callback(executor); - - if (!hash_groups.ExecuteTask(local_state, callback)) { - return TaskExecutionResult::TASK_ERROR; - } - - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -void PartitionMergeEvent::Schedule() { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> merge_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate, op)); - } - SetTasks(std::move(merge_tasks)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/hashed_sort.cpp b/src/duckdb/src/common/sorting/hashed_sort.cpp index 5571e0bc3..cdf5dde5b 100644 --- a/src/duckdb/src/common/sorting/hashed_sort.cpp +++ b/src/duckdb/src/common/sorting/hashed_sort.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/radix_partitioning.hpp" #include "duckdb/parallel/base_pipeline_event.hpp" #include "duckdb/parallel/thread_context.hpp" @@ -26,7 +27,9 @@ class HashedSortGroup { // Source atomic tasks_completed; unique_ptr sort_source; - unique_ptr sorted; + + unique_ptr columns; + unique_ptr run; }; HashedSortGroup::HashedSortGroup(ClientContext &client, optional_ptr sort, idx_t group_idx) @@ -53,6 +56,7 @@ class HashedSortGlobalSinkState : public GlobalSinkState { // OVER(PARTITION BY...) (hash grouping) unique_ptr CreatePartition(idx_t new_bits) const; + void SyncPartitioning(const HashedSortGlobalSinkState &other); void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &partition_append); void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); ProgressData GetSinkProgress(ClientContext &context, const ProgressData source_progress) const; @@ -172,6 +176,15 @@ void HashedSortGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_pa SyncLocalPartition(local_partition, partition_append); } +void HashedSortGlobalSinkState::SyncPartitioning(const HashedSortGlobalSinkState &other) { + fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; + + const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; + if (fixed_bits != old_bits) { + grouping_data = CreatePartition(fixed_bits); + } +} + void HashedSortGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { if (!local_partition) { @@ -347,6 +360,12 @@ HashedSortLocalSinkState::HashedSortLocalSinkState(ExecutionContext &context, co } } +void HashedSort::Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const { + auto &src = source.Cast(); + auto &tgt = target.Cast(); + tgt.SyncPartitioning(src); +} + void HashedSortLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { const auto count = input_chunk.size(); D_ASSERT(group_chunk.ColumnCount() > 0); @@ -393,6 +412,15 @@ SinkResultType HashedSort::Sink(ExecutionContext &context, DataChunk &input_chun payload_chunk.data[input_chunk.ColumnCount() + i].Reference(sort_chunk.data[i]); } } + + // Append a forced payload column + if (force_payload) { + auto &vec = payload_chunk.data[input_chunk.ColumnCount() + sort_chunk.ColumnCount()]; + D_ASSERT(vec.GetType().id() == LogicalTypeId::BOOLEAN); + vec.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(vec, true); + } + payload_chunk.SetCardinality(input_chunk); // OVER(ORDER BY...) @@ -435,14 +463,14 @@ SinkCombineResultType HashedSort::Combine(ExecutionContext &context, OperatorSin auto &hash_groups = gstate.hash_groups; if (!hash_groups.empty()) { D_ASSERT(hash_groups.size() == 1); - auto &unsorted = *hash_groups[0]->sorted; + auto &unsorted = *hash_groups[0]->columns; if (lstate.unsorted) { unsorted.Combine(*lstate.unsorted); lstate.unsorted.reset(); } } else { auto new_group = make_uniq(context.client, sort, idx_t(0)); - new_group->sorted = std::move(lstate.unsorted); + new_group->columns = std::move(lstate.unsorted); hash_groups.emplace_back(std::move(new_group)); } return SinkCombineResultType::FINISHED; @@ -508,7 +536,7 @@ SinkCombineResultType HashedSort::Combine(ExecutionContext &context, OperatorSin class HashedSortMaterializeTask : public ExecutorTask { public: HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, const PhysicalOperator &op, - HashedSortGroup &hash_group, idx_t tasks_scheduled); + HashedSortGroup &hash_group, idx_t tasks_scheduled, bool build_runs); TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; @@ -520,13 +548,14 @@ class HashedSortMaterializeTask : public ExecutorTask { Pipeline &pipeline; HashedSortGroup &hash_group; const idx_t tasks_scheduled; + const bool build_runs; }; HashedSortMaterializeTask::HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, const PhysicalOperator &op, HashedSortGroup &hash_group, - idx_t tasks_scheduled) + idx_t tasks_scheduled, bool build_runs) : ExecutorTask(pipeline.GetClientContext(), std::move(event), op), pipeline(pipeline), hash_group(hash_group), - tasks_scheduled(tasks_scheduled) { + tasks_scheduled(tasks_scheduled), build_runs(build_runs) { } TaskExecutionResult HashedSortMaterializeTask::ExecuteTask(TaskExecutionMode mode) { @@ -536,9 +565,20 @@ TaskExecutionResult HashedSortMaterializeTask::ExecuteTask(TaskExecutionMode mod auto sort_local = sort.GetLocalSourceState(execution, sort_global); InterruptState interrupt((weak_ptr(shared_from_this()))); OperatorSourceInput input {sort_global, *sort_local, interrupt}; - sort.MaterializeColumnData(execution, input); + if (build_runs) { + sort.MaterializeSortedRun(execution, input); + } else { + sort.MaterializeColumnData(execution, input); + } if (++hash_group.tasks_completed == tasks_scheduled) { - hash_group.sorted = sort.GetColumnData(input); + if (build_runs) { + hash_group.run = sort.GetSortedRun(sort_global); + if (!hash_group.run) { + hash_group.run = make_uniq(execution.client, sort, false); + } + } else { + hash_group.columns = sort.GetColumnData(input); + } } event->FinishTask(); @@ -551,18 +591,20 @@ TaskExecutionResult HashedSortMaterializeTask::ExecuteTask(TaskExecutionMode mod // Formerly PartitionMergeEvent class HashedSortMaterializeEvent : public BasePipelineEvent { public: - HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, const PhysicalOperator &op); + HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, const PhysicalOperator &op, + bool build_runs); HashedSortGlobalSinkState &gstate; const PhysicalOperator &op; + const bool build_runs; public: void Schedule() override; }; HashedSortMaterializeEvent::HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, - const PhysicalOperator &op) - : BasePipelineEvent(pipeline), gstate(gstate), op(op) { + const PhysicalOperator &op, bool build_runs) + : BasePipelineEvent(pipeline), gstate(gstate), op(op), build_runs(build_runs) { } void HashedSortMaterializeEvent::Schedule() { @@ -573,7 +615,7 @@ void HashedSortMaterializeEvent::Schedule() { const auto num_threads = NumericCast(ts.NumberOfThreads()); auto &sort = *gstate.hashed_sort.sort; - vector> merge_tasks; + vector> tasks; for (auto &hash_group : gstate.hash_groups) { if (!hash_group) { continue; @@ -582,12 +624,12 @@ void HashedSortMaterializeEvent::Schedule() { hash_group->sort_source = sort.GetGlobalSourceState(client, global_sink); const auto tasks_scheduled = MinValue(num_threads, hash_group->sort_source->MaxThreads()); for (idx_t t = 0; t < tasks_scheduled; ++t) { - merge_tasks.emplace_back( - make_uniq(*pipeline, shared_from_this(), op, *hash_group, tasks_scheduled)); + tasks.emplace_back(make_uniq(*pipeline, shared_from_this(), op, *hash_group, + tasks_scheduled, build_runs)); } } - SetTasks(std::move(merge_tasks)); + SetTasks(std::move(tasks)); } //===--------------------------------------------------------------------===// @@ -596,22 +638,26 @@ void HashedSortMaterializeEvent::Schedule() { class HashedSortGlobalSourceState : public GlobalSourceState { public: using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; HashedSortGlobalSourceState(ClientContext &client, HashedSortGlobalSinkState &gsink) { if (!gsink.count) { return; } - hash_groups.resize(gsink.hash_groups.size()); + columns.resize(gsink.hash_groups.size()); + runs.resize(gsink.hash_groups.size()); for (auto &hash_group : gsink.hash_groups) { if (!hash_group) { continue; } const auto group_idx = hash_group->group_idx; - hash_groups[group_idx] = std::move(hash_group->sorted); + columns[group_idx] = std::move(hash_group->columns); + runs[group_idx] = std::move(hash_group->run); } } - vector hash_groups; + vector columns; + vector runs; }; //===--------------------------------------------------------------------===// @@ -642,7 +688,8 @@ void HashedSort::GenerateOrderings(Orders &partitions, Orders &orders, HashedSort::HashedSort(ClientContext &client, const vector> &partition_bys, const vector &order_bys, const Types &input_types, - const vector> &partition_stats, idx_t estimated_cardinality) + const vector> &partition_stats, idx_t estimated_cardinality, + bool require_payload) : client(client), estimated_cardinality(estimated_cardinality), payload_types(input_types) { GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); @@ -673,6 +720,15 @@ HashedSort::HashedSort(ClientContext &client, const vector sort_set(sort_ids.begin(), sort_ids.end()); + force_payload = (sort_set.size() >= payload_types.size()); + if (force_payload) { + payload_types.emplace_back(LogicalType::BOOLEAN); + } + } vector projection_map; sort = make_uniq(client, orders, payload_types, projection_map); } @@ -697,7 +753,7 @@ unique_ptr HashedSort::GetLocalSourceState(ExecutionContext &c vector &HashedSort::GetHashGroups(GlobalSourceState &gstate) const { auto &gsource = gstate.Cast(); - return gsource.hash_groups; + return gsource.columns; } SinkFinalizeType HashedSort::MaterializeHashGroups(Pipeline &pipeline, Event &event, const PhysicalOperator &op, @@ -707,7 +763,7 @@ SinkFinalizeType HashedSort::MaterializeHashGroups(Pipeline &pipeline, Event &ev // OVER() if (sort_col_count == 0) { auto &hash_group = *gsink.hash_groups[0]; - auto &unsorted = *hash_group.sorted; + auto &unsorted = *hash_group.columns; if (!unsorted.Count()) { return SinkFinalizeType::NO_OUTPUT_POSSIBLE; } @@ -715,10 +771,36 @@ SinkFinalizeType HashedSort::MaterializeHashGroups(Pipeline &pipeline, Event &ev } // Schedule all the sorts for maximum thread utilisation - auto sort_event = make_shared_ptr(gsink, pipeline, op); + auto sort_event = make_shared_ptr(gsink, pipeline, op, false); event.InsertEvent(std::move(sort_event)); return SinkFinalizeType::READY; } +SinkFinalizeType HashedSort::MaterializeSortedRuns(Pipeline &pipeline, Event &event, const PhysicalOperator &op, + OperatorSinkFinalizeInput &finalize) const { + auto &gsink = finalize.global_state.Cast(); + + // OVER() + if (sort_col_count == 0) { + auto &hash_group = *gsink.hash_groups[0]; + auto &unsorted = *hash_group.columns; + if (!unsorted.Count()) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + return SinkFinalizeType::READY; + } + + // Schedule all the sorts for maximum thread utilisation + auto sort_event = make_shared_ptr(gsink, pipeline, op, true); + event.InsertEvent(std::move(sort_event)); + + return SinkFinalizeType::READY; +} + +vector &HashedSort::GetSortedRuns(GlobalSourceState &gstate) const { + auto &gsource = gstate.Cast(); + return gsource.runs; +} + } // namespace duckdb diff --git a/src/duckdb/src/common/types/geometry.cpp b/src/duckdb/src/common/types/geometry.cpp index 2ec9ac53a..e05816546 100644 --- a/src/duckdb/src/common/types/geometry.cpp +++ b/src/duckdb/src/common/types/geometry.cpp @@ -770,4 +770,116 @@ string_t Geometry::ToString(Vector &result, const string_t &geom) { return StringVector::AddString(result, buffer.data(), buffer.size()); } +pair Geometry::GetType(const string_t &wkb) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + // Read the byte order (should always be 1 for little-endian) + const auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + + const auto meta = reader.Read(); + const auto type_id = meta % 1000; + const auto flag_id = meta / 1000; + + if (type_id < 1 || type_id > 7) { + throw InvalidInputException("Unsupported geometry type %d in WKB", type_id); + } + if (flag_id > 3) { + throw InvalidInputException("Unsupported geometry flag %d in WKB", flag_id); + } + + const auto geom_type = static_cast(type_id); + const auto vert_type = static_cast(flag_id); + + return {geom_type, vert_type}; +} + +template +static uint32_t ParseVerticesInternal(BlobReader &reader, GeometryExtent &extent, uint32_t vert_count, bool check_nan) { + uint32_t count = 0; + + // Issue a single .Reserve() for all vertices, to minimize bounds checking overhead + const auto ptr = const_data_ptr_cast(reader.Reserve(vert_count * sizeof(VERTEX_TYPE))); + + for (uint32_t vert_idx = 0; vert_idx < vert_count; vert_idx++) { + VERTEX_TYPE vertex = Load(ptr + vert_idx * sizeof(VERTEX_TYPE)); + if (check_nan && vertex.AllNan()) { + continue; + } + + extent.Extend(vertex); + count++; + } + return count; +} + +static uint32_t ParseVertices(BlobReader &reader, GeometryExtent &extent, uint32_t vert_count, VertexType type, + bool check_nan) { + switch (type) { + case VertexType::XY: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYZ: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYM: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + case VertexType::XYZM: + return ParseVerticesInternal(reader, extent, vert_count, check_nan); + default: + throw InvalidInputException("Unsupported vertex type %d in WKB", static_cast(type)); + } +} + +uint32_t Geometry::GetExtent(const string_t &wkb, GeometryExtent &extent) { + BlobReader reader(wkb.GetData(), static_cast(wkb.GetSize())); + + uint32_t vertex_count = 0; + + while (!reader.IsAtEnd()) { + const auto byte_order = reader.Read(); + if (byte_order != 1) { + throw InvalidInputException("Unsupported byte order %d in WKB", byte_order); + } + const auto meta = reader.Read(); + const auto type_id = meta % 1000; + const auto flag_id = meta / 1000; + if (type_id < 1 || type_id > 7) { + throw InvalidInputException("Unsupported geometry type %d in WKB", type_id); + } + if (flag_id > 3) { + throw InvalidInputException("Unsupported geometry flag %d in WKB", flag_id); + } + const auto geom_type = static_cast(type_id); + const auto vert_type = static_cast(flag_id); + + switch (geom_type) { + case GeometryType::POINT: { + vertex_count += ParseVertices(reader, extent, 1, vert_type, true); + } break; + case GeometryType::LINESTRING: { + const auto vert_count = reader.Read(); + vertex_count += ParseVertices(reader, extent, vert_count, vert_type, false); + } break; + case GeometryType::POLYGON: { + const auto ring_count = reader.Read(); + for (uint32_t ring_idx = 0; ring_idx < ring_count; ring_idx++) { + const auto vert_count = reader.Read(); + vertex_count += ParseVertices(reader, extent, vert_count, vert_type, false); + } + } break; + case GeometryType::MULTIPOINT: + case GeometryType::MULTILINESTRING: + case GeometryType::MULTIPOLYGON: + case GeometryType::GEOMETRYCOLLECTION: { + // Skip count. We don't need it for extent calculation. + reader.Skip(sizeof(uint32_t)); + } break; + default: + throw InvalidInputException("Unsupported geometry type %d in WKB", static_cast(geom_type)); + } + } + return vertex_count; +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp index 54c224b9f..ad974a475 100644 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -1,11 +1,14 @@ #include "duckdb/execution/operator/join/physical_asof_join.hpp" #include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/sorting/sort_key.hpp" +#include "duckdb/common/sorting/sorted_run.hpp" +#include "duckdb/common/types/row/block_iterator.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/operator/join/outer_join_marker.hpp" +#include "duckdb/execution/operator/join/physical_range_join.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parallel/event.hpp" #include "duckdb/parallel/meta_pipeline.hpp" @@ -75,65 +78,54 @@ PhysicalAsOfJoin::PhysicalAsOfJoin(PhysicalPlan &physical_plan, LogicalCompariso //===--------------------------------------------------------------------===// class AsOfGlobalSinkState : public GlobalSinkState { public: - using PartitionSinkPtr = unique_ptr; + using HashedSortPtr = unique_ptr; + using HashedSinkPtr = unique_ptr; using PartitionMarkers = vector; - using LocalBuffers = vector>; + using HashGroupPtr = unique_ptr; + using HashGroups = vector; - AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) : is_outer(IsRightOuterJoin(op.join_type)) { + AsOfGlobalSinkState(ClientContext &client, const PhysicalAsOfJoin &op) : is_outer(IsRightOuterJoin(op.join_type)) { // Set up partitions for both sides - partition_sinks.reserve(2); + hashed_sorts.reserve(2); + hashed_sinks.reserve(2); const vector> partitions_stats; auto &lhs = op.children[0].get(); - auto sink = make_uniq(context, op.lhs_partitions, op.lhs_orders, lhs.GetTypes(), - partitions_stats, lhs.estimated_cardinality); - partition_sinks.emplace_back(std::move(sink)); - auto &rhs = op.children[1].get(); - sink = make_uniq(context, op.rhs_partitions, op.rhs_orders, rhs.GetTypes(), - partitions_stats, rhs.estimated_cardinality); - partition_sinks.emplace_back(std::move(sink)); - - local_buffers.resize(2); - } + auto sort = make_uniq(client, op.lhs_partitions, op.lhs_orders, lhs.GetTypes(), partitions_stats, + lhs.estimated_cardinality, true); + hashed_sinks.emplace_back(sort->GetGlobalSinkState(client)); + hashed_sorts.emplace_back(std::move(sort)); - idx_t Count() const { - return partition_sinks[child]->count; - } + auto &rhs = op.children[1].get(); + sort = make_uniq(client, op.rhs_partitions, op.rhs_orders, rhs.GetTypes(), partitions_stats, + rhs.estimated_cardinality, true); + hashed_sinks.emplace_back(sort->GetGlobalSinkState(client)); + hashed_sorts.emplace_back(std::move(sort)); - PartitionLocalSinkState *RegisterBuffer(ClientContext &context) { - lock_guard guard(lock); - auto &buffers = local_buffers[child]; - buffers.emplace_back(make_uniq(context, *partition_sinks[child])); - return buffers.back().get(); + hash_groups.resize(2); } //! The child that is being materialised (right/1 then left/0) size_t child = 1; + //! The child's partitioning description + vector hashed_sorts; //! The child's partitioning buffer - vector partition_sinks; + vector hashed_sinks; + //! The child's hash groups + vector hash_groups; //! Whether the right side is outer const bool is_outer; //! The right outer join markers (one per partition) vector right_outers; - - mutex lock; - vector local_buffers; }; class AsOfLocalSinkState : public LocalSinkState { public: - AsOfLocalSinkState(ClientContext &context, AsOfGlobalSinkState &gsink) - : local_partition(context, *gsink.partition_sinks[gsink.child]) { - } - - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); + AsOfLocalSinkState(ExecutionContext &context, AsOfGlobalSinkState &gsink) { + auto &hashed_sort = *gsink.hashed_sorts[gsink.child]; + local_partition = hashed_sort.GetLocalSinkState(context); } - void Combine() { - local_partition.Combine(); - } - - PartitionLocalSinkState local_partition; + unique_ptr local_partition; }; unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext &context) const { @@ -142,53 +134,64 @@ unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext & unique_ptr PhysicalAsOfJoin::GetLocalSinkState(ExecutionContext &context) const { auto &gsink = sink_state->Cast(); - return make_uniq(context.client, gsink); + return make_uniq(context, gsink); } -SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); +SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &sink) const { + auto &gstate = sink.global_state.Cast(); + auto &lstate = sink.local_state.Cast(); - lstate.Sink(chunk); + auto &hashed_sort = *gstate.hashed_sorts[gstate.child]; + auto &gsink = *gstate.hashed_sinks[gstate.child]; + auto &lsink = *lstate.local_partition; - return SinkResultType::NEED_MORE_INPUT; + OperatorSinkInput hsink {gsink, lsink, sink.interrupt_state}; + return hashed_sort.Sink(context, chunk, hsink); } -SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.Combine(); - return SinkCombineResultType::FINISHED; +SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &combine) const { + auto &gstate = combine.global_state.Cast(); + auto &lstate = combine.local_state.Cast(); + + auto &hashed_sort = *gstate.hashed_sorts[gstate.child]; + auto &gsink = *gstate.hashed_sinks[gstate.child]; + auto &lsink = *lstate.local_partition; + + OperatorSinkCombineInput hcombine {gsink, lsink, combine.interrupt_state}; + return hashed_sort.Combine(context, hcombine); } //===--------------------------------------------------------------------===// // Finalize //===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); +SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, + OperatorSinkFinalizeInput &finalize) const { + auto &gstate = finalize.global_state.Cast(); // The data is all in so we can synchronise the left partitioning. - auto result = SinkFinalizeType::READY; - auto &partition_sink = *gstate.partition_sinks[gstate.child]; + auto &hashed_sort = *gstate.hashed_sorts[gstate.child]; + auto &hashed_sink = *gstate.hashed_sinks[gstate.child]; + OperatorSinkFinalizeInput hfinalize {hashed_sink, finalize.interrupt_state}; if (gstate.child == 1) { - gstate.partition_sinks[1 - gstate.child]->SyncPartitioning(partition_sink); + auto &lhs_groups = *gstate.hashed_sinks[1 - gstate.child]; + auto &rhs_groups = hashed_sink; + hashed_sort.Synchronize(rhs_groups, lhs_groups); - // Find the first group to sort - if (!partition_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { + auto result = hashed_sort.Finalize(client, hfinalize); + if (result != SinkFinalizeType::READY && EmptyResultIfRHSIsEmpty()) { // Empty input! - result = SinkFinalizeType::NO_OUTPUT_POSSIBLE; + gstate.child = 1 - gstate.child; + return result; } - } - - // Schedule all the sorts for maximum thread utilisation - if (partition_sink.HasMergeTasks()) { - auto new_event = make_shared_ptr(partition_sink, pipeline, *this); - event.InsertEvent(std::move(new_event)); + } else { + hashed_sort.Finalize(client, hfinalize); } // Switch sides gstate.child = 1 - gstate.child; - return result; + // Schedule all the sorts for maximum thread utilisation + return hashed_sort.MaterializeSortedRuns(pipeline, event, *this, hfinalize); } OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, @@ -199,18 +202,146 @@ OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// -class AsOfProbeBuffer { +class AsOfPayloadScanner { public: - using Orders = vector; + using Types = vector; + using Columns = vector; + + AsOfPayloadScanner(const SortedRun &sorted_run, const HashedSort &hashed_sort); + idx_t Base() const { + return base; + } + idx_t Scanned() const { + return scanned; + } + idx_t Remaining() const { + return count - scanned; + } + bool Scan(DataChunk &chunk) { + // Free the previous blocks + block_state.SetKeepPinned(true); + block_state.SetPinPayload(true); + + base = scanned; + const auto result = (this->*scan_func)(); + chunk.ReferenceColumns(scan_chunk, scan_ids); + scanned += scan_chunk.size(); + ++chunk_idx; + return result; + } + +private: + template + bool TemplatedScan() { + using SORT_KEY = SortKey; + using BLOCK_ITERATOR = block_iterator_t; + BLOCK_ITERATOR itr(block_state, chunk_idx, 0); + + const auto sort_keys = FlatVector::GetData(sort_key_pointers); + const auto result_count = MinValue(Remaining(), STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < result_count; ++i) { + const auto idx = block_state.GetIndex(chunk_idx, i); + sort_keys[i] = &itr[idx]; + } + + // Scan + scan_chunk.Reset(); + scan_state.Scan(sorted_run, sort_key_pointers, result_count, scan_chunk); + return scan_chunk.size() > 0; + } + + // Only figure out the scan function once. + using scan_t = bool (duckdb::AsOfPayloadScanner::*)(); + scan_t scan_func; + + const SortedRun &sorted_run; + ExternalBlockIteratorState block_state; + Vector sort_key_pointers = Vector(LogicalType::POINTER); + SortedRunScanState scan_state; + const Columns scan_ids; + DataChunk scan_chunk; + const idx_t count; + idx_t base = 0; + idx_t scanned = 0; + idx_t chunk_idx = 0; +}; - static bool IsExternal(ClientContext &context) { - return ClientConfig::GetConfig(context).force_external; +AsOfPayloadScanner::AsOfPayloadScanner(const SortedRun &sorted_run, const HashedSort &hashed_sort) + : sorted_run(sorted_run), block_state(*sorted_run.key_data, sorted_run.payload_data.get()), + scan_state(sorted_run.context, sorted_run.sort), scan_ids(hashed_sort.scan_ids), count(sorted_run.Count()) { + + scan_chunk.Initialize(sorted_run.context, hashed_sort.payload_types); + const auto sort_key_type = sorted_run.key_data->GetLayout().GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_16: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_24: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_FIXED_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + scan_func = &AsOfPayloadScanner::TemplatedScan; + break; + default: + throw NotImplementedException("AsOfPayloadScanner for %s", EnumUtil::ToString(sort_key_type)); } +} + +class AsOfProbeBuffer { +public: + using Orders = vector; AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op); public: - void ResolveJoin(bool *found_matches, idx_t *matches = nullptr); + // Comparison utilities + static bool IsStrictComparison(ExpressionType comparison) { + switch (comparison) { + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_GREATERTHAN: + return true; + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return false; + default: + throw NotImplementedException("Unsupported comparison type for ASOF join"); + } + } + + //! Is left cmp right? + template + static inline bool Compare(const T &lhs, const T &rhs, const bool strict) { + const bool less_than = lhs < rhs; + if (!less_than && !strict) { + return !(rhs < lhs); + } + return less_than; + } + + template + void ResolveJoin(bool *found_matches, idx_t *matches); + + using resolve_join_t = void (duckdb::AsOfProbeBuffer::*)(bool *, idx_t *); + resolve_join_t resolve_join_func; + bool Scanning() const { return lhs_scanner.get(); } @@ -218,6 +349,20 @@ class AsOfProbeBuffer { bool NextLeft(); void EndLeftScan(); + //! Create a new iterator for the sorted run + static unique_ptr CreateIteratorState(SortedRun &sorted) { + auto state = make_uniq(*sorted.key_data, sorted.payload_data.get()); + + // Unless we do this, we will only get values from the first chunk + Repin(*state); + + return state; + } + //! Reset the pins for an iterator so we release memory in a timely manner + static void Repin(ExternalBlockIteratorState &iter) { + // Don't pin the payload because we are not using it here. + iter.SetKeepPinned(true); + } // resolve joins that output max N elements (SEMI, ANTI, MARK) void ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk); // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) @@ -230,34 +375,37 @@ class AsOfProbeBuffer { ClientContext &client; const PhysicalAsOfJoin &op; - BufferManager &buffer_manager; - const bool force_external; - const idx_t memory_per_thread; - Orders lhs_orders; + //! Is the inequality strict? + const bool strict; // LHS scanning SelectionVector lhs_scan_sel; - optional_ptr left_hash; + optional_ptr left_group; OuterJoinMarker left_outer; - unique_ptr left_itr; - unique_ptr lhs_scanner; + unique_ptr left_itr; + unique_ptr lhs_scanner; DataChunk lhs_scanned; DataChunk lhs_payload; ExpressionExecutor lhs_executor; DataChunk lhs_keys; ValidityMask lhs_valid_mask; - idx_t left_group = 0; + idx_t left_bin = 0; SelectionVector lhs_match_sel; // RHS scanning - optional_ptr right_hash; + optional_ptr right_group; optional_ptr right_outer; - unique_ptr right_itr; - unique_ptr rhs_scanner; + unique_ptr right_itr; + idx_t right_pos; // ExternalBlockIteratorState doesn't know this... + unique_ptr rhs_scanner; DataChunk rhs_payload; - idx_t right_group = 0; + ExpressionExecutor rhs_executor; + DataChunk rhs_input; + DataChunk rhs_keys; + idx_t right_bin = 0; // Predicate evaluation + SelectionVector tail_sel; SelectionVector filter_sel; ExpressionExecutor filterer; @@ -266,28 +414,32 @@ class AsOfProbeBuffer { }; AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op) - : client(client), op(op), buffer_manager(BufferManager::GetBufferManager(client)), - force_external(IsExternal(client)), memory_per_thread(op.GetMaxThreadMemory(client)), - left_outer(IsLeftOuterJoin(op.join_type)), lhs_executor(client), filterer(client), fetch_next_left(true) { - vector> partition_stats; - Orders partitions; // Not used. - PartitionGlobalSinkState::GenerateOrderings(partitions, lhs_orders, op.lhs_partitions, op.lhs_orders, - partition_stats); + : client(client), op(op), strict(IsStrictComparison(op.comparison_type)), left_outer(IsLeftOuterJoin(op.join_type)), + lhs_executor(client), rhs_executor(client), filterer(client), fetch_next_left(true) { lhs_keys.Initialize(client, op.join_key_types); for (const auto &cond : op.conditions) { lhs_executor.AddExpression(*cond.left); } - // We sort the row numbers of the incoming block, not the rows lhs_scanned.Initialize(client, op.children[0].get().GetTypes()); lhs_payload.Initialize(client, op.children[0].get().GetTypes()); rhs_payload.Initialize(client, op.children[1].get().GetTypes()); + rhs_input.Initialize(client, op.children[1].get().GetTypes()); lhs_scan_sel.Initialize(); lhs_match_sel.Initialize(); left_outer.Initialize(STANDARD_VECTOR_SIZE); + // If we have equality predicates, we need some more buffers. + if (op.conditions.size() > 1) { + tail_sel.Initialize(); + rhs_keys.Initialize(client, op.join_key_types); + for (const auto &cond : op.conditions) { + rhs_executor.AddExpression(*cond.right); + } + } + if (op.predicate) { filter_sel.Initialize(); filterer.AddExpression(*op.predicate); @@ -297,74 +449,86 @@ AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin & void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { auto &gsink = op.sink_state->Cast(); - // Always set right_group too for memory management - auto &rhs_sink = *gsink.partition_sinks[1]; - if (scan_bin < rhs_sink.bin_groups.size()) { - right_group = rhs_sink.bin_groups[scan_bin]; + // Always set right_bin too for memory management + auto &rhs_groups = gsink.hash_groups[1]; + if (scan_bin < rhs_groups.size()) { + right_bin = scan_bin; } else { - right_group = rhs_sink.bin_groups.size(); + right_bin = rhs_groups.size(); } - auto &lhs_sink = *gsink.partition_sinks[0]; - left_group = lhs_sink.bin_groups[scan_bin]; - if (scan_bin < lhs_sink.bin_groups.size()) { - left_group = lhs_sink.bin_groups[scan_bin]; + auto &lhs_groups = gsink.hash_groups[0]; + if (scan_bin < lhs_groups.size()) { + left_bin = scan_bin; } else { - left_group = lhs_sink.bin_groups.size(); + left_bin = lhs_groups.size(); + } + + if (left_bin >= lhs_groups.size()) { + return; } - if (left_group >= lhs_sink.bin_groups.size()) { + left_group = lhs_groups[left_bin].get(); + if (!left_group || !left_group->Count()) { return; } - auto iterator_comp = ExpressionType::INVALID; - switch (op.comparison_type) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_LESSTHANOREQUALTO; + // Set up function pointer for sort type + const auto sort_key_type = left_group->key_data->GetLayout().GetSortKeyType(); + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_GREATERTHAN: - iterator_comp = ExpressionType::COMPARE_LESSTHAN; + case SortKeyType::NO_PAYLOAD_FIXED_16: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_GREATERTHANOREQUALTO; + case SortKeyType::NO_PAYLOAD_FIXED_24: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; - case ExpressionType::COMPARE_LESSTHAN: - iterator_comp = ExpressionType::COMPARE_GREATERTHAN; + case SortKeyType::NO_PAYLOAD_FIXED_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_16: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_24: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_FIXED_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + resolve_join_func = &AsOfProbeBuffer::ResolveJoin; break; default: throw NotImplementedException("Unsupported comparison type for ASOF join"); } - left_hash = lhs_sink.hash_groups[left_group].get(); - auto &left_sort = *(left_hash->global_sort); - if (left_sort.sorted_blocks.empty()) { - return; - } - lhs_scanner = make_uniq(left_sort, false); - left_itr = make_uniq(left_sort, iterator_comp); + lhs_scanner = make_uniq(*left_group, *gsink.hashed_sorts[0]); + left_itr = CreateIteratorState(*left_group); // We are only probing the corresponding right side bin, which may be empty // If it is empty, we leave the iterator as null so we can emit left matches - if (right_group < rhs_sink.bin_groups.size()) { - right_hash = rhs_sink.hash_groups[right_group].get(); - right_outer = gsink.right_outers.data() + right_group; - auto &right_sort = *(right_hash->global_sort); - if (!right_sort.sorted_blocks.empty()) { - right_itr = make_uniq(right_sort, iterator_comp); - rhs_scanner = make_uniq(right_sort, false); + right_pos = 0; + if (right_bin < rhs_groups.size()) { + right_group = rhs_groups[right_bin].get(); + right_outer = gsink.right_outers.data() + right_bin; + if (right_group && right_group->Count()) { + right_itr = CreateIteratorState(*right_group); + rhs_scanner = make_uniq(*right_group, *gsink.hashed_sorts[1]); } } } bool AsOfProbeBuffer::NextLeft() { - if (!HasMoreData()) { - return false; - } - // Scan the next sorted chunk lhs_scanned.Reset(); - left_itr->SetIndex(lhs_scanner->Scanned()); - lhs_scanner->Scan(lhs_scanned); + if (!lhs_scanner || !lhs_scanner->Scan(lhs_scanned)) { + return false; + } // Compute the join keys lhs_keys.Reset(); @@ -418,54 +582,62 @@ bool AsOfProbeBuffer::NextLeft() { void AsOfProbeBuffer::EndLeftScan() { auto &gsink = op.sink_state->Cast(); - right_hash = nullptr; + right_group = nullptr; right_itr.reset(); rhs_scanner.reset(); right_outer = nullptr; - auto &rhs_sink = *gsink.partition_sinks[1]; - if (!gsink.is_outer && right_group < rhs_sink.bin_groups.size()) { - rhs_sink.hash_groups[right_group].reset(); + auto &rhs_groups = gsink.hash_groups[1]; + if (!gsink.is_outer && right_bin < rhs_groups.size()) { + rhs_groups[right_bin].reset(); } - left_hash = nullptr; + left_group = nullptr; left_itr.reset(); lhs_scanner.reset(); - auto &lhs_sink = *gsink.partition_sinks[0]; - if (left_group < lhs_sink.bin_groups.size()) { - lhs_sink.hash_groups[left_group].reset(); + auto &lhs_groups = gsink.hash_groups[0]; + if (left_bin < lhs_groups.size()) { + lhs_groups[left_bin].reset(); } } +template void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { + using SORT_KEY = SortKey; + using BLOCKS_ITERATOR = block_iterator_t; + // If there was no right partition, there are no matches lhs_match_count = 0; if (!right_itr) { return; } + Repin(*left_itr); + BLOCKS_ITERATOR left_key(*left_itr); + + Repin(*right_itr); + BLOCKS_ITERATOR right_key(*right_itr); + const auto count = lhs_payload.size(); - const auto left_base = left_itr->GetIndex(); + const auto left_base = lhs_scanner->Base(); + auto lhs_sel = (count == lhs_scanned.size()) ? FlatVector::IncrementalSelectionVector() : &lhs_scan_sel; // Searching for right <= left for (idx_t i = 0; i < count; ++i) { - left_itr->SetIndex(left_base + i); - // If right > left, then there is no match - if (!right_itr->Compare(*left_itr)) { + const auto left_pos = left_base + lhs_sel->get_index(i); + if (!Compare(right_key[right_pos], left_key[left_pos], strict)) { continue; } // Exponential search forward for a non-matching value using radix iterators // (We use exponential search to avoid thrashing the block manager on large probes) idx_t bound = 1; - idx_t begin = right_itr->GetIndex(); - right_itr->SetIndex(begin + bound); - while (right_itr->GetIndex() < right_hash->count) { - if (right_itr->Compare(*left_itr)) { + idx_t begin = right_pos; + while (begin + bound < right_group->Count()) { + if (Compare(right_key[begin + bound], left_key[left_pos], strict)) { // If right <= left, jump ahead bound *= 2; - right_itr->SetIndex(begin + bound); } else { break; } @@ -474,23 +646,22 @@ void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { // Binary search for the first non-matching value using radix iterators // The previous value (which we know exists) is the match auto first = begin + bound / 2; - auto last = MinValue(begin + bound, right_hash->count); + auto last = MinValue(begin + bound, right_group->Count()); while (first < last) { const auto mid = first + (last - first) / 2; - right_itr->SetIndex(mid); - if (right_itr->Compare(*left_itr)) { + if (Compare(right_key[mid], left_key[left_pos], strict)) { // If right <= left, new lower bound first = mid + 1; } else { last = mid; } } - right_itr->SetIndex(--first); + right_pos = --first; - // Check partitions for strict equality - if (right_hash->ComparePartitions(*left_itr, *right_itr)) { - continue; - } + // TODO: Check partitions for strict equality + // if (right_group->ComparePartitions(*left_itr, *right_itr)) { + // continue; + // } // Emit match data if (found_match) { @@ -506,7 +677,7 @@ void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join bool found_match[STANDARD_VECTOR_SIZE] = {false}; - ResolveJoin(found_match); + (this->*resolve_join_func)(found_match, nullptr); // now construct the result based on the join result switch (op.join_type) { @@ -521,11 +692,22 @@ void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &ch } } +static idx_t SliceSelectionVector(SelectionVector &target, const SelectionVector &source, const idx_t count) { + idx_t result = 0; + for (idx_t i = 0; i < count; ++i) { + target.set_index(result++, target.get_index(source.get_index(i))); + } + + return result; +} + void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join idx_t matches[STANDARD_VECTOR_SIZE]; - ResolveJoin(nullptr, matches); + (this->*resolve_join_func)(nullptr, matches); + // Extract the rhs input columns from the match + rhs_input.Reset(); for (idx_t i = 0; i < lhs_match_count; ++i) { const auto idx = lhs_match_sel[i]; const auto match_pos = matches[idx]; @@ -537,30 +719,81 @@ void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &c // Append the individual values // TODO: Batch the copies const auto source_offset = match_pos - (rhs_scanner->Scanned() - rhs_payload.size()); - for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { - const auto rhs_idx = op.right_projection_map[col_idx]; - auto &source = rhs_payload.data[rhs_idx]; - auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; + for (column_t col_idx = 0; col_idx < rhs_payload.data.size(); ++col_idx) { + auto &source = rhs_payload.data[col_idx]; + auto &target = rhs_input.data[col_idx]; VectorOperations::Copy(source, target, source_offset + 1, source_offset, i); } } + rhs_input.SetCardinality(lhs_match_count); // Slice the left payload into the result for (column_t i = 0; i < lhs_payload.ColumnCount(); ++i) { chunk.data[i].Slice(lhs_payload.data[i], lhs_match_sel, lhs_match_count); } + + // Reference the projected right payload into the result + for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { + const auto rhs_idx = op.right_projection_map[col_idx]; + auto &source = rhs_input.data[rhs_idx]; + auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; + target.Reference(source); + } chunk.SetCardinality(lhs_match_count); - auto match_sel = &lhs_match_sel; + + // Filter out partition mismatches + const auto equal_cols = op.conditions.size() - 1; + if (equal_cols) { + // Prepare the lhs keys + if (lhs_match_count < lhs_keys.size()) { + lhs_keys.Slice(lhs_match_sel, lhs_match_count); + } + + rhs_keys.Reset(); + rhs_executor.Execute(rhs_input, rhs_keys); + + auto sel = FlatVector::IncrementalSelectionVector(); + auto tail_count = lhs_match_count; + for (size_t cmp_idx = 0; cmp_idx < equal_cols; ++cmp_idx) { + auto &left = lhs_keys.data[cmp_idx]; + auto &right = rhs_keys.data[cmp_idx]; + if (tail_count < rhs_keys.size()) { + left.Slice(*sel, tail_count); + right.Slice(*sel, tail_count); + } + tail_count = PhysicalRangeJoin::SelectJoinTail(op.conditions[cmp_idx].comparison, left, right, sel, + tail_count, &tail_sel); + sel = &tail_sel; + } + + // Did anything get filtered out? + if (tail_count < lhs_match_count) { + if (tail_count == 0) { + // Need to reset here otherwise we may use the non-flat chunk when constructing LEFT/OUTER + chunk.Reset(); + lhs_match_count = tail_count; + } else { + chunk.Slice(*sel, tail_count); + // Slice lhs_match_sel to the remaining lhs rows + lhs_match_count = SliceSelectionVector(lhs_match_sel, *sel, tail_count); + } + } + } + + // Apply the predicate filter + // TODO: This is wrong - we have to search for a match if (filterer.expressions.size() == 1) { - lhs_match_count = filterer.SelectExpression(chunk, filter_sel); - chunk.Slice(filter_sel, lhs_match_count); - match_sel = &filter_sel; + const auto filter_count = filterer.SelectExpression(chunk, filter_sel); + if (filter_count < chunk.size()) { + chunk.Slice(filter_sel, filter_count); + lhs_match_count = SliceSelectionVector(lhs_match_sel, filter_sel, filter_count); + } } // Update the match masks for the rows we ended up with left_outer.Reset(); for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = match_sel->get_index(i); + const auto idx = lhs_match_sel.get_index(i); left_outer.SetMatch(idx); const auto first = matches[idx]; right_outer->SetMatch(first); @@ -608,20 +841,7 @@ void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { class AsOfGlobalSourceState : public GlobalSourceState { public: - explicit AsOfGlobalSourceState(AsOfGlobalSinkState &gsink_p) - : gsink(gsink_p), next_left(0), flushed(0), next_right(0) { - - if (gsink.child == 1) { - // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple - auto &rhs_partition = *gsink.partition_sinks[gsink.child]; - auto &right_outers = gsink.right_outers; - right_outers.reserve(rhs_partition.hash_groups.size()); - for (const auto &hash_group : rhs_partition.hash_groups) { - right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); - right_outers.back().Initialize(hash_group->count); - } - } - } + AsOfGlobalSourceState(ClientContext &client, AsOfGlobalSinkState &gsink_p); AsOfGlobalSinkState &gsink; //! The next buffer to flush @@ -633,18 +853,45 @@ class AsOfGlobalSourceState : public GlobalSourceState { public: idx_t MaxThreads() override { - return gsink.local_buffers[0].size(); + return gsink.hash_groups[1].size(); } }; -unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &context) const { +AsOfGlobalSourceState::AsOfGlobalSourceState(ClientContext &client, AsOfGlobalSinkState &gsink_p) + : gsink(gsink_p), next_left(0), flushed(0), next_right(0) { + + // Take ownership of the hash groups + for (idx_t child = 0; child < 2; ++child) { + auto &hashed_sort = *gsink.hashed_sorts[child]; + auto &hashed_sink = *gsink.hashed_sinks[child]; + auto hashed_source = hashed_sort.GetGlobalSourceState(client, hashed_sink); + auto &sorted_runs = hashed_sort.GetSortedRuns(*hashed_source); + auto &hash_groups = gsink.hash_groups[child]; + hash_groups.resize(sorted_runs.size()); + + for (idx_t group_idx = 0; group_idx < sorted_runs.size(); ++group_idx) { + hash_groups[group_idx] = std::move(sorted_runs[group_idx]); + } + } + + // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple + auto &rhs_partition = gsink.hash_groups[1]; + auto &right_outers = gsink.right_outers; + right_outers.reserve(rhs_partition.size()); + for (const auto &hash_group : rhs_partition) { + right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); + right_outers.back().Initialize(hash_group ? hash_group->Count() : 0); + } +} + +unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &client) const { auto &gsink = sink_state->Cast(); - return make_uniq(gsink); + return make_uniq(client, gsink); } class AsOfLocalSourceState : public LocalSourceState { public: - using HashGroupPtr = unique_ptr; + using HashGroupPtr = unique_ptr; AsOfLocalSourceState(ExecutionContext &context, AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op); @@ -660,7 +907,7 @@ class AsOfLocalSourceState : public LocalSourceState { idx_t hash_bin; HashGroupPtr hash_group; //! The read cursor - unique_ptr scanner; + unique_ptr scanner; //! Pointer to the right marker const bool *rhs_matches = {}; }; @@ -673,12 +920,16 @@ AsOfLocalSourceState::AsOfLocalSourceState(ExecutionContext &context, AsOfGlobal idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { hash_bin = hash_bin_p; - auto &rhs_sink = *gsource.gsink.partition_sinks[1]; - hash_group = std::move(rhs_sink.hash_groups[hash_bin]); - if (hash_group->global_sort->sorted_blocks.empty()) { + auto &rhs_groups = gsource.gsink.hash_groups[1]; + if (hash_bin >= rhs_groups.size()) { return 0; } - scanner = make_uniq(*hash_group->global_sort); + + hash_group = std::move(rhs_groups[hash_bin]); + if (!hash_group || !hash_group->Count()) { + return 0; + } + scanner = make_uniq(*hash_group, *gsource.gsink.hashed_sorts[1]); rhs_matches = gsource.gsink.right_outers[hash_bin].GetMatches(); @@ -695,12 +946,12 @@ SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk OperatorSourceInput &input) const { auto &gsource = input.global_state.Cast(); auto &lsource = input.local_state.Cast(); - auto &rhs_sink = *gsource.gsink.partition_sinks[1]; + auto &rhs_groups = gsource.gsink.hash_groups[1]; auto &client = context.client; // Step 1: Join the partitions - auto &lhs_sink = *gsource.gsink.partition_sinks[0]; - const auto left_bins = lhs_sink.grouping_data ? lhs_sink.grouping_data->GetPartitions().size() : 1; + auto &lhs_groups = gsource.gsink.hash_groups[0]; + const auto left_bins = lhs_groups.size(); while (gsource.flushed < left_bins) { // Make sure we have something to flush if (!lsource.probe_buffer.Scanning()) { @@ -736,11 +987,8 @@ SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk return SourceResultType::FINISHED; } - auto &hash_groups = rhs_sink.hash_groups; - const auto right_groups = hash_groups.size(); - DataChunk rhs_chunk; - rhs_chunk.Initialize(context.client, rhs_sink.payload_types); + rhs_chunk.Initialize(context.client, children[1].get().GetTypes()); SelectionVector rsel(STANDARD_VECTOR_SIZE); while (chunk.size() == 0) { @@ -749,12 +997,12 @@ SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk lsource.scanner.reset(); lsource.hash_group.reset(); auto hash_bin = gsource.next_right++; - if (hash_bin >= right_groups) { + if (hash_bin >= rhs_groups.size()) { return SourceResultType::FINISHED; } - for (; hash_bin < hash_groups.size(); hash_bin = gsource.next_right++) { - if (hash_groups[hash_bin]) { + for (; hash_bin < rhs_groups.size(); hash_bin = gsource.next_right++) { + if (rhs_groups[hash_bin]) { break; } } diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index f797f6910..9eb76671a 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev732" +#define DUCKDB_PATCH_VERSION "0-dev828" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev732" +#define DUCKDB_VERSION "v1.5.0-dev828" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "f793ea27c6" +#define DUCKDB_SOURCE_ID "353406bd7f" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index b43824c46..9854e49f4 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -202,7 +202,7 @@ enum class FunctionStability : uint8_t; enum class GateStatus : uint8_t; -enum class GeometryType : uint32_t; +enum class GeometryType : uint8_t; enum class HLLStorageType : uint8_t; @@ -296,8 +296,6 @@ enum class ParseInfoType : uint8_t; enum class ParserExtensionResultType : uint8_t; -enum class PartitionSortStage : uint8_t; - enum class PartitionedColumnDataType : uint8_t; enum class PartitionedTupleDataType : uint8_t; @@ -442,6 +440,8 @@ enum class VerificationType : uint8_t; enum class VerifyExistenceType : uint8_t; +enum class VertexType : uint8_t; + enum class WALType : uint8_t; enum class WindowAggregationMode : uint32_t; @@ -849,9 +849,6 @@ const char* EnumUtil::ToChars(ParseInfoType value); template<> const char* EnumUtil::ToChars(ParserExtensionResultType value); -template<> -const char* EnumUtil::ToChars(PartitionSortStage value); - template<> const char* EnumUtil::ToChars(PartitionedColumnDataType value); @@ -1068,6 +1065,9 @@ const char* EnumUtil::ToChars(VerificationType value); template<> const char* EnumUtil::ToChars(VerifyExistenceType value); +template<> +const char* EnumUtil::ToChars(VertexType value); + template<> const char* EnumUtil::ToChars(WALType value); @@ -1480,9 +1480,6 @@ ParseInfoType EnumUtil::FromString(const char *value); template<> ParserExtensionResultType EnumUtil::FromString(const char *value); -template<> -PartitionSortStage EnumUtil::FromString(const char *value); - template<> PartitionedColumnDataType EnumUtil::FromString(const char *value); @@ -1699,6 +1696,9 @@ VerificationType EnumUtil::FromString(const char *value); template<> VerifyExistenceType EnumUtil::FromString(const char *value); +template<> +VertexType EnumUtil::FromString(const char *value); + template<> WALType EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp index bb0760897..d0d82b4c2 100644 --- a/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/metric_type.hpp @@ -64,6 +64,7 @@ enum class MetricsType : uint8_t { OPTIMIZER_BUILD_SIDE_PROBE_SIDE, OPTIMIZER_LIMIT_PUSHDOWN, OPTIMIZER_TOP_N, + OPTIMIZER_TOP_N_WINDOW_ELIMINATION, OPTIMIZER_COMPRESSED_MATERIALIZATION, OPTIMIZER_DUPLICATE_GROUPS, OPTIMIZER_REORDER_FILTER, diff --git a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp index 36bb672c9..82675c7d5 100644 --- a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp @@ -33,6 +33,7 @@ enum class OptimizerType : uint32_t { BUILD_SIDE_PROBE_SIDE, LIMIT_PUSHDOWN, TOP_N, + TOP_N_WINDOW_ELIMINATION, COMPRESSED_MATERIALIZATION, DUPLICATE_GROUPS, REORDER_FILTER, diff --git a/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp b/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp deleted file mode 100644 index 8170875e8..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/partition_state.hpp +++ /dev/null @@ -1,245 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/partition_state.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" - -namespace duckdb { - -class PartitionGlobalHashGroup { -public: - using GlobalSortStatePtr = unique_ptr; - using Orders = vector; - using Types = vector; - using OrderMasks = unordered_map; - - PartitionGlobalHashGroup(ClientContext &context, const Orders &partitions, const Orders &orders, - const Types &payload_types, bool external); - - inline int ComparePartitions(const SBIterator &left, const SBIterator &right) { - int part_cmp = 0; - if (partition_layout.all_constant) { - part_cmp = FastMemcmp(left.entry_ptr, right.entry_ptr, partition_layout.comparison_size); - } else { - part_cmp = Comparators::CompareTuple(left.scan, right.scan, left.entry_ptr, right.entry_ptr, - partition_layout, left.external); - } - return part_cmp; - } - - void ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks); - - GlobalSortStatePtr global_sort; - atomic count; - - // Mask computation - SortLayout partition_layout; -}; - -class PartitionGlobalSinkState { -public: - using HashGroupPtr = unique_ptr; - using Orders = vector; - using Types = vector; - - using GroupingPartition = unique_ptr; - using GroupingAppend = unique_ptr; - - static void GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, const Orders &order_bys, - const vector> &partitions_stats); - - PartitionGlobalSinkState(ClientContext &context, const vector> &partition_bys, - const vector &order_bys, const Types &payload_types, - const vector> &partitions_stats, idx_t estimated_cardinality); - virtual ~PartitionGlobalSinkState() = default; - - bool HasMergeTasks() const; - - unique_ptr CreatePartition(idx_t new_bits) const; - void SyncPartitioning(const PartitionGlobalSinkState &other); - - void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); - - virtual void OnBeginMerge() {}; - virtual void OnSortedPartition(const idx_t hash_bin_p) {}; - - ClientContext &context; - BufferManager &buffer_manager; - Allocator &allocator; - mutex lock; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr grouping_data; - //! Payload plus hash column - shared_ptr grouping_types_ptr; - //! The number of radix bits if this partition is being synced with another - idx_t fixed_bits; - - // OVER(...) (sorting) - Orders partitions; - Orders orders; - const Types payload_types; - vector hash_groups; - bool external; - // Reverse lookup from hash bins to non-empty hash groups - vector bin_groups; - - // OVER() (no sorting) - unique_ptr rows; - unique_ptr strings; - - // Threading - idx_t memory_per_thread; - idx_t max_bits; - atomic count; - -private: - void ResizeGroupingData(idx_t cardinality); - void SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); -}; - -class PartitionLocalSinkState { -public: - using LocalSortStatePtr = unique_ptr; - - PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p); - - // Global state - PartitionGlobalSinkState &gstate; - Allocator &allocator; - - // Shared expression evaluation - ExpressionExecutor executor; - DataChunk group_chunk; - DataChunk payload_chunk; - size_t sort_cols; - - // OVER(PARTITION BY...) (hash grouping) - unique_ptr local_partition; - unique_ptr local_append; - - // OVER(ORDER BY...) (only sorting) - LocalSortStatePtr local_sort; - - // OVER() (no sorting) - RowLayout payload_layout; - unique_ptr rows; - unique_ptr strings; - - //! Compute the hash values - void Hash(DataChunk &input_chunk, Vector &hash_vector); - //! Sink an input chunk - void Sink(DataChunk &input_chunk); - //! Merge the state into the global state. - void Combine(); -}; - -enum class PartitionSortStage : uint8_t { INIT, SCAN, PREPARE, MERGE, SORTED, FINISHED }; - -class PartitionLocalMergeState; - -class PartitionGlobalMergeState { -public: - using GroupDataPtr = unique_ptr; - - // OVER(PARTITION BY...) - PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data, hash_t hash_bin); - - // OVER(ORDER BY...) - explicit PartitionGlobalMergeState(PartitionGlobalSinkState &sink); - - bool IsFinished() const { - return stage == PartitionSortStage::FINISHED; - } - - bool AssignTask(PartitionLocalMergeState &local_state); - bool TryPrepareNextStage(); - void CompleteTask(); - - PartitionGlobalSinkState &sink; - GroupDataPtr group_data; - PartitionGlobalHashGroup *hash_group; - const idx_t group_idx; - vector column_ids; - TupleDataParallelScanState chunk_state; - GlobalSortState *global_sort; - const idx_t memory_per_thread; - const idx_t num_threads; - -private: - mutable mutex lock; - atomic stage; - idx_t total_tasks; - idx_t tasks_assigned; - idx_t tasks_completed; -}; - -class PartitionLocalMergeState { -public: - explicit PartitionLocalMergeState(PartitionGlobalSinkState &gstate); - - bool TaskFinished() { - return finished; - } - - void Prepare(); - void Scan(); - void Merge(); - void Sorted(); - - void ExecuteTask(); - - PartitionGlobalMergeState *merge_state; - PartitionSortStage stage; - atomic finished; - - // Sorting buffers - ExpressionExecutor executor; - DataChunk sort_chunk; - DataChunk payload_chunk; -}; - -class PartitionGlobalMergeStates { -public: - struct Callback { - virtual ~Callback() = default; - - virtual bool HasError() const { - return false; - } - }; - - using PartitionGlobalMergeStatePtr = unique_ptr; - - explicit PartitionGlobalMergeStates(PartitionGlobalSinkState &sink); - - bool ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback); - - vector states; -}; - -class PartitionMergeEvent : public BasePipelineEvent { -public: - PartitionMergeEvent(PartitionGlobalSinkState &gstate_p, Pipeline &pipeline_p, const PhysicalOperator &op_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p), merge_states(gstate_p), op(op_p) { - } - - PartitionGlobalSinkState &gstate; - PartitionGlobalMergeStates merge_states; - const PhysicalOperator &op; - -public: - void Schedule() override; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp index 374133692..022b4c606 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp @@ -17,6 +17,7 @@ class HashedSort { using Orders = vector; using Types = vector; using HashGroupPtr = unique_ptr; + using SortedRunPtr = unique_ptr; static void GenerateOrderings(Orders &partitions, Orders &orders, const vector> &partition_bys, const Orders &order_bys, @@ -24,7 +25,8 @@ class HashedSort { HashedSort(ClientContext &context, const vector> &partition_bys, const vector &order_bys, const Types &payload_types, - const vector> &partitions_stats, idx_t estimated_cardinality); + const vector> &partitions_stats, idx_t estimated_cardinality, + bool require_payload = false); public: //===--------------------------------------------------------------------===// @@ -37,6 +39,7 @@ class HashedSort { SinkFinalizeType Finalize(ClientContext &client, OperatorSinkFinalizeInput &finalize) const; ProgressData GetSinkProgress(ClientContext &context, GlobalSinkState &gstate, const ProgressData source_progress) const; + void Synchronize(const GlobalSinkState &source, GlobalSinkState &target) const; public: //===--------------------------------------------------------------------===// @@ -53,6 +56,10 @@ class HashedSort { OperatorSinkFinalizeInput &finalize) const; vector &GetHashGroups(GlobalSourceState &global_state) const; + SinkFinalizeType MaterializeSortedRuns(Pipeline &pipeline, Event &event, const PhysicalOperator &op, + OperatorSinkFinalizeInput &finalize) const; + vector &GetSortedRuns(GlobalSourceState &global_state) const; + public: ClientContext &client; //! The host's estimated row count @@ -63,6 +70,8 @@ class HashedSort { Orders orders; idx_t sort_col_count; Types payload_types; + //! Are we creating a dummy payload column? + bool force_payload = false; // Input columns in the sorted output vector scan_ids; // Key columns in the sorted output diff --git a/src/duckdb/src/include/duckdb/common/types/geometry.hpp b/src/duckdb/src/include/duckdb/common/types/geometry.hpp index 2cca6fe29..5b9bcd1f8 100644 --- a/src/duckdb/src/include/duckdb/common/types/geometry.hpp +++ b/src/duckdb/src/include/duckdb/common/types/geometry.hpp @@ -10,10 +10,15 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/common/pair.hpp" +#include +#include namespace duckdb { -enum class GeometryType : uint32_t { +struct GeometryStatsData; + +enum class GeometryType : uint8_t { INVALID = 0, POINT = 1, LINESTRING = 2, @@ -24,12 +29,176 @@ enum class GeometryType : uint32_t { GEOMETRYCOLLECTION = 7, }; +enum class VertexType : uint8_t { XY = 0, XYZ = 1, XYM = 2, XYZM = 3 }; + +struct VertexXY { + static constexpr auto TYPE = VertexType::XY; + static constexpr auto HAS_Z = false; + static constexpr auto HAS_M = false; + + double x; + double y; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y); + } +}; + +struct VertexXYZ { + static constexpr auto TYPE = VertexType::XYZ; + static constexpr auto HAS_Z = true; + static constexpr auto HAS_M = false; + + double x; + double y; + double z; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(z); + } +}; +struct VertexXYM { + static constexpr auto TYPE = VertexType::XYM; + static constexpr auto HAS_M = true; + static constexpr auto HAS_Z = false; + + double x; + double y; + double m; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(m); + } +}; + +struct VertexXYZM { + static constexpr auto TYPE = VertexType::XYZM; + static constexpr auto HAS_Z = true; + static constexpr auto HAS_M = true; + + double x; + double y; + double z; + double m; + + bool AllNan() const { + return std::isnan(x) && std::isnan(y) && std::isnan(z) && std::isnan(m); + } +}; + +class GeometryExtent { +public: + static constexpr auto UNKNOWN_MIN = -std::numeric_limits::infinity(); + static constexpr auto UNKNOWN_MAX = +std::numeric_limits::infinity(); + + static constexpr auto EMPTY_MIN = +std::numeric_limits::infinity(); + static constexpr auto EMPTY_MAX = -std::numeric_limits::infinity(); + + // "Unknown" extent means we don't know the bounding box. + // Merging with an unknown extent results in an unknown extent. + // Everything intersects with an unknown extent. + static GeometryExtent Unknown() { + return GeometryExtent {UNKNOWN_MIN, UNKNOWN_MIN, UNKNOWN_MIN, UNKNOWN_MIN, + UNKNOWN_MAX, UNKNOWN_MAX, UNKNOWN_MAX, UNKNOWN_MAX}; + } + + // "Empty" extent means the smallest possible bounding box. + // Merging with an empty extent has no effect. + // Nothing intersects with an empty extent. + static GeometryExtent Empty() { + return GeometryExtent {EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX}; + } + + // Does this extent have any X/Y values set? + // In other words, is the range of the x/y axes not empty and not unknown? + bool HasXY() const { + return std::isfinite(x_min) && std::isfinite(y_min) && std::isfinite(x_max) && std::isfinite(y_max); + } + // Does this extent have any Z values set? + // In other words, is the range of the Z-axis not empty and not unknown? + bool HasZ() const { + return std::isfinite(z_min) && std::isfinite(z_max); + } + // Does this extent have any M values set? + // In other words, is the range of the M-axis not empty and not unknown? + bool HasM() const { + return std::isfinite(m_min) && std::isfinite(m_max); + } + + void Extend(const VertexXY &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + } + + void Extend(const VertexXYZ &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + z_min = MinValue(z_min, vertex.z); + z_max = MaxValue(z_max, vertex.z); + } + + void Extend(const VertexXYM &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + m_min = MinValue(m_min, vertex.m); + m_max = MaxValue(m_max, vertex.m); + } + + void Extend(const VertexXYZM &vertex) { + x_min = MinValue(x_min, vertex.x); + x_max = MaxValue(x_max, vertex.x); + y_min = MinValue(y_min, vertex.y); + y_max = MaxValue(y_max, vertex.y); + z_min = MinValue(z_min, vertex.z); + z_max = MaxValue(z_max, vertex.z); + m_min = MinValue(m_min, vertex.m); + m_max = MaxValue(m_max, vertex.m); + } + + void Merge(const GeometryExtent &other) { + x_min = MinValue(x_min, other.x_min); + y_min = MinValue(y_min, other.y_min); + z_min = MinValue(z_min, other.z_min); + m_min = MinValue(m_min, other.m_min); + + x_max = MaxValue(x_max, other.x_max); + y_max = MaxValue(y_max, other.y_max); + z_max = MaxValue(z_max, other.z_max); + m_max = MaxValue(m_max, other.m_max); + } + + double x_min; + double y_min; + double z_min; + double m_min; + + double x_max; + double y_max; + double z_max; + double m_max; +}; + class Geometry { public: static constexpr auto MAX_RECURSION_DEPTH = 16; + //! Convert from WKT DUCKDB_API static bool FromString(const string_t &wkt_text, string_t &result, Vector &result_vector, bool strict); + + //! Convert to WKT DUCKDB_API static string_t ToString(Vector &result, const string_t &geom); + + //! Get the geometry type and vertex type from the WKB + DUCKDB_API static pair GetType(const string_t &wkb); + + //! Update the bounding box, return number of vertices processed + DUCKDB_API static uint32_t GetExtent(const string_t &wkb, GeometryExtent &extent); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp index 85463b954..1edb36ed4 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/execution/operator/join/physical_piecewise_merge_join.hpp +// duckdb/execution/operator/join/physical_range_join.hpp // // //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/main/config.hpp b/src/duckdb/src/include/duckdb/main/config.hpp index 9a685f560..8eb2e5576 100644 --- a/src/duckdb/src/include/duckdb/main/config.hpp +++ b/src/duckdb/src/include/duckdb/main/config.hpp @@ -110,6 +110,8 @@ struct DBConfigOptions { #else bool autoinstall_known_extensions = false; #endif + //! Setting for the parser override registered by extensions. Allowed options: "default, "fallback", "strict" + string allow_parser_override_extension = "default"; //! Override for the default extension repository string custom_extension_repo = ""; //! Override for the default autoload extension repository diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 383d5533b..1c96771ee 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -95,6 +95,18 @@ struct AllowExtensionsMetadataMismatchSetting { static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; +struct AllowParserOverrideExtensionSetting { + using RETURN_TYPE = string; + static constexpr const char *Name = "allow_parser_override_extension"; + static constexpr const char *Description = "Allow extensions to override the current parser"; + static constexpr const char *InputType = "VARCHAR"; + static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); + static void ResetGlobal(DatabaseInstance *db, DBConfig &config); + static bool OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input); + static bool OnGlobalReset(DatabaseInstance *db, DBConfig &config); + static Value GetSetting(const ClientContext &context); +}; + struct AllowPersistentSecretsSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "allow_persistent_secrets"; diff --git a/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp b/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp new file mode 100644 index 000000000..fcb50bac5 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/topn_window_elimination.hpp @@ -0,0 +1,68 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/topn_window_elimination.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/main/client_context.hpp" +#include "duckdb/optimizer/column_binding_replacer.hpp" +#include "duckdb/optimizer/remove_unused_columns.hpp" + +namespace duckdb { + +enum class TopNPayloadType { SINGLE_COLUMN, STRUCT_PACK }; + +struct TopNWindowEliminationParameters { + //! Whether the sort is ASCENDING or DESCENDING + OrderType order_type; + //! The number of values in the LIMIT clause + int64_t limit; + //! How we fetch the payload columns + TopNPayloadType payload_type; + //! Whether to include row numbers + bool include_row_number; +}; + +class TopNWindowElimination : public BaseColumnPruner { +public: + explicit TopNWindowElimination(ClientContext &context, Optimizer &optimizer, + optional_ptr>> stats_p); + + unique_ptr Optimize(unique_ptr op); + +private: + bool CanOptimize(LogicalOperator &op); + unique_ptr OptimizeInternal(unique_ptr op, ColumnBindingReplacer &replacer); + + unique_ptr CreateAggregateOperator(LogicalWindow &window, vector> args, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr TryCreateUnnestOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms) const; + unique_ptr CreateProjectionOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms, + const map &group_idxs) const; + + vector> GenerateAggregatePayload(const vector &bindings, + const LogicalWindow &window, map &group_idxs); + vector TraverseProjectionBindings(const std::vector &old_bindings, + LogicalOperator *&op); + unique_ptr CreateAggregateExpression(vector> aggregate_params, bool requires_arg, + OrderType order_type) const; + unique_ptr CreateRowNumberGenerator(unique_ptr aggregate_column_ref) const; + void AddStructExtractExprs(vector> &exprs, const LogicalType &struct_type, + const unique_ptr &aggregate_column_ref) const; + static void UpdateTopmostBindings(idx_t window_idx, const unique_ptr &op, + const map &group_idxs, + const vector &topmost_bindings, + vector &new_bindings, ColumnBindingReplacer &replacer); + +private: + ClientContext &context; + Optimizer &optimizer; + optional_ptr>> stats; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp index 5540d38a2..12c4f77ca 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/vacuum_info.hpp @@ -10,7 +10,6 @@ #include "duckdb/parser/parsed_data/parse_info.hpp" #include "duckdb/parser/tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/optional_ptr.hpp" #include "duckdb/catalog/dependency_list.hpp" diff --git a/src/duckdb/src/include/duckdb/parser/parser_options.hpp b/src/duckdb/src/include/duckdb/parser/parser_options.hpp index d388fb116..d9a42632a 100644 --- a/src/duckdb/src/include/duckdb/parser/parser_options.hpp +++ b/src/duckdb/src/include/duckdb/parser/parser_options.hpp @@ -18,6 +18,7 @@ struct ParserOptions { bool integer_division = false; idx_t max_expression_depth = 1000; const vector *extensions = nullptr; + string parser_override_setting = "default"; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp b/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp index be997eb5d..c8586a559 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/bound_ref_wrapper.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/parser/tableref.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/binder.hpp" namespace duckdb { @@ -20,10 +19,10 @@ class BoundRefWrapper : public TableRef { static constexpr const TableReferenceType TYPE = TableReferenceType::BOUND_TABLE_REF; public: - BoundRefWrapper(unique_ptr bound_ref_p, shared_ptr binder_p); + BoundRefWrapper(BoundStatement bound_ref_p, shared_ptr binder_p); //! The bound reference object - unique_ptr bound_ref; + BoundStatement bound_ref; //! The binder that was used to bind this table ref shared_ptr binder; diff --git a/src/duckdb/src/include/duckdb/planner/bind_context.hpp b/src/duckdb/src/include/duckdb/planner/bind_context.hpp index d4d487400..b17f2666a 100644 --- a/src/duckdb/src/include/duckdb/planner/bind_context.hpp +++ b/src/duckdb/src/include/duckdb/planner/bind_context.hpp @@ -43,9 +43,6 @@ class BindContext { public: explicit BindContext(Binder &binder); - //! Keep track of recursive CTE references - case_insensitive_map_t> cte_references; - public: //! Given a column name, find the matching table it belongs to. Throws an //! exception if no table has a column of the given name. @@ -122,8 +119,6 @@ class BindContext { void AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types, bool using_key = false); - void RemoveCTEBinding(const string &alias); - //! Add an implicit join condition (e.g. USING (x)) void AddUsingBinding(const string &column_name, UsingColumnSet &set); @@ -146,13 +141,6 @@ class BindContext { string GetActualColumnName(const BindingAlias &binding_alias, const string &column_name); string GetActualColumnName(Binding &binding, const string &column_name); - case_insensitive_map_t> GetCTEBindings() { - return cte_bindings; - } - void SetCTEBindings(case_insensitive_map_t> bindings) { - cte_bindings = std::move(bindings); - } - //! Alias a set of column names for the specified table, using the original names if there are not enough aliases //! specified. static vector AliasColumnNames(const string &table_name, const vector &names, @@ -184,10 +172,7 @@ class BindContext { vector> bindings_list; //! The set of columns used in USING join conditions case_insensitive_map_t> using_columns; - //! Using column sets - vector> using_column_sets; - //! The set of CTE bindings - case_insensitive_map_t> cte_bindings; + case_insensitive_map_t> cte_bindings; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 8f1608112..5603ab36b 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -27,7 +27,6 @@ #include "duckdb/planner/joinside.hpp" #include "duckdb/planner/bound_constraint.hpp" #include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" #include "duckdb/common/enums/copy_option_mode.hpp" //! fwd declare @@ -160,6 +159,30 @@ struct CorrelatedColumns { idx_t delim_index; }; +//! GlobalBinderState is state shared over the ENTIRE query, including subqueries, views, etc +struct GlobalBinderState { + //! The count of bound_tables + idx_t bound_tables = 0; + //! Statement properties + StatementProperties prop; + //! Binding mode + BindingMode mode = BindingMode::STANDARD_BINDING; + //! Table names extracted for BindingMode::EXTRACT_NAMES or BindingMode::EXTRACT_QUALIFIED_NAMES. + unordered_set table_names; + //! Replacement Scans extracted for BindingMode::EXTRACT_REPLACEMENT_SCANS + case_insensitive_map_t> replacement_scans; + //! Using column sets + vector> using_column_sets; +}; + +// QueryBinderState is state shared WITHIN a query, a new query-binder state is created when binding inside e.g. a view +struct QueryBinderState { + //! The vector of active binders + vector> active_binders; + //! The set of parameter expressions bound by this binder + optional_ptr parameters; +}; + //! Bind the parsed query tree to the actual columns present in the catalog. /*! The binder is responsible for binding tables and columns to actual physical @@ -183,8 +206,6 @@ class Binder : public enable_shared_from_this { //! The set of correlated columns bound by this binder (FIXME: this should probably be an unordered_set and not a //! vector) CorrelatedColumns correlated_columns; - //! The set of parameter expressions bound by this binder - optional_ptr parameters; //! The alias for the currently processing subquery, if it exists string alias; //! Macro parameter bindings (if any) @@ -231,8 +252,7 @@ class Binder : public enable_shared_from_this { QueryErrorContext &error_context, string &func_name); unique_ptr BindPragma(PragmaInfo &info, QueryErrorContext error_context); - unique_ptr Bind(TableRef &ref); - unique_ptr CreatePlan(BoundTableRef &ref); + BoundStatement Bind(TableRef &ref); //! Generates an unused index for a table idx_t GenerateTableIndex(); @@ -243,7 +263,7 @@ class Binder : public enable_shared_from_this { //! Add a common table expression to the binder void AddCTE(const string &name); //! Find all candidate common table expression by name; returns empty vector if none exists - vector> FindCTE(const string &name, bool skip = false); + optional_ptr GetCTEBinding(const string &name); bool CTEExists(const string &name); @@ -288,12 +308,11 @@ class Binder : public enable_shared_from_this { void AddReplacementScan(const string &table_name, unique_ptr replacement); const unordered_set &GetTableNames(); case_insensitive_map_t> &GetReplacementScans(); - optional_ptr GetRootStatement() { - return root_statement; - } CatalogEntryRetriever &EntryRetriever() { return entry_retriever; } + optional_ptr GetParameters(); + void SetParameters(BoundParameterMap ¶meters); //! Returns a ColumnRefExpression after it was resolved (i.e. past the STAR expression/USING clauses) static optional_ptr GetResolvedColumnExpression(ParsedExpression &root_expr); @@ -310,42 +329,28 @@ class Binder : public enable_shared_from_this { private: //! The parent binder (if any) shared_ptr parent; - //! The vector of active binders - vector> active_binders; - //! The count of bound_tables - idx_t bound_tables; + //! What kind of node we are binding using this binder + BinderType binder_type = BinderType::REGULAR_BINDER; + //! Global binder state + shared_ptr global_binder_state; + //! Query binder state + shared_ptr query_binder_state; //! Whether or not the binder has any unplanned dependent joins that still need to be planned/flattened bool has_unplanned_dependent_joins = false; //! Whether or not outside dependent joins have been planned and flattened bool is_outside_flattened = true; - //! What kind of node we are binding using this binder - BinderType binder_type = BinderType::REGULAR_BINDER; //! Whether or not the binder can contain NULLs as the root of expressions bool can_contain_nulls = false; - //! The root statement of the query that is currently being parsed - optional_ptr root_statement; - //! Binding mode - BindingMode mode = BindingMode::STANDARD_BINDING; - //! Table names extracted for BindingMode::EXTRACT_NAMES or BindingMode::EXTRACT_QUALIFIED_NAMES. - unordered_set table_names; - //! Replacement Scans extracted for BindingMode::EXTRACT_REPLACEMENT_SCANS - case_insensitive_map_t> replacement_scans; //! The set of bound views reference_set_t bound_views; //! Used to retrieve CatalogEntry's CatalogEntryRetriever entry_retriever; //! Unnamed subquery index idx_t unnamed_subquery_index = 1; - //! Statement properties - StatementProperties prop; - //! Root binder - Binder &root_binder; //! Binder depth idx_t depth; private: - //! Get the root binder (binder with no parent) - Binder &GetRootBinder(); //! Determine the depth of the binder idx_t GetBinderDepth() const; //! Increase the depth of the binder @@ -363,7 +368,7 @@ class Binder : public enable_shared_from_this { void MoveCorrelatedExpressions(Binder &other); //! Tries to bind the table name with replacement scans - unique_ptr BindWithReplacementScan(ClientContext &context, BaseTableRef &ref); + BoundStatement BindWithReplacementScan(ClientContext &context, BaseTableRef &ref); template BoundStatement BindWithCTE(T &statement); @@ -423,24 +428,24 @@ class Binder : public enable_shared_from_this { BoundSetOpChild BindSetOpChild(QueryNode &child); unique_ptr BindSetOpNode(SetOperationNode &statement); - unique_ptr BindJoin(Binder &parent, TableRef &ref); - unique_ptr Bind(BaseTableRef &ref); - unique_ptr Bind(BoundRefWrapper &ref); - unique_ptr Bind(JoinRef &ref); - unique_ptr Bind(SubqueryRef &ref); - unique_ptr Bind(TableFunctionRef &ref); - unique_ptr Bind(EmptyTableRef &ref); - unique_ptr Bind(DelimGetRef &ref); - unique_ptr Bind(ExpressionListRef &ref); - unique_ptr Bind(ColumnDataRef &ref); - unique_ptr Bind(PivotRef &expr); - unique_ptr Bind(ShowRef &ref); + BoundStatement BindJoin(Binder &parent, TableRef &ref); + BoundStatement Bind(BaseTableRef &ref); + BoundStatement Bind(BoundRefWrapper &ref); + BoundStatement Bind(JoinRef &ref); + BoundStatement Bind(SubqueryRef &ref); + BoundStatement Bind(TableFunctionRef &ref); + BoundStatement Bind(EmptyTableRef &ref); + BoundStatement Bind(DelimGetRef &ref); + BoundStatement Bind(ExpressionListRef &ref); + BoundStatement Bind(ColumnDataRef &ref); + BoundStatement Bind(PivotRef &expr); + BoundStatement Bind(ShowRef &ref); unique_ptr BindPivot(PivotRef &expr, vector> all_columns); unique_ptr BindUnpivot(Binder &child_binder, PivotRef &expr, vector> all_columns, unique_ptr &where_clause); - unique_ptr BindBoundPivot(PivotRef &expr); + BoundStatement BindBoundPivot(PivotRef &expr); void ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry, vector &unpivot_entries); void ExtractUnpivotColumnName(ParsedExpression &expr, vector &result); @@ -449,26 +454,14 @@ class Binder : public enable_shared_from_this { bool BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, vector> &expressions, vector &arguments, vector ¶meters, named_parameter_map_t &named_parameters, - unique_ptr &subquery, ErrorData &error); - void BindTableInTableOutFunction(vector> &expressions, - unique_ptr &subquery); - unique_ptr BindTableFunction(TableFunction &function, vector parameters); - unique_ptr BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, - vector parameters, - named_parameter_map_t named_parameters, - vector input_table_types, - vector input_table_names); - - unique_ptr CreatePlan(BoundBaseTableRef &ref); + BoundStatement &subquery, ErrorData &error); + void BindTableInTableOutFunction(vector> &expressions, BoundStatement &subquery); + BoundStatement BindTableFunction(TableFunction &function, vector parameters); + BoundStatement BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, + vector parameters, named_parameter_map_t named_parameters, + vector input_table_types, vector input_table_names); + unique_ptr CreatePlan(BoundJoinRef &ref); - unique_ptr CreatePlan(BoundSubqueryRef &ref); - unique_ptr CreatePlan(BoundTableFunction &ref); - unique_ptr CreatePlan(BoundEmptyTableRef &ref); - unique_ptr CreatePlan(BoundExpressionListRef &ref); - unique_ptr CreatePlan(BoundColumnDataRef &ref); - unique_ptr CreatePlan(BoundCTERef &ref); - unique_ptr CreatePlan(BoundPivotRef &ref); - unique_ptr CreatePlan(BoundDelimGetRef &ref); BoundStatement BindCopyTo(CopyStatement &stmt, const CopyFunction &function, CopyToType copy_to_type); BoundStatement BindCopyFrom(CopyStatement &stmt, const CopyFunction &function); @@ -503,8 +496,6 @@ class Binder : public enable_shared_from_this { BindingAlias RetrieveUsingBinding(Binder ¤t_binder, optional_ptr current_set, const string &column_name, const string &join_side); - void AddCTEMap(CommonTableExpressionMap &cte_map); - void ExpandStarExpressions(vector> &select_list, vector> &new_select_list); void ExpandStarExpression(unique_ptr expr, vector> &new_select_list); @@ -525,16 +516,16 @@ class Binder : public enable_shared_from_this { LogicalType BindLogicalTypeInternal(const LogicalType &type, optional_ptr catalog, const string &schema); - BoundStatement BindSelectNode(SelectNode &statement, unique_ptr from_table); + BoundStatement BindSelectNode(SelectNode &statement, BoundStatement from_table); unique_ptr BindSelectNodeInternal(SelectNode &statement); - unique_ptr BindSelectNodeInternal(SelectNode &statement, unique_ptr from_table); + unique_ptr BindSelectNodeInternal(SelectNode &statement, BoundStatement from_table); unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const string &target_database_name); unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name); - unique_ptr BindShowQuery(ShowRef &ref); - unique_ptr BindShowTable(ShowRef &ref); - unique_ptr BindSummarize(ShowRef &ref); + BoundStatement BindShowQuery(ShowRef &ref); + BoundStatement BindShowTable(ShowRef &ref); + BoundStatement BindSummarize(ShowRef &ref); void BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, vector &named_column_map, vector &expected_types, diff --git a/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp b/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp deleted file mode 100644 index 0a831c54a..000000000 --- a/src/duckdb/src/include/duckdb/planner/bound_tableref.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/bound_tableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/common/enums/tableref_type.hpp" -#include "duckdb/parser/parsed_data/sample_options.hpp" - -namespace duckdb { - -class BoundTableRef { -public: - explicit BoundTableRef(TableReferenceType type) : type(type) { - } - virtual ~BoundTableRef() { - } - - //! The type of table reference - TableReferenceType type; - //! The sample options (if any) - unique_ptr sample; - -public: - template - TARGET &Cast() { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound table ref to type - table ref type mismatch"); - } - return reinterpret_cast(*this); - } -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp index bd75aac19..ac8aef099 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp @@ -45,18 +45,7 @@ class BoundWindowExpression; //===--------------------------------------------------------------------===// // TableRefs //===--------------------------------------------------------------------===// -class BoundTableRef; - -class BoundBaseTableRef; class BoundJoinRef; -class BoundSubqueryRef; -class BoundTableFunction; -class BoundEmptyTableRef; -class BoundExpressionListRef; -class BoundColumnDataRef; -class BoundCTERef; -class BoundPivotRef; - class BoundMergeIntoAction; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp index d11d94731..956c66fab 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/select_bind_state.hpp @@ -11,7 +11,6 @@ #include "duckdb/planner/bound_query_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/parser/expression_map.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/parser/group_by_node.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp index 804adc56f..52599a1c8 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_iterator.hpp @@ -14,7 +14,6 @@ #include namespace duckdb { -class BoundTableRef; class ExpressionIterator { public: diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp index b3a22966a..d94195698 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp @@ -11,7 +11,6 @@ #include "duckdb/planner/bound_query_node.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/parser/expression_map.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/parser/parsed_data/sample_options.hpp" #include "duckdb/parser/group_by_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" @@ -47,7 +46,7 @@ class BoundSelectNode : public BoundQueryNode { //! The projection list vector> select_list; //! The FROM clause - unique_ptr from_table; + BoundStatement from_table; //! The WHERE clause unique_ptr where_clause; //! list of groups diff --git a/src/duckdb/src/include/duckdb/planner/table_binding.hpp b/src/duckdb/src/include/duckdb/planner/table_binding.hpp index 9aedc7e70..671c5abed 100644 --- a/src/duckdb/src/include/duckdb/planner/table_binding.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_binding.hpp @@ -26,11 +26,10 @@ class SubqueryRef; class LogicalGet; class TableCatalogEntry; class TableFunctionCatalogEntry; -class BoundTableFunction; class StandardEntry; struct ColumnBinding; -enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY }; +enum class BindingType { BASE, TABLE, DUMMY, CATALOG_ENTRY, CTE }; //! A Binding represents a binding to a table, table-producing function or subquery with a specified table index. struct Binding { @@ -149,4 +148,14 @@ struct DummyBinding : public Binding { unique_ptr ParamToArg(ColumnRefExpression &col_ref); }; +struct CTEBinding : public Binding { +public: + static constexpr const BindingType TYPE = BindingType::CTE; + +public: + CTEBinding(BindingAlias alias, vector types, vector names, idx_t index); + + idx_t reference_count; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp deleted file mode 100644 index b1f7f6f46..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_basetableref.hpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_basetableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/logical_operator.hpp" - -namespace duckdb { -class TableCatalogEntry; - -//! Represents a TableReference to a base table in the schema -class BoundBaseTableRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::BASE_TABLE; - -public: - BoundBaseTableRef(TableCatalogEntry &table, unique_ptr get) - : BoundTableRef(TableReferenceType::BASE_TABLE), table(table), get(std::move(get)) { - } - - TableCatalogEntry &table; - unique_ptr get; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp deleted file mode 100644 index 025bc4712..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_column_data_ref.hpp +++ /dev/null @@ -1,30 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_column_data_ref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/common/optionally_owned_ptr.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" - -namespace duckdb { -//! Represents a TableReference to a base table in the schema -class BoundColumnDataRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::COLUMN_DATA; - -public: - explicit BoundColumnDataRef(optionally_owned_ptr collection) - : BoundTableRef(TableReferenceType::COLUMN_DATA), collection(std::move(collection)) { - } - //! The (optionally owned) materialized column data to scan - optionally_owned_ptr collection; - //! The index in the bind context - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp deleted file mode 100644 index 781402fbe..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_cteref.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_cteref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/common/enums/cte_materialize.hpp" - -namespace duckdb { - -class BoundCTERef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::CTE; - -public: - BoundCTERef(idx_t bind_index, idx_t cte_index) - : BoundTableRef(TableReferenceType::CTE), bind_index(bind_index), cte_index(cte_index) { - } - - BoundCTERef(idx_t bind_index, idx_t cte_index, bool is_recurring) - : BoundTableRef(TableReferenceType::CTE), bind_index(bind_index), cte_index(cte_index), - is_recurring(is_recurring) { - } - //! The set of columns bound to this base table reference - vector bound_columns; - //! The types of the values list - vector types; - //! The index in the bind context - idx_t bind_index; - //! The index of the cte - idx_t cte_index; - //! Is this a reference to the recurring table of a CTE - bool is_recurring = false; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp deleted file mode 100644 index 7b1022482..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_delimgetref.hpp +++ /dev/null @@ -1,26 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_delimgetref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -class BoundDelimGetRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::DELIM_GET; - -public: - BoundDelimGetRef(idx_t bind_index, const vector &column_types_p) - : BoundTableRef(TableReferenceType::DELIM_GET), bind_index(bind_index), column_types(column_types_p) { - } - idx_t bind_index; - vector column_types; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp deleted file mode 100644 index 3a68f5166..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_dummytableref.hpp +++ /dev/null @@ -1,26 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_dummytableref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a cross product -class BoundEmptyTableRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::EMPTY_FROM; - -public: - explicit BoundEmptyTableRef(idx_t bind_index) - : BoundTableRef(TableReferenceType::EMPTY_FROM), bind_index(bind_index) { - } - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp deleted file mode 100644 index 7fc563dda..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_expressionlistref.hpp +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_expressionlistref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/expression.hpp" - -namespace duckdb { -//! Represents a TableReference to a base table in the schema -class BoundExpressionListRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::EXPRESSION_LIST; - -public: - BoundExpressionListRef() : BoundTableRef(TableReferenceType::EXPRESSION_LIST) { - } - - //! The bound VALUES list - vector>> values; - //! The generated names of the values list - vector names; - //! The types of the values list - vector types; - //! The index in the bind context - idx_t bind_index; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp index 299189624..87976ba30 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_joinref.hpp @@ -11,19 +11,14 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/common/enums/join_type.hpp" #include "duckdb/common/enums/joinref_type.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/expression.hpp" namespace duckdb { //! Represents a join -class BoundJoinRef : public BoundTableRef { +class BoundJoinRef { public: - static constexpr const TableReferenceType TYPE = TableReferenceType::JOIN; - -public: - explicit BoundJoinRef(JoinRefType ref_type) - : BoundTableRef(TableReferenceType::JOIN), type(JoinType::INNER), ref_type(ref_type), lateral(false) { + explicit BoundJoinRef(JoinRefType ref_type) : type(JoinType::INNER), ref_type(ref_type), lateral(false) { } //! The binder used to bind the LHS of the join @@ -31,9 +26,9 @@ class BoundJoinRef : public BoundTableRef { //! The binder used to bind the RHS of the join shared_ptr right_binder; //! The left hand side of the join - unique_ptr left; + BoundStatement left; //! The right hand side of the join - unique_ptr right; + BoundStatement right; //! The join condition unique_ptr condition; //! Duplicate Eliminated Columns (if any) diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp index 3219f6307..5a2d68aa1 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/bound_pivotref.hpp @@ -9,7 +9,6 @@ #pragma once #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/expression.hpp" #include "duckdb/parser/tableref/pivotref.hpp" #include "duckdb/function/aggregate_function.hpp" @@ -30,19 +29,13 @@ struct BoundPivotInfo { static BoundPivotInfo Deserialize(Deserializer &deserializer); }; -class BoundPivotRef : public BoundTableRef { +class BoundPivotRef { public: - static constexpr const TableReferenceType TYPE = TableReferenceType::PIVOT; - -public: - explicit BoundPivotRef() : BoundTableRef(TableReferenceType::PIVOT) { - } - idx_t bind_index; //! The binder used to bind the child of the pivot shared_ptr child_binder; //! The child node of the pivot - unique_ptr child; + BoundStatement child; //! The bound pivot info BoundPivotInfo bound_pivot; }; diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp deleted file mode 100644 index 4cb057e41..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_pos_join_ref.hpp +++ /dev/null @@ -1,38 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_pos_join_ref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a positional join -class BoundPositionalJoinRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::POSITIONAL_JOIN; - -public: - BoundPositionalJoinRef() : BoundTableRef(TableReferenceType::POSITIONAL_JOIN), lateral(false) { - } - - //! The binder used to bind the LHS of the positional join - shared_ptr left_binder; - //! The binder used to bind the RHS of the positional join - shared_ptr right_binder; - //! The left hand side of the positional join - unique_ptr left; - //! The right hand side of the positional join - unique_ptr right; - //! Whether or not this is a lateral positional join - bool lateral; - //! The correlated columns of the right-side with the left-side - vector correlated_columns; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp deleted file mode 100644 index a07994f8a..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_subqueryref.hpp +++ /dev/null @@ -1,32 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_subqueryref.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" -#include "duckdb/planner/bound_tableref.hpp" - -namespace duckdb { - -//! Represents a cross product -class BoundSubqueryRef : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::SUBQUERY; - -public: - BoundSubqueryRef(shared_ptr binder_p, BoundStatement subquery) - : BoundTableRef(TableReferenceType::SUBQUERY), binder(std::move(binder_p)), subquery(std::move(subquery)) { - } - - //! The binder used to bind the subquery - shared_ptr binder; - //! The bound subquery node (if any) - BoundStatement subquery; -}; -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp b/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp deleted file mode 100644 index 6aafe2b36..000000000 --- a/src/duckdb/src/include/duckdb/planner/tableref/bound_table_function.hpp +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/tableref/bound_table_function.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" - -namespace duckdb { - -//! Represents a reference to a table-producing function call -class BoundTableFunction : public BoundTableRef { -public: - static constexpr const TableReferenceType TYPE = TableReferenceType::TABLE_FUNCTION; - -public: - explicit BoundTableFunction(unique_ptr get) - : BoundTableRef(TableReferenceType::TABLE_FUNCTION), get(std::move(get)) { - } - - unique_ptr get; - unique_ptr subquery; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/tableref/list.hpp b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp index 79a00ce62..dbc8394df 100644 --- a/src/duckdb/src/include/duckdb/planner/tableref/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/tableref/list.hpp @@ -1,11 +1,2 @@ -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/tableref/bound_pivotref.hpp" -#include "duckdb/parser/tableref/delimgetref.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" diff --git a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp index 2101bcb31..e37879f61 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/base_statistics.hpp @@ -15,6 +15,7 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/storage/statistics/numeric_stats.hpp" #include "duckdb/storage/statistics/string_stats.hpp" +#include "duckdb/storage/statistics/geometry_stats.hpp" namespace duckdb { struct SelectionVector; @@ -33,7 +34,15 @@ enum class StatsInfo : uint8_t { CAN_HAVE_NULL_AND_VALID_VALUES = 4 }; -enum class StatisticsType : uint8_t { NUMERIC_STATS, STRING_STATS, LIST_STATS, STRUCT_STATS, BASE_STATS, ARRAY_STATS }; +enum class StatisticsType : uint8_t { + NUMERIC_STATS, + STRING_STATS, + LIST_STATS, + STRUCT_STATS, + BASE_STATS, + ARRAY_STATS, + GEOMETRY_STATS +}; class BaseStatistics { friend struct NumericStats; @@ -41,6 +50,7 @@ class BaseStatistics { friend struct StructStats; friend struct ListStats; friend struct ArrayStats; + friend struct GeometryStats; public: DUCKDB_API ~BaseStatistics(); @@ -146,6 +156,8 @@ class BaseStatistics { NumericStatsData numeric_data; //! String stats data, for string stats StringStatsData string_data; + //! Geometry stats data, for geometry stats + GeometryStatsData geometry_data; } stats_union; //! Child stats (for LIST and STRUCT) unsafe_unique_array child_stats; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp new file mode 100644 index 000000000..e8db7285b --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp @@ -0,0 +1,144 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/geometry_stats.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/enums/expression_type.hpp" +#include "duckdb/common/enums/filter_propagate_result.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/operator/comparison_operators.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/array_ptr.hpp" +#include "duckdb/common/types/geometry.hpp" + +namespace duckdb { +class BaseStatistics; +struct SelectionVector; + +class GeometryTypeSet { +public: + static constexpr auto VERT_TYPES = 4; + static constexpr auto PART_TYPES = 8; + + static GeometryTypeSet Unknown() { + GeometryTypeSet result; + for (idx_t i = 0; i < VERT_TYPES; i++) { + result.sets[i] = 0xFF; + } + return result; + } + static GeometryTypeSet Empty() { + GeometryTypeSet result; + for (idx_t i = 0; i < VERT_TYPES; i++) { + result.sets[i] = 0; + } + return result; + } + + bool IsEmpty() const { + for (idx_t i = 0; i < VERT_TYPES; i++) { + if (sets[i] != 0) { + return false; + } + } + return true; + } + + bool IsUnknown() const { + for (idx_t i = 0; i < VERT_TYPES; i++) { + if (sets[i] != 0xFF) { + return false; + } + } + return true; + } + + void Add(GeometryType geom_type, VertexType vert_type) { + const auto vert_idx = static_cast(vert_type); + const auto geom_idx = static_cast(geom_type); + D_ASSERT(vert_idx < VERT_TYPES); + D_ASSERT(geom_idx < PART_TYPES); + sets[vert_idx] |= (1 << geom_idx); + } + + void Merge(const GeometryTypeSet &other) { + for (idx_t i = 0; i < VERT_TYPES; i++) { + sets[i] |= other.sets[i]; + } + } + + vector ToWKBList() const { + vector result; + for (uint8_t vert_idx = 0; vert_idx < VERT_TYPES; vert_idx++) { + for (uint8_t geom_idx = 1; geom_idx < PART_TYPES; geom_idx++) { + if (sets[vert_idx] & (1 << geom_idx)) { + result.push_back(geom_idx + vert_idx * 1000); + } + } + } + return result; + } + + vector ToString(bool snake_case) const; + + uint8_t sets[VERT_TYPES]; +}; + +struct GeometryStatsData { + + GeometryTypeSet types; + GeometryExtent extent; + + void SetEmpty() { + types = GeometryTypeSet::Empty(); + extent = GeometryExtent::Empty(); + } + + void SetUnknown() { + types = GeometryTypeSet::Unknown(); + extent = GeometryExtent::Unknown(); + } + + void Merge(const GeometryStatsData &other) { + types.Merge(other.types); + extent.Merge(other.extent); + } + + void Update(const string_t &geom_blob) { + + // Parse type + const auto type_info = Geometry::GetType(geom_blob); + types.Add(type_info.first, type_info.second); + + // Update extent + Geometry::GetExtent(geom_blob, extent); + } +}; + +struct GeometryStats { + //! Unknown statistics + DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); + //! Empty statistics + DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); + + DUCKDB_API static void Serialize(const BaseStatistics &stats, Serializer &serializer); + DUCKDB_API static void Deserialize(Deserializer &deserializer, BaseStatistics &base); + + DUCKDB_API static string ToString(const BaseStatistics &stats); + + DUCKDB_API static void Update(BaseStatistics &stats, const string_t &value); + DUCKDB_API static void Merge(BaseStatistics &stats, const BaseStatistics &other); + DUCKDB_API static void Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count); + +private: + static GeometryStatsData &GetDataUnsafe(BaseStatistics &stats); + static const GeometryStatsData &GetDataUnsafe(const BaseStatistics &stats); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp index 7ca16e38e..755e99339 100644 --- a/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp +++ b/src/duckdb/src/include/duckdb/storage/string_uncompressed.hpp @@ -201,7 +201,11 @@ struct UncompressedStringStorage { public: static inline void UpdateStringStats(SegmentStatistics &stats, const string_t &new_value) { - StringStats::Update(stats.statistics, new_value); + if (stats.statistics.GetStatsType() == StatisticsType::GEOMETRY_STATS) { + GeometryStats::Update(stats.statistics, new_value); + } else { + StringStats::Update(stats.statistics, new_value); + } } static void SetDictionary(ColumnSegment &segment, BufferHandle &handle, StringDictionaryContainer dict); diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp index f52fbabdd..831f82370 100644 --- a/src/duckdb/src/main/client_context.cpp +++ b/src/duckdb/src/main/client_context.cpp @@ -1443,6 +1443,7 @@ ParserOptions ClientContext::GetParserOptions() const { options.integer_division = DBConfig::GetSetting(*this); options.max_expression_depth = client_config.max_expression_depth; options.extensions = &DBConfig::GetConfig(*this).parser_extensions; + options.parser_override_setting = DBConfig::GetConfig(*this).options.allow_parser_override_extension; return options; } diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index 78b174902..48c1cacdf 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -63,6 +63,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(AllocatorFlushThresholdSetting), DUCKDB_GLOBAL(AllowCommunityExtensionsSetting), DUCKDB_SETTING(AllowExtensionsMetadataMismatchSetting), + DUCKDB_GLOBAL(AllowParserOverrideExtensionSetting), DUCKDB_GLOBAL(AllowPersistentSecretsSetting), DUCKDB_GLOBAL(AllowUnredactedSecretsSetting), DUCKDB_GLOBAL(AllowUnsignedExtensionsSetting), @@ -179,12 +180,12 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(ZstdMinStringLengthSetting), FINAL_SETTING}; -static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("memory_limit", 83), - DUCKDB_SETTING_ALIAS("null_order", 33), - DUCKDB_SETTING_ALIAS("profiling_output", 102), - DUCKDB_SETTING_ALIAS("user", 117), - DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 20), - DUCKDB_SETTING_ALIAS("worker_threads", 116), +static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("memory_limit", 84), + DUCKDB_SETTING_ALIAS("null_order", 34), + DUCKDB_SETTING_ALIAS("profiling_output", 103), + DUCKDB_SETTING_ALIAS("user", 118), + DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 21), + DUCKDB_SETTING_ALIAS("worker_threads", 117), FINAL_ALIAS}; vector DBConfig::GetOptions() { diff --git a/src/duckdb/src/main/query_profiler.cpp b/src/duckdb/src/main/query_profiler.cpp index eb4116cca..2552248b4 100644 --- a/src/duckdb/src/main/query_profiler.cpp +++ b/src/duckdb/src/main/query_profiler.cpp @@ -780,8 +780,10 @@ static yyjson_mut_val *ToJSONRecursive(yyjson_mut_doc *doc, ProfilingNode &node) auto result_obj = yyjson_mut_obj(doc); auto &profiling_info = node.GetProfilingInfo(); - profiling_info.metrics[MetricsType::EXTRA_INFO] = - QueryProfiler::JSONSanitize(profiling_info.metrics.at(MetricsType::EXTRA_INFO)); + if (profiling_info.Enabled(profiling_info.settings, MetricsType::EXTRA_INFO)) { + profiling_info.metrics[MetricsType::EXTRA_INFO] = + QueryProfiler::JSONSanitize(profiling_info.metrics.at(MetricsType::EXTRA_INFO)); + } profiling_info.WriteMetricsToJSON(doc, result_obj); diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index 96c3065f2..989dd6e7a 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -78,6 +78,28 @@ Value AllowCommunityExtensionsSetting::GetSetting(const ClientContext &context) return Value::BOOLEAN(config.options.allow_community_extensions); } +//===----------------------------------------------------------------------===// +// Allow Parser Override Extension +//===----------------------------------------------------------------------===// +void AllowParserOverrideExtensionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { + if (!OnGlobalSet(db, config, input)) { + return; + } + config.options.allow_parser_override_extension = input.GetValue(); +} + +void AllowParserOverrideExtensionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { + if (!OnGlobalReset(db, config)) { + return; + } + config.options.allow_parser_override_extension = DBConfigOptions().allow_parser_override_extension; +} + +Value AllowParserOverrideExtensionSetting::GetSetting(const ClientContext &context) { + auto &config = DBConfig::GetConfig(context); + return Value(config.options.allow_parser_override_extension); +} + //===----------------------------------------------------------------------===// // Allow Unredacted Secrets //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/main/settings/custom_settings.cpp b/src/duckdb/src/main/settings/custom_settings.cpp index 8e9b491e3..819e1383d 100644 --- a/src/duckdb/src/main/settings/custom_settings.cpp +++ b/src/duckdb/src/main/settings/custom_settings.cpp @@ -150,6 +150,24 @@ bool AllowCommunityExtensionsSetting::OnGlobalReset(DatabaseInstance *db, DBConf return true; } +//===----------------------------------------------------------------------===// +// Allow Parser Override +//===----------------------------------------------------------------------===// +bool AllowParserOverrideExtensionSetting::OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input) { + auto new_value = input.GetValue(); + if (!StringUtil::CIEquals(new_value, "default") && !StringUtil::CIEquals(new_value, "fallback") && + !StringUtil::CIEquals(new_value, "strict")) { + throw InvalidInputException("Unrecognized value for parser override setting. Valid options are: \"default\", " + "\"fallback\", \"strict\"."); + } + return true; +} + +bool AllowParserOverrideExtensionSetting::OnGlobalReset(DatabaseInstance *db, DBConfig &config) { + config.options.allow_parser_override_extension = "default"; + return true; +} + //===----------------------------------------------------------------------===// // Allow Persistent Secrets //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index 3007fa9ac..dfbbbd901 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -32,6 +32,7 @@ #include "duckdb/optimizer/statistics_propagator.hpp" #include "duckdb/optimizer/sum_rewriter.hpp" #include "duckdb/optimizer/topn_optimizer.hpp" +#include "duckdb/optimizer/topn_window_elimination.hpp" #include "duckdb/optimizer/unnest_rewriter.hpp" #include "duckdb/optimizer/late_materialization.hpp" #include "duckdb/optimizer/common_subplan_optimizer.hpp" @@ -264,6 +265,12 @@ void Optimizer::RunBuiltInOptimizers() { statistics_map = propagator.GetStatisticsMap(); }); + // rewrite row_number window function + filter on row_number to aggregate + RunOptimizer(OptimizerType::TOP_N_WINDOW_ELIMINATION, [&]() { + TopNWindowElimination topn_window_elimination(context, *this, &statistics_map); + plan = topn_window_elimination.Optimize(std::move(plan)); + }); + // remove duplicate aggregates RunOptimizer(OptimizerType::COMMON_AGGREGATE, [&]() { CommonAggregateOptimizer common_aggregate; diff --git a/src/duckdb/src/optimizer/topn_window_elimination.cpp b/src/duckdb/src/optimizer/topn_window_elimination.cpp new file mode 100644 index 000000000..a06cdf830 --- /dev/null +++ b/src/duckdb/src/optimizer/topn_window_elimination.cpp @@ -0,0 +1,592 @@ +#include "duckdb/optimizer/topn_window_elimination.hpp" + +#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_unnest.hpp" +#include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" +#include "duckdb/function/scalar/struct_functions.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/parser/expression/function_expression.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_unnest_expression.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/function/function_binder.hpp" + +namespace duckdb { + +namespace { + +idx_t GetGroupIdx(const unique_ptr &op) { + if (op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return op->Cast().group_index; + } + return op->GetTableIndex()[0]; +} + +idx_t GetAggregateIdx(const unique_ptr &op) { + if (op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { + return op->Cast().aggregate_index; + } + return op->GetTableIndex()[0]; +} + +LogicalType GetAggregateType(const unique_ptr &op) { + switch (op->type) { + case LogicalOperatorType::LOGICAL_UNNEST: { + const auto &logical_unnest = op->Cast(); + const idx_t unnest_offset = logical_unnest.children[0]->types.size(); + return logical_unnest.types[unnest_offset]; + } + case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: { + const auto &logical_aggregate = op->Cast(); + const idx_t aggregate_column_idx = logical_aggregate.groups.size(); + return logical_aggregate.types[aggregate_column_idx]; + } + default: { + throw InternalException("Unnest or aggregate expected to extract aggregate type."); + } + } +} + +vector ExtractReturnTypes(const vector> &exprs) { + vector types; + types.reserve(exprs.size()); + for (const auto &expr : exprs) { + types.push_back(expr->return_type); + } + return types; +} + +bool BindingsReferenceRowNumber(const vector &bindings, const LogicalWindow &window) { + for (const auto &binding : bindings) { + if (binding.table_index == window.window_index) { + return true; + } + } + return false; +} +// Window, Filter, new_bindings, aggregate_payload +TopNWindowEliminationParameters ExtractOptimizerParameters(const LogicalWindow &window, const LogicalFilter &filter, + const vector &bindings, + const vector> &aggregate_payload) { + TopNWindowEliminationParameters params; + + auto &limit_expr = filter.expressions[0]->Cast().right; + params.limit = limit_expr->Cast().value.GetValue(); + params.include_row_number = BindingsReferenceRowNumber(bindings, window); + params.payload_type = aggregate_payload.size() > 1 ? TopNPayloadType::STRUCT_PACK : TopNPayloadType::SINGLE_COLUMN; + params.order_type = window.expressions[0]->Cast().orders[0].type; + + return params; +} + +} // namespace + +TopNWindowElimination::TopNWindowElimination(ClientContext &context_p, Optimizer &optimizer, + optional_ptr>> stats_p) + : context(context_p), optimizer(optimizer), stats(stats_p) { +} + +unique_ptr TopNWindowElimination::Optimize(unique_ptr op) { + ColumnBindingReplacer replacer; + op = OptimizeInternal(std::move(op), replacer); + if (!replacer.replacement_bindings.empty()) { + replacer.VisitOperator(*op); + } + return op; +} + +unique_ptr TopNWindowElimination::OptimizeInternal(unique_ptr op, + ColumnBindingReplacer &replacer) { + if (!CanOptimize(*op)) { + // Traverse through query plan to find grouped top-n pattern + if (op->children.size() > 1) { + // If an operator has multiple children, we do not want them to overwrite each other's stop operator. + // Thus, first update only the column binding in op, then set op as the new stop operator. + for (auto &child : op->children) { + ColumnBindingReplacer r2; + child = OptimizeInternal(std::move(child), r2); + + if (!r2.replacement_bindings.empty()) { + r2.VisitOperator(*op); + replacer.replacement_bindings.insert(replacer.replacement_bindings.end(), + r2.replacement_bindings.begin(), + r2.replacement_bindings.end()); + replacer.stop_operator = op; + } + } + } else if (!op->children.empty()) { + op->children[0] = OptimizeInternal(std::move(op->children[0]), replacer); + } + + return op; + } + // We have made sure that this is an operator sequence of filter -> N optional projections -> window + auto &filter = op->Cast(); + auto *child = filter.children[0].get(); + + // Get bindings and types from filter to use in top-most operator later + const auto topmost_bindings = filter.GetColumnBindings(); + auto new_bindings = TraverseProjectionBindings(topmost_bindings, child); + + D_ASSERT(child->type == LogicalOperatorType::LOGICAL_WINDOW); + auto &window = child->Cast(); + const idx_t window_idx = window.window_index; + + // Map the input column offsets of the group columns to the output offset if there are projections on the group + // We use an ordered map here because we need to iterate over them in order later + map group_projection_idxs; + auto aggregate_payload = GenerateAggregatePayload(new_bindings, window, group_projection_idxs); + const auto params = ExtractOptimizerParameters(window, filter, new_bindings, aggregate_payload); + + // Optimize window children + window.children[0] = Optimize(std::move(window.children[0])); + + op = CreateAggregateOperator(window, std::move(aggregate_payload), params); + op = TryCreateUnnestOperator(std::move(op), params); + op = CreateProjectionOperator(std::move(op), params, group_projection_idxs); + + D_ASSERT(op->type != LogicalOperatorType::LOGICAL_UNNEST); + + UpdateTopmostBindings(window_idx, op, group_projection_idxs, topmost_bindings, new_bindings, replacer); + replacer.stop_operator = op.get(); + + return unique_ptr(std::move(op)); +} + +unique_ptr TopNWindowElimination::CreateAggregateExpression(vector> aggregate_params, + const bool requires_arg, + const OrderType order_type) const { + auto &catalog = Catalog::GetSystemCatalog(context); + FunctionBinder function_binder(context); + + D_ASSERT(order_type == OrderType::ASCENDING || order_type == OrderType::DESCENDING); + string fun_name = requires_arg ? "arg_" : ""; + fun_name += order_type == OrderType::ASCENDING ? "min" : "max"; + + auto &fun_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, fun_name); + const auto fun = fun_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(aggregate_params)); + return function_binder.BindAggregateFunction(fun, std::move(aggregate_params)); +} + +unique_ptr +TopNWindowElimination::CreateAggregateOperator(LogicalWindow &window, vector> args, + const TopNWindowEliminationParameters ¶ms) const { + auto &window_expr = window.expressions[0]->Cast(); + D_ASSERT(window_expr.orders.size() == 1); + + vector> aggregate_params; + aggregate_params.reserve(3); + + const bool use_arg = !args.empty(); + if (args.size() == 1) { + aggregate_params.push_back(std::move(args[0])); + } else if (args.size() > 1) { + // For more than one arg, we must use struct pack + auto &catalog = Catalog::GetSystemCatalog(context); + FunctionBinder function_binder(context); + auto &struct_pack_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "struct_pack"); + const auto struct_pack_fun = + struct_pack_entry.functions.GetFunctionByArguments(context, ExtractReturnTypes(args)); + auto struct_pack_expr = function_binder.BindScalarFunction(struct_pack_fun, std::move(args)); + aggregate_params.push_back(std::move(struct_pack_expr)); + } + + aggregate_params.push_back(std::move(window_expr.orders[0].expression)); + if (params.limit > 1) { + aggregate_params.push_back(std::move(make_uniq(Value::BIGINT(params.limit)))); + } + + auto aggregate_expr = CreateAggregateExpression(std::move(aggregate_params), use_arg, params.order_type); + + vector> select_list; + select_list.push_back(std::move(aggregate_expr)); + + auto aggregate = make_uniq(optimizer.binder.GenerateTableIndex(), + optimizer.binder.GenerateTableIndex(), std::move(select_list)); + aggregate->groupings_index = optimizer.binder.GenerateTableIndex(); + aggregate->groups = std::move(window_expr.partitions); + aggregate->children.push_back(std::move(window.children[0])); + aggregate->ResolveOperatorTypes(); + + return unique_ptr(std::move(aggregate)); +} + +unique_ptr +TopNWindowElimination::CreateRowNumberGenerator(unique_ptr aggregate_column_ref) const { + // Create unnest(generate_series(1, array_length(column_ref, 1))) function to generate row ids + FunctionBinder function_binder(context); + auto &catalog = Catalog::GetSystemCatalog(context); + + // array_length + auto &array_length_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, "array_length"); + vector> array_length_exprs; + array_length_exprs.push_back(std::move(aggregate_column_ref)); + array_length_exprs.push_back(make_uniq(1)); + + const auto array_length_fun = array_length_entry.functions.GetFunctionByArguments( + context, {array_length_exprs[0]->return_type, array_length_exprs[1]->return_type}); + auto bound_array_length_fun = function_binder.BindScalarFunction(array_length_fun, std::move(array_length_exprs)); + + // generate_series + auto &generate_series_entry = + catalog.GetEntry(context, DEFAULT_SCHEMA, "generate_series"); + + vector> generate_series_exprs; + generate_series_exprs.push_back(make_uniq(1)); + generate_series_exprs.push_back(std::move(bound_array_length_fun)); + + const auto generate_series_fun = generate_series_entry.functions.GetFunctionByArguments( + context, {generate_series_exprs[0]->return_type, generate_series_exprs[1]->return_type}); + auto bound_generate_series_fun = + function_binder.BindScalarFunction(generate_series_fun, std::move(generate_series_exprs)); + + // unnest + auto unnest_row_number_expr = make_uniq(LogicalType::BIGINT); + unnest_row_number_expr->alias = "row_number"; + unnest_row_number_expr->child = std::move(bound_generate_series_fun); + + return unique_ptr(std::move(unnest_row_number_expr)); +} + +unique_ptr +TopNWindowElimination::TryCreateUnnestOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms) const { + D_ASSERT(op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY); + + auto &logical_aggregate = op->Cast(); + const idx_t aggregate_column_idx = logical_aggregate.groups.size(); + LogicalType aggregate_type = logical_aggregate.types[aggregate_column_idx]; + + if (params.limit <= 1) { + // LIMIT 1 -> we do not need to unnest + return std::move(op); + } + + // Create unnest expression for aggregate args + const auto aggregate_bindings = logical_aggregate.GetColumnBindings(); + auto aggregate_column_ref = + make_uniq(aggregate_type, aggregate_bindings[aggregate_column_idx]); + + vector> unnest_exprs; + + auto unnest_aggregate = make_uniq(ListType::GetChildType(aggregate_type)); + unnest_aggregate->child = aggregate_column_ref->Copy(); + unnest_exprs.push_back(std::move(unnest_aggregate)); + + if (params.include_row_number) { + // Create row number expression + unnest_exprs.push_back(CreateRowNumberGenerator(std::move(aggregate_column_ref))); + } + + auto unnest = make_uniq(optimizer.binder.GenerateTableIndex()); + unnest->expressions = std::move(unnest_exprs); + unnest->children.push_back(std::move(op)); + unnest->ResolveOperatorTypes(); + + return unique_ptr(std::move(unnest)); +} + +void TopNWindowElimination::AddStructExtractExprs( + vector> &exprs, const LogicalType &struct_type, + const unique_ptr &aggregate_column_ref) const { + FunctionBinder function_binder(context); + auto &catalog = Catalog::GetSystemCatalog(context); + auto &struct_extract_entry = + catalog.GetEntry(context, DEFAULT_SCHEMA, "struct_extract"); + const auto struct_extract_fun = + struct_extract_entry.functions.GetFunctionByArguments(context, {struct_type, LogicalType::VARCHAR}); + + const auto &child_types = StructType::GetChildTypes(struct_type); + for (idx_t i = 0; i < child_types.size(); i++) { + const auto &alias = child_types[i].first; + + vector> fun_args(2); + fun_args[0] = aggregate_column_ref->Copy(); + fun_args[1] = make_uniq(alias); + + auto bound_function = function_binder.BindScalarFunction(struct_extract_fun, std::move(fun_args)); + bound_function->alias = alias; + exprs.push_back(std::move(bound_function)); + } +} + +unique_ptr +TopNWindowElimination::CreateProjectionOperator(unique_ptr op, + const TopNWindowEliminationParameters ¶ms, + const map &group_idxs) const { + const auto aggregate_type = GetAggregateType(op); + const idx_t aggregate_table_idx = GetAggregateIdx(op); + const auto op_column_bindings = op->GetColumnBindings(); + + vector> proj_exprs; + // Only project necessary group columns + for (const auto &group_idx : group_idxs) { + proj_exprs.push_back( + make_uniq(op->types[group_idx.second], op_column_bindings[group_idx.second])); + } + + auto aggregate_column_ref = + make_uniq(aggregate_type, ColumnBinding(aggregate_table_idx, 0)); + + if (params.payload_type == TopNPayloadType::STRUCT_PACK) { + AddStructExtractExprs(proj_exprs, aggregate_type, aggregate_column_ref); + } else { + // No need for struct_unpack! Just reference the aggregate column + proj_exprs.push_back(std::move(aggregate_column_ref)); + } + + if (params.include_row_number) { + // If aggregate (i.e., limit 1): constant, if unnest: expect there to be a second column + if (op->type == LogicalOperatorType::LOGICAL_UNNEST) { + const idx_t row_number_offset = op->children[0]->types.size() + 1; + D_ASSERT(op->types.size() == row_number_offset + 1); // Row number should have been generated previously + proj_exprs.push_back(make_uniq(op->types[row_number_offset], + op_column_bindings[row_number_offset])); + } else { + proj_exprs.push_back(make_uniq(Value::BIGINT(1))); + } + } + + auto logical_projection = + make_uniq(optimizer.binder.GenerateTableIndex(), std::move(proj_exprs)); + logical_projection->children.push_back(std::move(op)); + logical_projection->ResolveOperatorTypes(); + + return unique_ptr(std::move(logical_projection)); +} + +bool TopNWindowElimination::CanOptimize(LogicalOperator &op) { + if (!stats) { + return false; + } + + if (op.type != LogicalOperatorType::LOGICAL_FILTER) { + return false; + } + + const auto &filter = op.Cast(); + if (filter.expressions.size() != 1) { + return false; + } + + if (filter.expressions[0]->type != ExpressionType::COMPARE_LESSTHANOREQUALTO) { + return false; + } + + auto &filter_comparison = filter.expressions[0]->Cast(); + if (filter_comparison.right->type != ExpressionType::VALUE_CONSTANT) { + return false; + } + auto &filter_value = filter_comparison.right->Cast(); + if (filter_value.value.type() != LogicalType::BIGINT) { + return false; + } + if (filter_value.value.GetValue() < 1) { + return false; + } + + if (filter_comparison.left->type != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + VisitExpression(&filter_comparison.left); + + auto *child = filter.children[0].get(); + while (child->type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = child->Cast(); + if (column_references.size() != 1) { + column_references.clear(); + return false; + } + + const auto current_column_ref = column_references.begin()->first; + column_references.clear(); + D_ASSERT(current_column_ref.table_index == projection.table_index); + VisitExpression(&projection.expressions[current_column_ref.column_index]); + + child = child->children[0].get(); + } + + if (column_references.size() != 1) { + column_references.clear(); + return false; + } + const auto filter_col_idx = column_references.begin()->first.table_index; + column_references.clear(); + + if (child->type != LogicalOperatorType::LOGICAL_WINDOW) { + return false; + } + const auto &window = child->Cast(); + if (window.window_index != filter_col_idx) { + return false; + } + if (window.expressions.size() != 1) { + for (idx_t i = 1; i < window.expressions.size(); ++i) { + if (!window.expressions[i]->Equals(*window.expressions[0])) { + return false; + } + } + } + if (window.expressions[0]->type != ExpressionType::WINDOW_ROW_NUMBER) { + return false; + } + auto &window_expr = window.expressions[0]->Cast(); + + if (window_expr.orders.size() != 1) { + return false; + } + if (window_expr.orders[0].type != OrderType::DESCENDING && window_expr.orders[0].type != OrderType::ASCENDING) { + return false; + } + + VisitExpression(&window_expr.orders[0].expression); + for (const auto &column_ref : column_references) { + const auto &column_stats = stats->find(column_ref.first); + if (column_stats == stats->end() || column_stats->second->CanHaveNull()) { + return false; + } + } + column_references.clear(); + + // We have found a grouped top-n window construct! + return true; +} + +vector> TopNWindowElimination::GenerateAggregatePayload(const vector &bindings, + const LogicalWindow &window, + map &group_idxs) { + vector> aggregate_args; + aggregate_args.reserve(bindings.size()); + + window.children[0]->ResolveOperatorTypes(); + const auto &window_child_types = window.children[0]->types; + const auto window_child_bindings = window.children[0]->GetColumnBindings(); + auto &window_expr = window.expressions[0]->Cast(); + + // Remember order of group columns to recreate that order in new bindings later + column_binding_map_t group_bindings; + for (idx_t i = 0; i < window_expr.partitions.size(); i++) { + auto &expr = window_expr.partitions[i]; + VisitExpression(&expr); + group_bindings[column_references.begin()->first] = i; + column_references.clear(); + } + + for (idx_t i = 0; i < bindings.size(); i++) { + const auto &binding = bindings[i]; + const auto group_binding = group_bindings.find(binding); + if (group_binding != group_bindings.end()) { + group_idxs[i] = group_binding->second; + continue; + } + if (binding.table_index == window.window_index) { + continue; + } + auto column_id = to_string(binding.column_index); // Use idx as struct pack/extract identifier + auto column_type = window_child_types[binding.column_index]; + const auto &column_binding = window_child_bindings[binding.column_index]; + + aggregate_args.push_back(make_uniq(column_id, column_type, column_binding)); + } + + if (aggregate_args.size() == 1) { + // If we only project the aggregate value itself, we do not need it as an arg + VisitExpression(&window_expr.orders[0].expression); + const auto aggregate_value_binding = column_references.begin()->first; + column_references.clear(); + + if (window_expr.orders[0].expression->type == ExpressionType::BOUND_COLUMN_REF && + aggregate_args[0]->Cast().binding == aggregate_value_binding) { + return {}; + } + } + + return aggregate_args; +} + +vector TopNWindowElimination::TraverseProjectionBindings(const std::vector &old_bindings, + LogicalOperator *&op) { + auto new_bindings = old_bindings; + + // Traverse child projections to retrieve projections on window output + while (op->type == LogicalOperatorType::LOGICAL_PROJECTION) { + auto &projection = op->Cast(); + + for (idx_t i = 0; i < new_bindings.size(); i++) { + auto &new_binding = new_bindings[i]; + D_ASSERT(new_binding.table_index == projection.table_index); + VisitExpression(&projection.expressions[new_binding.column_index]); + new_binding = column_references.begin()->first; + column_references.clear(); + } + op = op->children[0].get(); + } + + return new_bindings; +} + +void TopNWindowElimination::UpdateTopmostBindings(const idx_t window_idx, const unique_ptr &op, + const map &group_idxs, + const vector &topmost_bindings, + vector &new_bindings, + ColumnBindingReplacer &replacer) { + // The top-most operator's column order is [group][aggregate args][row number]. Now, set the new resulting bindings. + D_ASSERT(topmost_bindings.size() == new_bindings.size()); + replacer.replacement_bindings.reserve(new_bindings.size()); + set row_id_binding_idxs; + + const idx_t group_table_idx = GetGroupIdx(op); + const idx_t aggregate_table_idx = GetAggregateIdx(op); + + // Project the group columns + idx_t current_column_idx = 0; + for (auto group_idx : group_idxs) { + const idx_t group_referencing_idx = group_idx.first; + new_bindings[group_referencing_idx].table_index = group_table_idx; + new_bindings[group_referencing_idx].column_index = group_idx.second; + replacer.replacement_bindings.emplace_back(topmost_bindings[group_referencing_idx], + new_bindings[group_referencing_idx]); + current_column_idx++; + } + + if (group_table_idx != aggregate_table_idx) { + // If the topmost operator is not a projection, the table indexes are different, and we start back from 0 + current_column_idx = 0; + } + + // Project the args/value + for (idx_t i = 0; i < new_bindings.size(); i++) { + auto &binding = new_bindings[i]; + if (group_idxs.find(i) != group_idxs.end()) { + continue; + } + if (binding.table_index == window_idx) { + row_id_binding_idxs.insert(i); + continue; + } + binding.column_index = current_column_idx++; + binding.table_index = aggregate_table_idx; + replacer.replacement_bindings.emplace_back(topmost_bindings[i], binding); + } + + // Project the row number + for (const auto row_id_binding_idx : row_id_binding_idxs) { + // Let all projections on row id point to the last output column + auto &binding = new_bindings[row_id_binding_idx]; + binding.table_index = aggregate_table_idx; + binding.column_index = op->types.size() - 1; + replacer.replacement_bindings.emplace_back(topmost_bindings[row_id_binding_idx], binding); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/parser/parser.cpp b/src/duckdb/src/parser/parser.cpp index 3695f75dc..9ee2a675b 100644 --- a/src/duckdb/src/parser/parser.cpp +++ b/src/duckdb/src/parser/parser.cpp @@ -194,6 +194,7 @@ void Parser::ParseQuery(const string &query) { Transformer transformer(options); string parser_error; optional_idx parser_error_location; + string parser_override_option = StringUtil::Lower(options.parser_override_setting); { // check if there are any unicode spaces in the string string new_query; @@ -209,12 +210,24 @@ void Parser::ParseQuery(const string &query) { if (!ext.parser_override) { continue; } + if (StringUtil::CIEquals(parser_override_option, "default")) { + continue; + } auto result = ext.parser_override(ext.parser_info.get(), query); if (result.type == ParserExtensionResultType::PARSE_SUCCESSFUL) { statements = std::move(result.statements); return; - } else if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { - throw ParserException(result.error); + } + if (StringUtil::CIEquals(parser_override_option, "strict")) { + if (result.type == ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR) { + throw ParserException( + "Parser override failed to return a valid statement. Consider restarting the database and " + "using the setting \"set allow_parser_override_extension=fallback\" to fallback to the " + "default parser."); + } + if (result.type == ParserExtensionResultType::DISPLAY_EXTENSION_ERROR) { + throw ParserException(result.error); + } } } } diff --git a/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp b/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp index a2ecb5086..6bf31269c 100644 --- a/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp +++ b/src/duckdb/src/parser/tableref/bound_ref_wrapper.cpp @@ -2,7 +2,7 @@ namespace duckdb { -BoundRefWrapper::BoundRefWrapper(unique_ptr bound_ref_p, shared_ptr binder_p) +BoundRefWrapper::BoundRefWrapper(BoundStatement bound_ref_p, shared_ptr binder_p) : TableRef(TableReferenceType::BOUND_TABLE_REF), bound_ref(std::move(bound_ref_p)), binder(std::move(binder_p)) { } diff --git a/src/duckdb/src/planner/bind_context.cpp b/src/duckdb/src/planner/bind_context.cpp index 10f32432a..038eddb42 100644 --- a/src/duckdb/src/planner/bind_context.cpp +++ b/src/duckdb/src/planner/bind_context.cpp @@ -77,10 +77,6 @@ void BindContext::AddUsingBinding(const string &column_name, UsingColumnSet &set using_columns[column_name].insert(set); } -void BindContext::AddUsingBindingSet(unique_ptr set) { - using_column_sets.push_back(std::move(set)); -} - optional_ptr BindContext::GetUsingBinding(const string &column_name) { auto entry = using_columns.find(column_name); if (entry == using_columns.end()) { @@ -714,30 +710,16 @@ void BindContext::AddGenericBinding(idx_t index, const string &alias, const vect void BindContext::AddCTEBinding(idx_t index, const string &alias, const vector &names, const vector &types, bool using_key) { - auto binding = make_shared_ptr(BindingType::BASE, BindingAlias(alias), types, names, index); + auto binding = make_uniq(BindingAlias(alias), types, names, index); if (cte_bindings.find(alias) != cte_bindings.end()) { throw BinderException("Duplicate CTE binding \"%s\" in query!", alias); } cte_bindings[alias] = std::move(binding); - cte_references[alias] = make_shared_ptr(0); if (using_key) { auto recurring_alias = "recurring." + alias; - cte_bindings[recurring_alias] = - make_shared_ptr(BindingType::BASE, BindingAlias(recurring_alias), types, names, index); - cte_references[recurring_alias] = make_shared_ptr(0); - } -} - -void BindContext::RemoveCTEBinding(const std::string &alias) { - auto it = cte_bindings.find(alias); - if (it != cte_bindings.end()) { - cte_bindings.erase(it); - } - auto it2 = cte_references.find(alias); - if (it2 != cte_references.end()) { - cte_references.erase(it2); + cte_bindings[recurring_alias] = make_uniq(BindingAlias(recurring_alias), types, names, index); } } diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 440152476..07d225378 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -28,10 +28,6 @@ namespace duckdb { -Binder &Binder::GetRootBinder() { - return root_binder; -} - idx_t Binder::GetBinderDepth() const { return depth; } @@ -50,9 +46,11 @@ shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent_p, BinderType binder_type) - : context(context), bind_context(*this), parent(std::move(parent_p)), bound_tables(0), binder_type(binder_type), - entry_retriever(context), root_binder(parent ? parent->GetRootBinder() : *this), - depth(parent ? parent->GetBinderDepth() : 1) { + : context(context), bind_context(*this), parent(std::move(parent_p)), binder_type(binder_type), + global_binder_state(parent ? parent->global_binder_state : make_shared_ptr()), + query_binder_state(parent && binder_type == BinderType::REGULAR_BINDER ? parent->query_binder_state + : make_shared_ptr()), + entry_retriever(context), depth(parent ? parent->GetBinderDepth() : 1) { IncreaseDepth(); if (parent) { entry_retriever.Inherit(parent->entry_retriever); @@ -60,13 +58,6 @@ Binder::Binder(ClientContext &context, shared_ptr parent_p, BinderType b // We have to inherit macro and lambda parameter bindings and from the parent binder, if there is a parent. macro_binding = parent->macro_binding; lambda_bindings = parent->lambda_bindings; - - if (binder_type == BinderType::REGULAR_BINDER) { - // We have to inherit CTE bindings from the parent bind_context, if there is a parent. - bind_context.SetCTEBindings(parent->bind_context.GetCTEBindings()); - bind_context.cte_references = parent->bind_context.cte_references; - parameters = parent->parameters; - } } } @@ -83,7 +74,7 @@ BoundStatement Binder::BindWithCTE(T &statement) { auto &cte_entry = cte.second; auto mat_cte = make_uniq(); mat_cte->ctename = cte.first; - mat_cte->query = cte_entry->query->node->Copy(); + mat_cte->query = std::move(cte_entry->query->node); mat_cte->aliases = cte_entry->aliases; mat_cte->materialized = cte_entry->materialized; materialized_ctes.push_back(std::move(mat_cte)); @@ -93,18 +84,15 @@ BoundStatement Binder::BindWithCTE(T &statement) { while (!materialized_ctes.empty()) { unique_ptr node_result; node_result = std::move(materialized_ctes.back()); - node_result->cte_map = cte_map.Copy(); node_result->child = std::move(cte_root); cte_root = std::move(node_result); materialized_ctes.pop_back(); } - AddCTEMap(cte_map); return Bind(*cte_root); } BoundStatement Binder::Bind(SQLStatement &statement) { - root_statement = &statement; switch (statement.type) { case StatementType::SELECT_STATEMENT: return Bind(statement.Cast()); @@ -164,15 +152,7 @@ BoundStatement Binder::Bind(SQLStatement &statement) { } // LCOV_EXCL_STOP } -void Binder::AddCTEMap(CommonTableExpressionMap &cte_map) { - for (auto &cte_it : cte_map.map) { - AddCTE(cte_it.first); - } -} - BoundStatement Binder::BindNode(QueryNode &node) { - // first we visit the set of CTEs and add them to the bind context - AddCTEMap(node.cte_map); // now we bind the node switch (node.type) { case QueryNodeType::SELECT_NODE: @@ -194,8 +174,8 @@ BoundStatement Binder::Bind(QueryNode &node) { return BindNode(node); } -unique_ptr Binder::Bind(TableRef &ref) { - unique_ptr result; +BoundStatement Binder::Bind(TableRef &ref) { + BoundStatement result; switch (ref.type) { case TableReferenceType::BASE_TABLE: result = Bind(ref.Cast()); @@ -235,52 +215,10 @@ unique_ptr Binder::Bind(TableRef &ref) { default: throw InternalException("Unknown table ref type (%s)", EnumUtil::ToString(ref.type)); } - result->sample = std::move(ref.sample); - return result; -} - -unique_ptr Binder::CreatePlan(BoundTableRef &ref) { - unique_ptr root; - switch (ref.type) { - case TableReferenceType::BASE_TABLE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::SUBQUERY: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::JOIN: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::TABLE_FUNCTION: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EMPTY_FROM: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::EXPRESSION_LIST: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::COLUMN_DATA: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::CTE: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::PIVOT: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::DELIM_GET: - root = CreatePlan(ref.Cast()); - break; - case TableReferenceType::INVALID: - default: - throw InternalException("Unsupported bound table ref type (%s)", EnumUtil::ToString(ref.type)); - } - // plan the sample clause if (ref.sample) { - root = make_uniq(std::move(ref.sample), std::move(root)); + result.plan = make_uniq(std::move(ref.sample), std::move(result.plan)); } - return root; + return result; } void Binder::AddCTE(const string &name) { @@ -288,17 +226,19 @@ void Binder::AddCTE(const string &name) { CTE_bindings.insert(name); } -vector> Binder::FindCTE(const string &name, bool skip) { - auto entry = bind_context.GetCTEBinding(name); - vector> ctes; - if (entry) { - ctes.push_back(*entry.get()); - } - if (parent && binder_type == BinderType::REGULAR_BINDER) { - auto parent_ctes = parent->FindCTE(name, name == alias); - ctes.insert(ctes.end(), parent_ctes.begin(), parent_ctes.end()); +optional_ptr Binder::GetCTEBinding(const string &name) { + reference current_binder(*this); + while (true) { + auto ¤t = current_binder.get(); + auto entry = current.bind_context.GetCTEBinding(name); + if (entry) { + return entry; + } + if (!current.parent || current.binder_type != BinderType::REGULAR_BINDER) { + return nullptr; + } + current_binder = *current.parent; } - return ctes; } bool Binder::CTEExists(const string &name) { @@ -324,13 +264,19 @@ void Binder::AddBoundView(ViewCatalogEntry &view) { } idx_t Binder::GenerateTableIndex() { - auto &root_binder = GetRootBinder(); - return root_binder.bound_tables++; + return global_binder_state->bound_tables++; } StatementProperties &Binder::GetStatementProperties() { - auto &root_binder = GetRootBinder(); - return root_binder.prop; + return global_binder_state->prop; +} + +optional_ptr Binder::GetParameters() { + return query_binder_state->parameters; +} + +void Binder::SetParameters(BoundParameterMap ¶meters) { + query_binder_state->parameters = parameters; } void Binder::PushExpressionBinder(ExpressionBinder &binder) { @@ -356,17 +302,11 @@ bool Binder::HasActiveBinder() { } vector> &Binder::GetActiveBinders() { - reference root = *this; - while (root.get().parent && root.get().binder_type == BinderType::REGULAR_BINDER) { - root = *root.get().parent; - } - auto &root_binder = root.get(); - return root_binder.active_binders; + return query_binder_state->active_binders; } void Binder::AddUsingBindingSet(unique_ptr set) { - auto &root_binder = GetRootBinder(); - root_binder.bind_context.AddUsingBindingSet(std::move(set)); + global_binder_state->using_column_sets.push_back(std::move(set)); } void Binder::MoveCorrelatedExpressions(Binder &other) { @@ -414,13 +354,11 @@ optional_ptr Binder::GetMatchingBinding(const string &catalog_name, con } void Binder::SetBindingMode(BindingMode mode) { - auto &root_binder = GetRootBinder(); - root_binder.mode = mode; + global_binder_state->mode = mode; } BindingMode Binder::GetBindingMode() { - auto &root_binder = GetRootBinder(); - return root_binder.mode; + return global_binder_state->mode; } void Binder::SetCanContainNulls(bool can_contain_nulls_p) { @@ -433,30 +371,26 @@ void Binder::SetAlwaysRequireRebind() { } void Binder::AddTableName(string table_name) { - auto &root_binder = GetRootBinder(); - root_binder.table_names.insert(std::move(table_name)); + global_binder_state->table_names.insert(std::move(table_name)); } void Binder::AddReplacementScan(const string &table_name, unique_ptr replacement) { - auto &root_binder = GetRootBinder(); - auto it = root_binder.replacement_scans.find(table_name); + auto it = global_binder_state->replacement_scans.find(table_name); replacement->column_name_alias.clear(); replacement->alias.clear(); - if (it == root_binder.replacement_scans.end()) { - root_binder.replacement_scans[table_name] = std::move(replacement); + if (it == global_binder_state->replacement_scans.end()) { + global_binder_state->replacement_scans[table_name] = std::move(replacement); } else { // A replacement scan by this name was previously registered, we can just use it } } const unordered_set &Binder::GetTableNames() { - auto &root_binder = GetRootBinder(); - return root_binder.table_names; + return global_binder_state->table_names; } case_insensitive_map_t> &Binder::GetReplacementScans() { - auto &root_binder = GetRootBinder(); - return root_binder.replacement_scans; + return global_binder_state->replacement_scans; } // FIXME: this is extremely naive diff --git a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp index 109c0ecbd..3fe02467e 100644 --- a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp @@ -8,19 +8,19 @@ namespace duckdb { BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t depth) { - if (!binder.parameters) { + auto parameters = binder.GetParameters(); + if (!parameters) { throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); } auto parameter_id = expr.identifier; - D_ASSERT(binder.parameters); // Check if a parameter value has already been supplied - auto ¶meter_data = binder.parameters->GetParameterData(); + auto ¶meter_data = parameters->GetParameterData(); auto param_data_it = parameter_data.find(parameter_id); if (param_data_it != parameter_data.end()) { // it has! emit a constant directly auto &data = param_data_it->second; - auto return_type = binder.parameters->GetReturnType(parameter_id); + auto return_type = parameters->GetReturnType(parameter_id); bool is_literal = return_type.id() == LogicalTypeId::INTEGER_LITERAL || return_type.id() == LogicalTypeId::STRING_LITERAL; auto constant = make_uniq(data.GetValue()); @@ -32,7 +32,7 @@ BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t dep return BindResult(std::move(cast)); } - auto bound_parameter = binder.parameters->BindParameterExpression(expr); + auto bound_parameter = parameters->BindParameterExpression(expr); return BindResult(std::move(bound_parameter)); } diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp index 8e8ab74ad..92ca383c7 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -63,7 +63,6 @@ BoundStatement Binder::BindCTE(CTENode &statement) { // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - result.child_binder->bind_context.RemoveCTEBinding(statement.ctename); result.child_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, names, result.types); if (statement.child) { diff --git a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp index 18f22ae50..8cb62ab73 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -40,10 +40,6 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { result.right_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - // If there is already a binding for the CTE, we need to remove it first - // as we are binding a CTE currently, we take precendence over the existing binding. - // This implements the CTE shadowing behavior. - result.right_binder->bind_context.RemoveCTEBinding(statement.ctename); result.right_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, result.names, result.types, !statement.key_targets.empty()); diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index 6edddac51..a02f878a9 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -412,9 +412,8 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { } } -unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement, - unique_ptr from_table) { - D_ASSERT(from_table); +unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement, BoundStatement from_table) { + D_ASSERT(from_table.plan); D_ASSERT(!statement.from_table); auto result_ptr = make_uniq(); auto &result = *result_ptr; @@ -692,7 +691,7 @@ unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement return result_ptr; } -BoundStatement Binder::BindSelectNode(SelectNode &statement, unique_ptr from_table) { +BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { auto result = BindSelectNodeInternal(statement, std::move(from_table)); BoundStatement result_statement; diff --git a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp index dccbae4cd..f51a03c50 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp @@ -24,17 +24,16 @@ unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { left_node = CastLogicalOperatorToTypes(node.left.types, node.types, std::move(left_node)); right_node = CastLogicalOperatorToTypes(node.right.types, node.types, std::move(right_node)); - bool ref_recurring = node.right_binder->bind_context.cte_references["recurring." + node.ctename] && - *node.right_binder->bind_context.cte_references["recurring." + node.ctename] != 0; - + auto recurring_binding = node.right_binder->GetCTEBinding("recurring." + node.ctename); + bool ref_recurring = recurring_binding && recurring_binding->Cast().reference_count > 0; if (node.key_targets.empty() && ref_recurring) { throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); } // Check if there is a reference to the recursive or recurring table, if not create a set operator. - if ((!node.right_binder->bind_context.cte_references[node.ctename] || - *node.right_binder->bind_context.cte_references[node.ctename] == 0) && - !ref_recurring) { + auto cte_binding = node.right_binder->GetCTEBinding(node.ctename); + bool ref_cte = cte_binding && cte_binding->Cast().reference_count > 0; + if (!ref_cte && !ref_recurring) { auto root = make_uniq(node.setop_index, node.types.size(), std::move(left_node), std::move(right_node), LogicalOperatorType::LOGICAL_UNION, node.union_all); diff --git a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp index 46e5d2e12..10b206f24 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_select_node.cpp @@ -16,10 +16,8 @@ unique_ptr Binder::PlanFilter(unique_ptr condition, } unique_ptr Binder::CreatePlan(BoundSelectNode &statement) { - unique_ptr root; - D_ASSERT(statement.from_table); - root = CreatePlan(*statement.from_table); - D_ASSERT(root); + D_ASSERT(statement.from_table.plan); + auto root = std::move(statement.from_table.plan); // plan the sample clause if (statement.sample_options) { diff --git a/src/duckdb/src/planner/binder/statement/bind_attach.cpp b/src/duckdb/src/planner/binder/statement/bind_attach.cpp index 0e8655d2f..8ec2beca3 100644 --- a/src/duckdb/src/planner/binder/statement/bind_attach.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_attach.cpp @@ -1,7 +1,6 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/statement/attach_statement.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/planner/expression_binder/table_function_binder.hpp" #include "duckdb/execution/expression_executor.hpp" diff --git a/src/duckdb/src/planner/binder/statement/bind_call.cpp b/src/duckdb/src/planner/binder/statement/bind_call.cpp index ba96927e8..46be806cf 100644 --- a/src/duckdb/src/planner/binder/statement/bind_call.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_call.cpp @@ -1,8 +1,6 @@ #include "duckdb/parser/statement/call_statement.hpp" #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp index f1ffd7496..c63e9bf10 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -39,7 +39,6 @@ #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/storage/storage_extension.hpp" #include "duckdb/common/extension_type_info.hpp" #include "duckdb/common/type_visitor.hpp" @@ -544,23 +543,21 @@ BoundStatement Binder::Bind(CreateStatement &stmt) { create_index_info.table); auto table_ref = make_uniq(table_description); auto bound_table = Bind(*table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + auto plan = std::move(bound_table.plan); + if (plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("can only create an index on a base table"); + } + auto &get = plan->Cast(); + auto table_ptr = get.GetTable(); + if (!table_ptr) { throw BinderException("can only create an index on a base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; + auto &table = *table_ptr; if (table.temporary) { stmt.info->temporary = true; } properties.RegisterDBModify(table.catalog, context); - - // create a plan over the bound table - auto plan = CreatePlan(*bound_table); - if (plan->type != LogicalOperatorType::LOGICAL_GET) { - throw BinderException("Cannot create index on a view!"); - } - result.plan = table.catalog.BindCreateIndex(*this, stmt, table, std::move(plan)); break; } diff --git a/src/duckdb/src/planner/binder/statement/bind_delete.cpp b/src/duckdb/src/planner/binder/statement/bind_delete.cpp index e83a62ae3..6b2b9d6fd 100644 --- a/src/duckdb/src/planner/binder/statement/bind_delete.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_delete.cpp @@ -5,8 +5,6 @@ #include "duckdb/planner/operator/logical_delete.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" @@ -15,38 +13,34 @@ namespace duckdb { BoundStatement Binder::Bind(DeleteStatement &stmt) { // visit the table reference auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only delete from base table!"); + auto root = std::move(bound_table.plan); + if (root->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only delete from base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - auto root = CreatePlan(*bound_table); auto &get = root->Cast(); - D_ASSERT(root->type == LogicalOperatorType::LOGICAL_GET); - + auto table_ptr = get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only delete from base table"); + } + auto &table = *table_ptr; if (!table.temporary) { // delete from persistent table: not read only! auto &properties = GetStatementProperties(); properties.RegisterDBModify(table.catalog, context); } - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); - // plan any tables from the various using clauses if (!stmt.using_clauses.empty()) { unique_ptr child_operator; for (auto &using_clause : stmt.using_clauses) { // bind the using clause auto using_binder = Binder::CreateBinder(context, this); - auto bound_node = using_binder->Bind(*using_clause); - auto op = CreatePlan(*bound_node); + auto op = using_binder->Bind(*using_clause); if (child_operator) { // already bound a child: create a cross product to unify the two - child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op)); + child_operator = LogicalCrossProduct::Create(std::move(child_operator), std::move(op.plan)); } else { - child_operator = std::move(op); + child_operator = std::move(op.plan); } bind_context.AddContext(std::move(using_binder->bind_context)); } diff --git a/src/duckdb/src/planner/binder/statement/bind_drop.cpp b/src/duckdb/src/planner/binder/statement/bind_drop.cpp index f40a86c61..9239e812e 100644 --- a/src/duckdb/src/planner/binder/statement/bind_drop.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_drop.cpp @@ -1,6 +1,5 @@ #include "duckdb/parser/statement/drop_statement.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/standard_entry.hpp" diff --git a/src/duckdb/src/planner/binder/statement/bind_execute.cpp b/src/duckdb/src/planner/binder/statement/bind_execute.cpp index cceb6796c..1202b01fa 100644 --- a/src/duckdb/src/planner/binder/statement/bind_execute.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_execute.cpp @@ -79,7 +79,7 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { prepared = prepared_planner.PrepareSQLStatement(entry->second->unbound_statement->Copy()); rebound_plan = std::move(prepared_planner.plan); D_ASSERT(prepared->properties.bound_all_parameters); - this->bound_tables = prepared_planner.binder->bound_tables; + global_binder_state->bound_tables = prepared_planner.binder->global_binder_state->bound_tables; } // copy the properties of the prepared statement into the planner auto &properties = GetStatementProperties(); diff --git a/src/duckdb/src/planner/binder/statement/bind_extension.cpp b/src/duckdb/src/planner/binder/statement/bind_extension.cpp index b4fc0e86b..6569315f7 100644 --- a/src/duckdb/src/planner/binder/statement/bind_extension.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_extension.cpp @@ -5,8 +5,6 @@ namespace duckdb { BoundStatement Binder::Bind(ExtensionStatement &stmt) { - BoundStatement result; - // perform the planning of the function D_ASSERT(stmt.extension.plan_function); auto parse_result = @@ -18,11 +16,9 @@ BoundStatement Binder::Bind(ExtensionStatement &stmt) { properties.return_type = parse_result.return_type; // create the plan as a scan of the given table function - result.plan = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); + auto result = BindTableFunction(parse_result.function, std::move(parse_result.parameters)); D_ASSERT(result.plan->type == LogicalOperatorType::LOGICAL_GET); auto &get = result.plan->Cast(); - result.names = get.names; - result.types = get.returned_types; get.ClearColumnIds(); for (idx_t i = 0; i < get.returned_types.size(); i++) { get.AddColumnId(i); diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp index f2c8db644..bddfdbb5e 100644 --- a/src/duckdb/src/planner/binder/statement/bind_insert.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -22,9 +22,6 @@ #include "duckdb/planner/expression/bound_default_expression.hpp" #include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/planner/bound_tableref.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/parser/tableref/basetableref.hpp" @@ -519,8 +516,6 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { } auto insert = make_uniq(table, GenerateTableIndex()); - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); auto values_list = stmt.GetValuesList(); diff --git a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp index 5b187c8e3..3108b76aa 100644 --- a/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_logical_plan.cpp @@ -32,7 +32,7 @@ BoundStatement Binder::Bind(LogicalPlanStatement &stmt) { if (parent) { throw InternalException("LogicalPlanStatement should be bound in root binder"); } - bound_tables = GetMaxTableIndex(*result.plan) + 1; + global_binder_state->bound_tables = GetMaxTableIndex(*result.plan) + 1; return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp index b52a04cf2..1dd59c480 100644 --- a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp @@ -1,6 +1,5 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/statement/merge_into_statement.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/expression_binder/where_binder.hpp" @@ -178,11 +177,14 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { auto target_binder = Binder::CreateBinder(context, this); string table_alias = stmt.target->alias; auto bound_table = target_binder->Bind(*stmt.target); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { throw BinderException("Can only merge into base tables!"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; + auto table_ptr = bound_table.plan->Cast().GetTable(); + if (!table_ptr) { + throw BinderException("Can only merge into base tables!"); + } + auto &table = *table_ptr; if (!table.temporary) { // update of persistent table: not read only! auto &properties = GetStatementProperties(); @@ -231,7 +233,7 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { } auto bound_join_node = Bind(join); - auto root = CreatePlan(*bound_join_node); + auto root = std::move(bound_join_node.plan); auto join_ref = reference(*root); while (join_ref.get().children.size() == 1) { join_ref = *join_ref.get().children[0]; diff --git a/src/duckdb/src/planner/binder/statement/bind_prepare.cpp b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp index cbb338dfc..66fb40d61 100644 --- a/src/duckdb/src/planner/binder/statement/bind_prepare.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_prepare.cpp @@ -8,7 +8,7 @@ namespace duckdb { BoundStatement Binder::Bind(PrepareStatement &stmt) { Planner prepared_planner(context); auto prepared_data = prepared_planner.PrepareSQLStatement(std::move(stmt.statement)); - this->bound_tables = prepared_planner.binder->bound_tables; + global_binder_state->bound_tables = prepared_planner.binder->global_binder_state->bound_tables; if (prepared_planner.properties.always_require_rebind) { // we always need to rebind - don't keep the plan around diff --git a/src/duckdb/src/planner/binder/statement/bind_simple.cpp b/src/duckdb/src/planner/binder/statement/bind_simple.cpp index 942f6784c..46758e416 100644 --- a/src/duckdb/src/planner/binder/statement/bind_simple.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_simple.cpp @@ -60,16 +60,15 @@ BoundStatement Binder::BindAlterAddIndex(BoundStatement &result, CatalogEntry &e TableDescription table_description(table_info.catalog, table_info.schema, table_info.name); auto table_ref = make_uniq(table_description); auto bound_table = Bind(*table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { throw BinderException("can only add an index to a base table"); } - auto plan = CreatePlan(*bound_table); - auto &get = plan->Cast(); + auto &get = bound_table.plan->Cast(); get.names = column_list.GetColumnNames(); auto alter_table_info = unique_ptr_cast(std::move(alter_info)); - result.plan = table.catalog.BindAlterAddIndex(*this, table, std::move(plan), std::move(create_index_info), - std::move(alter_table_info)); + result.plan = table.catalog.BindAlterAddIndex(*this, table, std::move(bound_table.plan), + std::move(create_index_info), std::move(alter_table_info)); return std::move(result); } diff --git a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp index 45b2b2f25..f8a68ae4c 100644 --- a/src/duckdb/src/planner/binder/statement/bind_summarize.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_summarize.cpp @@ -9,7 +9,6 @@ #include "duckdb/parser/tableref/showref.hpp" #include "duckdb/parser/tableref/basetableref.hpp" #include "duckdb/parser/expression/star_expression.hpp" -#include "duckdb/planner/bound_tableref.hpp" namespace duckdb { @@ -78,7 +77,7 @@ static unique_ptr SummarizeCreateNullPercentage(string column_ return make_uniq(LogicalType::DECIMAL(9, 2), std::move(case_expr)); } -unique_ptr Binder::BindSummarize(ShowRef &ref) { +BoundStatement Binder::BindSummarize(ShowRef &ref) { unique_ptr query; if (ref.query) { query = std::move(ref.query); diff --git a/src/duckdb/src/planner/binder/statement/bind_update.cpp b/src/duckdb/src/planner/binder/statement/bind_update.cpp index 650b23b89..fafcbaf02 100644 --- a/src/duckdb/src/planner/binder/statement/bind_update.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_update.cpp @@ -2,7 +2,6 @@ #include "duckdb/parser/statement/update_statement.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/constraints/bound_check_constraint.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_default_expression.hpp" @@ -12,7 +11,6 @@ #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_update.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/storage/data_table.hpp" @@ -110,14 +108,15 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { // visit the table reference auto bound_table = Bind(*stmt.table); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only update base table!"); + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only update base table"); } - auto &table_binding = bound_table->Cast(); - auto &table = table_binding.table; - - // Add CTEs as bindable - AddCTEMap(stmt.cte_map); + auto &bound_table_get = bound_table.plan->Cast(); + auto table_ptr = bound_table_get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only update base table"); + } + auto &table = *table_ptr; optional_ptr get; if (stmt.from_table) { @@ -129,7 +128,7 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { get = &root->children[0]->Cast(); bind_context.AddContext(std::move(from_binder->bind_context)); } else { - root = CreatePlan(*bound_table); + root = std::move(bound_table.plan); get = &root->Cast(); } diff --git a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp index 93e70fe5b..026f682b0 100644 --- a/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_vacuum.cpp @@ -15,12 +15,18 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr } D_ASSERT(vacuum.column_id_map.empty()); + auto bound_table = Bind(*info.ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw InvalidInputException("can only vacuum or analyze base tables"); + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw BinderException("Can only vacuum or analyze base tables"); + } + auto table_scan = std::move(bound_table.plan); + auto &get = table_scan->Cast(); + auto table_ptr = get.GetTable(); + if (!table_ptr) { + throw BinderException("Can only vacuum or analyze base tables"); } - auto ref = unique_ptr_cast(std::move(bound_table)); - auto &table = ref->table; + auto &table = *table_ptr; vacuum.SetTable(table); vector> select_list; @@ -60,11 +66,6 @@ void Binder::BindVacuumTable(LogicalVacuum &vacuum, unique_ptr } info.columns = std::move(non_generated_column_names); - auto table_scan = CreatePlan(*ref); - D_ASSERT(table_scan->type == LogicalOperatorType::LOGICAL_GET); - - auto &get = table_scan->Cast(); - auto &column_ids = get.GetColumnIds(); D_ASSERT(select_list.size() == column_ids.size()); D_ASSERT(info.columns.size() == column_ids.size()); diff --git a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp index 304ec793b..de775a198 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp @@ -11,15 +11,13 @@ #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" #include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/planner/tableref/bound_at_clause.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/parser/query_node/cte_node.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { @@ -48,10 +46,10 @@ static bool TryLoadExtensionForReplacementScan(ClientContext &context, const str return false; } -unique_ptr Binder::BindWithReplacementScan(ClientContext &context, BaseTableRef &ref) { +BoundStatement Binder::BindWithReplacementScan(ClientContext &context, BaseTableRef &ref) { auto &config = DBConfig::GetConfig(context); if (!context.config.use_replacement_scans) { - return nullptr; + return BoundStatement(); } for (auto &scan : config.replacement_scans) { ReplacementScanInput input(ref.catalog_name, ref.schema_name, ref.table_name); @@ -80,7 +78,7 @@ unique_ptr Binder::BindWithReplacementScan(ClientContext &context } return Bind(*replacement_function); } - return nullptr; + return BoundStatement(); } unique_ptr Binder::BindAtClause(optional_ptr at_clause) { @@ -116,62 +114,60 @@ static vector ExchangeAllNullTypes(const vector &types return result; } -unique_ptr Binder::Bind(BaseTableRef &ref) { +BoundStatement Binder::Bind(BaseTableRef &ref) { QueryErrorContext error_context(ref.query_location); // CTEs and views are also referred to using BaseTableRefs, hence need to distinguish here // check if the table name refers to a CTE // CTE name should never be qualified (i.e. schema_name should be empty) // unless we want to refer to the recurring table of "using key". - vector> found_ctes; - if (ref.schema_name.empty() || ref.schema_name == "recurring") { - found_ctes = FindCTE(ref.table_name, false); - } - - if (!found_ctes.empty()) { - // Check if there is a CTE binding in the BindContext - auto ctebinding = bind_context.GetCTEBinding(ref.table_name); - if (ctebinding) { - // There is a CTE binding in the BindContext. - // This can only be the case if there is a recursive CTE, - // or a materialized CTE present. - auto index = GenerateTableIndex(); - - if (ref.schema_name == "recurring") { - auto recurring_bindings = FindCTE("recurring." + ref.table_name, false); - if (recurring_bindings.empty()) { - throw BinderException(error_context, - "There is a WITH item named \"%s\", but the recurring table cannot be " - "referenced from this part of the query." - " Hint: RECURRING can only be used with USING KEY in recursive CTE.", - ref.table_name); - } + auto ctebinding = GetCTEBinding(ref.table_name); + if (ctebinding) { + // There is a CTE binding in the BindContext. + // This can only be the case if there is a recursive CTE, + // or a materialized CTE present. + auto index = GenerateTableIndex(); + + if (ref.schema_name == "recurring") { + auto recurring_bindings = GetCTEBinding("recurring." + ref.table_name); + if (!recurring_bindings) { + throw BinderException(error_context, + "There is a WITH item named \"%s\", but the recurring table cannot be " + "referenced from this part of the query." + " Hint: RECURRING can only be used with USING KEY in recursive CTE.", + ref.table_name); } + } - auto result = make_uniq(index, ctebinding->index, ref.schema_name == "recurring"); - auto alias = ref.alias.empty() ? ref.table_name : ref.alias; - auto names = BindContext::AliasColumnNames(alias, ctebinding->names, ref.column_name_alias); - - bind_context.AddGenericBinding(index, alias, names, ctebinding->types); - - auto cte_reference = ref.schema_name.empty() ? ref.table_name : ref.schema_name + "." + ref.table_name; + auto alias = ref.alias.empty() ? ref.table_name : ref.alias; + auto names = BindContext::AliasColumnNames(alias, ctebinding->names, ref.column_name_alias); - // Update references to CTE - auto cteref = bind_context.cte_references[cte_reference]; + bind_context.AddGenericBinding(index, alias, names, ctebinding->types); - if (cteref == nullptr && ref.schema_name == "recurring") { + auto cte_ref = reference(ctebinding->Cast()); + if (!ref.schema_name.empty()) { + auto cte_reference = ref.schema_name + "." + ref.table_name; + auto recurring_ref = GetCTEBinding(cte_reference); + if (!recurring_ref) { throw BinderException(error_context, "There is a WITH item named \"%s\", but the recurring table cannot be " "referenced from this part of the query.", ref.table_name); } - - (*cteref)++; - - result->types = ctebinding->types; - result->bound_columns = std::move(names); - return std::move(result); + cte_ref = reference(recurring_ref->Cast()); } + + // Update references to CTE + cte_ref.get().reference_count++; + bool is_recurring = ref.schema_name == "recurring"; + + BoundStatement result; + result.types = ctebinding->types; + result.names = names; + result.plan = + make_uniq(index, ctebinding->index, ctebinding->types, std::move(names), is_recurring); + return result; + ; } // not a CTE @@ -198,14 +194,19 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { vector types {LogicalType::INTEGER}; vector names {"__dummy_col" + to_string(table_index)}; bind_context.AddGenericBinding(table_index, ref_alias, names, types); - return make_uniq_base(table_index); + + BoundStatement result; + result.types = std::move(types); + result.names = std::move(names); + result.plan = make_uniq(table_index); + return result; } } if (!table_or_view) { // table could not be found: try to bind a replacement scan // Try replacement scan bind auto replacement_scan_bind_result = BindWithReplacementScan(context, ref); - if (replacement_scan_bind_result) { + if (replacement_scan_bind_result.plan) { return replacement_scan_bind_result; } @@ -214,7 +215,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { auto extension_loaded = TryLoadExtensionForReplacementScan(context, full_path); if (extension_loaded) { replacement_scan_bind_result = BindWithReplacementScan(context, ref); - if (replacement_scan_bind_result) { + if (replacement_scan_bind_result.plan) { return replacement_scan_bind_result; } } @@ -233,7 +234,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { // remember that we did not find a CTE, but there is a CTE with the same name // this means that there is a circular reference // Otherwise, re-throw the original exception - if (found_ctes.empty() && ref.schema_name.empty() && CTEExists(ref.table_name)) { + if (!ctebinding && ref.schema_name.empty() && CTEExists(ref.table_name)) { throw BinderException( error_context, "Circular reference to CTE \"%s\", There are two possible solutions. \n1. use WITH RECURSIVE to " @@ -251,7 +252,7 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { switch (table_or_view->type) { case CatalogType::TABLE_ENTRY: { - // base table: create the BoundBaseTableRef node + // base table auto table_index = GenerateTableIndex(); auto &table = table_or_view->Cast(); @@ -294,7 +295,11 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { } else { bind_context.AddBaseTable(table_index, ref.alias, table_names, table_types, col_ids, *table_entry); } - return make_uniq_base(table, std::move(logical_get)); + BoundStatement result; + result.types = table_types; + result.names = table_names; + result.plan = std::move(logical_get); + return result; } case CatalogType::VIEW_ENTRY: { // the node is a view: get the query that the view represents @@ -355,15 +360,13 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { throw BinderException("Contents of view were altered - view bound correlated columns"); } - D_ASSERT(bound_child->type == TableReferenceType::SUBQUERY); // verify that the types and names match up with the expected types and names if the view has type info defined - auto &bound_subquery = bound_child->Cast(); if (GetBindingMode() != BindingMode::EXTRACT_NAMES && GetBindingMode() != BindingMode::EXTRACT_QUALIFIED_NAMES && view_catalog_entry.HasTypes()) { // we bind the view subquery and the original view with different "can_contain_nulls", // but we don't want to throw an error when SQLNULL does not match up with INTEGER, // so we exchange all SQLNULL with INTEGER here before comparing - auto bound_types = ExchangeAllNullTypes(bound_subquery.subquery.types); + auto bound_types = ExchangeAllNullTypes(bound_child.types); auto view_types = ExchangeAllNullTypes(view_catalog_entry.types); if (bound_types != view_types) { auto actual_types = StringUtil::ToString(bound_types, ", "); @@ -372,17 +375,17 @@ unique_ptr Binder::Bind(BaseTableRef &ref) { "Contents of view were altered: types don't match! Expected [%s], but found [%s] instead", expected_types, actual_types); } - if (bound_subquery.subquery.names.size() == view_catalog_entry.names.size() && - bound_subquery.subquery.names != view_catalog_entry.names) { - auto actual_names = StringUtil::Join(bound_subquery.subquery.names, ", "); + if (bound_child.names.size() == view_catalog_entry.names.size() && + bound_child.names != view_catalog_entry.names) { + auto actual_names = StringUtil::Join(bound_child.names, ", "); auto expected_names = StringUtil::Join(view_catalog_entry.names, ", "); throw BinderException( "Contents of view were altered: names don't match! Expected [%s], but found [%s] instead", expected_names, actual_names); } } - bind_context.AddView(bound_subquery.subquery.plan->GetRootIndex(), subquery.alias, subquery, - bound_subquery.subquery, view_catalog_entry); + bind_context.AddView(bound_child.plan->GetRootIndex(), subquery.alias, subquery, bound_child, + view_catalog_entry); return bound_child; } default: diff --git a/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp b/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp index e31c2e83c..ace531ccf 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_bound_table_ref.cpp @@ -2,8 +2,8 @@ namespace duckdb { -unique_ptr Binder::Bind(BoundRefWrapper &ref) { - if (!ref.binder || !ref.bound_ref) { +BoundStatement Binder::Bind(BoundRefWrapper &ref) { + if (!ref.binder || !ref.bound_ref.plan) { throw InternalException("Rebinding bound ref that was already bound"); } bind_context.AddContext(std::move(ref.binder->bind_context)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp index 635d23f71..d3c5ea4a2 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_column_data_ref.cpp @@ -1,20 +1,25 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/tableref/column_data_ref.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" #include "duckdb/planner/operator/logical_column_data_get.hpp" namespace duckdb { -unique_ptr Binder::Bind(ColumnDataRef &ref) { +BoundStatement Binder::Bind(ColumnDataRef &ref) { auto &collection = *ref.collection; auto types = collection.Types(); - auto result = make_uniq(std::move(ref.collection)); - result->bind_index = GenerateTableIndex(); - for (idx_t i = ref.expected_names.size(); i < types.size(); i++) { - ref.expected_names.push_back("col" + to_string(i + 1)); + + BoundStatement result; + result.names = std::move(ref.expected_names); + for (idx_t i = result.names.size(); i < types.size(); i++) { + result.names.push_back("col" + to_string(i + 1)); } - bind_context.AddGenericBinding(result->bind_index, ref.alias, ref.expected_names, types); - return unique_ptr_cast(std::move(result)); + result.types = types; + auto bind_index = GenerateTableIndex(); + bind_context.AddGenericBinding(bind_index, ref.alias, result.names, types); + + result.plan = + make_uniq_base(bind_index, std::move(types), std::move(ref.collection)); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp index f280404f9..18c27cccf 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_delimgetref.cpp @@ -1,16 +1,21 @@ #include "duckdb/parser/tableref/delimgetref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_delimgetref.hpp" +#include "duckdb/planner/operator/logical_delim_get.hpp" namespace duckdb { -unique_ptr Binder::Bind(DelimGetRef &ref) { +BoundStatement Binder::Bind(DelimGetRef &ref) { // Have to add bindings idx_t tbl_idx = GenerateTableIndex(); string internal_name = "__internal_delim_get_ref_" + std::to_string(tbl_idx); - bind_context.AddGenericBinding(tbl_idx, internal_name, ref.internal_aliases, ref.types); - return make_uniq(tbl_idx, ref.types); + BoundStatement result; + result.types = std::move(ref.types); + result.names = std::move(ref.internal_aliases); + result.plan = make_uniq(tbl_idx, result.types); + + bind_context.AddGenericBinding(tbl_idx, internal_name, result.names, result.types); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp index fe0e96f3d..b6ea93ab8 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_emptytableref.cpp @@ -1,11 +1,13 @@ #include "duckdb/parser/tableref/emptytableref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { -unique_ptr Binder::Bind(EmptyTableRef &ref) { - return make_uniq(GenerateTableIndex()); +BoundStatement Binder::Bind(EmptyTableRef &ref) { + BoundStatement result; + result.plan = make_uniq(GenerateTableIndex()); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp index 7176fb682..139f94670 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp @@ -1,72 +1,87 @@ #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" #include "duckdb/parser/tableref/expressionlistref.hpp" #include "duckdb/planner/expression_binder/insert_binder.hpp" #include "duckdb/common/to_string.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/operator/logical_expression_get.hpp" +#include "duckdb/planner/operator/logical_dummy_scan.hpp" namespace duckdb { -unique_ptr Binder::Bind(ExpressionListRef &expr) { - auto result = make_uniq(); - result->types = expr.expected_types; - result->names = expr.expected_names; +BoundStatement Binder::Bind(ExpressionListRef &expr) { + BoundStatement result; + result.types = expr.expected_types; + result.names = expr.expected_names; + + vector>> values; auto prev_can_contain_nulls = this->can_contain_nulls; // bind value list InsertBinder binder(*this, context); binder.target_type = LogicalType(LogicalTypeId::INVALID); for (idx_t list_idx = 0; list_idx < expr.values.size(); list_idx++) { auto &expression_list = expr.values[list_idx]; - if (result->names.empty()) { + if (result.names.empty()) { // no names provided, generate them for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - result->names.push_back("col" + to_string(val_idx)); + result.names.push_back("col" + to_string(val_idx)); } } this->can_contain_nulls = true; vector> list; for (idx_t val_idx = 0; val_idx < expression_list.size(); val_idx++) { - if (!result->types.empty()) { - D_ASSERT(result->types.size() == expression_list.size()); - binder.target_type = result->types[val_idx]; + if (!result.types.empty()) { + D_ASSERT(result.types.size() == expression_list.size()); + binder.target_type = result.types[val_idx]; } auto bound_expr = binder.Bind(expression_list[val_idx]); list.push_back(std::move(bound_expr)); } - result->values.push_back(std::move(list)); + values.push_back(std::move(list)); this->can_contain_nulls = prev_can_contain_nulls; } - if (result->types.empty() && !expr.values.empty()) { + if (result.types.empty() && !expr.values.empty()) { // there are no types specified // we have to figure out the result types // for each column, we iterate over all of the expressions and select the max logical type // we initialize all types to SQLNULL - result->types.resize(expr.values[0].size(), LogicalType::SQLNULL); + result.types.resize(expr.values[0].size(), LogicalType::SQLNULL); // now loop over the lists and select the max logical type - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; + for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { + auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { - auto ¤t_type = result->types[val_idx]; + auto ¤t_type = result.types[val_idx]; auto next_type = ExpressionBinder::GetExpressionReturnType(*list[val_idx]); - result->types[val_idx] = LogicalType::MaxLogicalType(context, current_type, next_type); + result.types[val_idx] = LogicalType::MaxLogicalType(context, current_type, next_type); } } - for (auto &type : result->types) { + for (auto &type : result.types) { type = LogicalType::NormalizeType(type); } // finally do another loop over the expressions and add casts where required - for (idx_t list_idx = 0; list_idx < result->values.size(); list_idx++) { - auto &list = result->values[list_idx]; + for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { + auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { list[val_idx] = - BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result->types[val_idx]); + BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result.types[val_idx]); } } } - result->bind_index = GenerateTableIndex(); - bind_context.AddGenericBinding(result->bind_index, expr.alias, result->names, result->types); - return std::move(result); + auto bind_index = GenerateTableIndex(); + bind_context.AddGenericBinding(bind_index, expr.alias, result.names, result.types); + + // values list, first plan any subqueries in the list + auto root = make_uniq_base(GenerateTableIndex()); + for (auto &expr_list : values) { + for (auto &expr : expr_list) { + PlanSubqueries(expr, root); + } + } + + auto expr_get = make_uniq(bind_index, result.types, std::move(values)); + expr_get->AddChild(std::move(root)); + result.plan = std::move(expr_get); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp index 257e275be..258bd3331 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp @@ -122,14 +122,14 @@ static vector RemoveDuplicateUsingColumns(const vector &using_co return result; } -unique_ptr Binder::BindJoin(Binder &parent_binder, TableRef &ref) { +BoundStatement Binder::BindJoin(Binder &parent_binder, TableRef &ref) { unnamed_subquery_index = parent_binder.unnamed_subquery_index; auto result = Bind(ref); parent_binder.unnamed_subquery_index = unnamed_subquery_index; return result; } -unique_ptr Binder::Bind(JoinRef &ref) { +BoundStatement Binder::Bind(JoinRef &ref) { auto result = make_uniq(ref.ref_type); result->left_binder = Binder::CreateBinder(context, this); result->right_binder = Binder::CreateBinder(context, this); @@ -351,7 +351,13 @@ unique_ptr Binder::Bind(JoinRef &ref) { bind_context.RemoveContext(left_bindings); } - return std::move(result); + BoundStatement result_stmt; + result_stmt.types.insert(result_stmt.types.end(), result->left.types.begin(), result->left.types.end()); + result_stmt.types.insert(result_stmt.types.end(), result->right.types.begin(), result->right.types.end()); + result_stmt.names.insert(result_stmt.names.end(), result->left.names.begin(), result->left.names.end()); + result_stmt.names.insert(result_stmt.names.end(), result->right.names.begin(), result->right.names.end()); + result_stmt.plan = CreatePlan(*result); + return result_stmt; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index e77e93ce0..51cfae932 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -13,7 +13,6 @@ #include "duckdb/common/types/value_map.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/parser/expression/operator_expression.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" #include "duckdb/planner/tableref/bound_pivotref.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/main/client_config.hpp" @@ -21,6 +20,7 @@ #include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_pivot.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { @@ -383,12 +383,8 @@ static unique_ptr PivotFinalOperator(PivotBindState &bind_state, Piv return final_pivot_operator; } -void ExtractPivotAggregates(BoundTableRef &node, vector> &aggregates) { - if (node.type != TableReferenceType::SUBQUERY) { - throw InternalException("Pivot - Expected a subquery"); - } - auto &subq = node.Cast(); - reference op(*subq.subquery.plan); +void ExtractPivotAggregates(BoundStatement &node, vector> &aggregates) { + reference op(*node.plan); bool found_first_aggregate = false; while (true) { if (op.get().type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY) { @@ -420,15 +416,15 @@ string GetPivotAggregateName(const PivotValueElement &pivot_value, const string return name; } -unique_ptr Binder::BindBoundPivot(PivotRef &ref) { +BoundStatement Binder::BindBoundPivot(PivotRef &ref) { // bind the child table in a child binder - auto result = make_uniq(); - result->bind_index = GenerateTableIndex(); - result->child_binder = Binder::CreateBinder(context, this); - result->child = result->child_binder->Bind(*ref.source); + BoundPivotRef result; + result.bind_index = GenerateTableIndex(); + result.child_binder = Binder::CreateBinder(context, this); + result.child = result.child_binder->Bind(*ref.source); - auto &aggregates = result->bound_pivot.aggregates; - ExtractPivotAggregates(*result->child, aggregates); + auto &aggregates = result.bound_pivot.aggregates; + ExtractPivotAggregates(result.child, aggregates); if (aggregates.size() != ref.bound_aggregate_names.size()) { throw InternalException("Pivot aggregate count mismatch (expected %llu, found %llu)", ref.bound_aggregate_names.size(), aggregates.size()); @@ -436,7 +432,7 @@ unique_ptr Binder::BindBoundPivot(PivotRef &ref) { vector child_names; vector child_types; - result->child_binder->bind_context.GetTypesAndNames(child_names, child_types); + result.child_binder->bind_context.GetTypesAndNames(child_names, child_types); vector names; vector types; @@ -461,19 +457,23 @@ unique_ptr Binder::BindBoundPivot(PivotRef &ref) { pivot_str += "_" + str; } } - result->bound_pivot.pivot_values.push_back(std::move(pivot_str)); + result.bound_pivot.pivot_values.push_back(std::move(pivot_str)); names.push_back(std::move(name)); types.push_back(aggr->return_type); } } - result->bound_pivot.group_count = ref.bound_group_names.size(); - result->bound_pivot.types = types; + result.bound_pivot.group_count = ref.bound_group_names.size(); + result.bound_pivot.types = types; auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; QueryResult::DeduplicateColumns(names); - bind_context.AddGenericBinding(result->bind_index, subquery_alias, names, types); + bind_context.AddGenericBinding(result.bind_index, subquery_alias, names, types); + + MoveCorrelatedExpressions(*result.child_binder); - MoveCorrelatedExpressions(*result->child_binder); - return std::move(result); + BoundStatement result_statement; + result_statement.plan = + make_uniq(result.bind_index, std::move(result.child.plan), std::move(result.bound_pivot)); + return result_statement; } unique_ptr Binder::BindPivot(PivotRef &ref, vector> all_columns) { @@ -835,7 +835,7 @@ unique_ptr Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, return result_node; } -unique_ptr Binder::Bind(PivotRef &ref) { +BoundStatement Binder::Bind(PivotRef &ref) { if (!ref.source) { throw InternalException("Pivot without a source!?"); } @@ -866,11 +866,10 @@ unique_ptr Binder::Bind(PivotRef &ref) { } // bind the generated select node auto child_binder = Binder::CreateBinder(context, this); - auto bound_select_node = child_binder->BindNode(*select_node); - auto root_index = bound_select_node.plan->GetRootIndex(); + auto result = child_binder->BindNode(*select_node); + auto root_index = result.plan->GetRootIndex(); MoveCorrelatedExpressions(*child_binder); - auto result = make_uniq(std::move(child_binder), std::move(bound_select_node)); auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; SubqueryRef subquery_ref(nullptr, subquery_alias); subquery_ref.column_name_alias = std::move(ref.column_name_alias); @@ -878,16 +877,15 @@ unique_ptr Binder::Bind(PivotRef &ref) { // if a WHERE clause was provided - bind a subquery holding the WHERE clause // we need to bind a new subquery here because the WHERE clause has to be applied AFTER the unnest child_binder = Binder::CreateBinder(context, this); - child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result->subquery); + child_binder->bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result); auto where_query = make_uniq(); where_query->select_list.push_back(make_uniq()); where_query->where_clause = std::move(where_clause); - bound_select_node = child_binder->BindSelectNode(*where_query, std::move(result)); - root_index = bound_select_node.plan->GetRootIndex(); - result = make_uniq(std::move(child_binder), std::move(bound_select_node)); + result = child_binder->BindSelectNode(*where_query, std::move(result)); + root_index = result.plan->GetRootIndex(); } - bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result->subquery); - return std::move(result); + bind_context.AddSubquery(root_index, subquery_ref.alias, subquery_ref, result); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp index b23456cab..d2d91c3af 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp @@ -5,12 +5,10 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_column_data_get.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_search_path.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/main/client_context.hpp" @@ -89,7 +87,7 @@ BaseTableColumnInfo FindBaseTableColumn(LogicalOperator &op, idx_t column_index) return FindBaseTableColumn(op, bindings[column_index]); } -unique_ptr Binder::BindShowQuery(ShowRef &ref) { +BoundStatement Binder::BindShowQuery(ShowRef &ref) { // bind the child plan of the DESCRIBE statement auto child_binder = Binder::CreateBinder(context, this); auto plan = child_binder->Bind(*ref.query); @@ -142,12 +140,17 @@ unique_ptr Binder::BindShowQuery(ShowRef &ref) { } collection->Append(append_state, output); - auto show = make_uniq(GenerateTableIndex(), return_types, std::move(collection)); - bind_context.AddGenericBinding(show->table_index, "__show_select", return_names, return_types); - return make_uniq(std::move(show)); + auto table_index = GenerateTableIndex(); + + BoundStatement result; + result.names = return_names; + result.types = return_types; + result.plan = make_uniq(table_index, return_types, std::move(collection)); + bind_context.AddGenericBinding(table_index, "__show_select", return_names, return_types); + return result; } -unique_ptr Binder::BindShowTable(ShowRef &ref) { +BoundStatement Binder::BindShowTable(ShowRef &ref) { auto lname = StringUtil::Lower(ref.table_name); string sql; @@ -193,7 +196,7 @@ unique_ptr Binder::BindShowTable(ShowRef &ref) { return Bind(*subquery); } -unique_ptr Binder::Bind(ShowRef &ref) { +BoundStatement Binder::Bind(ShowRef &ref) { if (ref.show_type == ShowType::SUMMARY) { return BindSummarize(ref); } diff --git a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp index 503299d4e..cfa727927 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_subqueryref.cpp @@ -1,10 +1,9 @@ #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" namespace duckdb { -unique_ptr Binder::Bind(SubqueryRef &ref) { +BoundStatement Binder::Bind(SubqueryRef &ref) { auto binder = Binder::CreateBinder(context, this); binder->can_contain_nulls = true; auto subquery = binder->BindNode(*ref.subquery->node); @@ -21,10 +20,14 @@ unique_ptr Binder::Bind(SubqueryRef &ref) { } else { subquery_alias = ref.alias; } - auto result = make_uniq(std::move(binder), std::move(subquery)); - bind_context.AddSubquery(bind_index, subquery_alias, ref, result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); + binder->is_outside_flattened = is_outside_flattened; + if (binder->has_unplanned_dependent_joins) { + has_unplanned_dependent_joins = true; + } + bind_context.AddSubquery(bind_index, subquery_alias, ref, subquery); + MoveCorrelatedExpressions(*binder); + + return subquery; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp index 57ef40c3f..abb057011 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp @@ -13,9 +13,6 @@ #include "duckdb/planner/expression_binder/table_function_binder.hpp" #include "duckdb/planner/expression_binder/select_binder.hpp" #include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" @@ -79,32 +76,28 @@ static TableFunctionBindType GetTableFunctionBindType(TableFunctionCatalogEntry : TableFunctionBindType::STANDARD_TABLE_FUNCTION; } -void Binder::BindTableInTableOutFunction(vector> &expressions, - unique_ptr &subquery) { +void Binder::BindTableInTableOutFunction(vector> &expressions, BoundStatement &subquery) { auto binder = Binder::CreateBinder(this->context, this); - unique_ptr subquery_node; // generate a subquery and bind that (i.e. UNNEST([1,2,3]) becomes UNNEST((SELECT [1,2,3])) auto select_node = make_uniq(); select_node->select_list = std::move(expressions); select_node->from_table = make_uniq(); - subquery_node = std::move(select_node); binder->can_contain_nulls = true; - auto node = binder->BindNode(*subquery_node); - subquery = make_uniq(std::move(binder), std::move(node)); - MoveCorrelatedExpressions(*subquery->binder); + subquery = binder->BindNode(*select_node); + MoveCorrelatedExpressions(*binder); } bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_function, vector> &expressions, vector &arguments, vector ¶meters, - named_parameter_map_t &named_parameters, - unique_ptr &subquery, ErrorData &error) { + named_parameter_map_t &named_parameters, BoundStatement &subquery, + ErrorData &error) { auto bind_type = GetTableFunctionBindType(table_function, expressions); if (bind_type == TableFunctionBindType::TABLE_IN_OUT_FUNCTION) { // bind table in-out function BindTableInTableOutFunction(expressions, subquery); // fetch the arguments from the subquery - arguments = subquery->subquery.types; + arguments = subquery.types; return true; } bool seen_subquery = false; @@ -142,9 +135,8 @@ bool Binder::BindTableFunctionParameters(TableFunctionCatalogEntry &table_functi auto binder = Binder::CreateBinder(this->context, this); binder->can_contain_nulls = true; auto &se = child->Cast(); - auto node = binder->BindNode(*se.subquery->node); - subquery = make_uniq(std::move(binder), std::move(node)); - MoveCorrelatedExpressions(*subquery->binder); + subquery = binder->BindNode(*se.subquery->node); + MoveCorrelatedExpressions(*binder); seen_subquery = true; arguments.emplace_back(LogicalTypeId::TABLE); parameters.emplace_back(Value()); @@ -188,11 +180,10 @@ static string GetAlias(const TableFunctionRef &ref) { return string(); } -unique_ptr Binder::BindTableFunctionInternal(TableFunction &table_function, - const TableFunctionRef &ref, vector parameters, - named_parameter_map_t named_parameters, - vector input_table_types, - vector input_table_names) { +BoundStatement Binder::BindTableFunctionInternal(TableFunction &table_function, const TableFunctionRef &ref, + vector parameters, named_parameter_map_t named_parameters, + vector input_table_types, + vector input_table_names) { auto function_name = GetAlias(ref); auto &column_name_alias = ref.column_name_alias; auto bind_index = GenerateTableIndex(); @@ -221,8 +212,12 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab table_function.name); } } + BoundStatement result; bind_context.AddGenericBinding(bind_index, function_name, return_names, new_plan->types); - return new_plan; + result.names = return_names; + result.types = new_plan->types; + result.plan = std::move(new_plan); + return result; } } if (table_function.bind_replace) { @@ -234,7 +229,7 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab if (!ref.column_name_alias.empty()) { new_plan->column_name_alias = ref.column_name_alias; } - return CreatePlan(*Bind(*new_plan)); + return Bind(*new_plan); } } if (!table_function.bind) { @@ -343,16 +338,24 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab return_types.push_back(LogicalType::BIGINT); bind_context.AddGenericBinding(projection_index, function_name, return_names, return_types); - return std::move(projection); + BoundStatement result; + result.names = std::move(return_names); + result.types = std::move(return_types); + result.plan = std::move(projection); + return result; } // now add the table function to the bind context so its columns can be bound + BoundStatement result; bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), get->GetTable().get(), std::move(virtual_columns)); - return std::move(get); + result.names = std::move(return_names); + result.types = std::move(return_types); + result.plan = std::move(get); + return result; } -unique_ptr Binder::BindTableFunction(TableFunction &function, vector parameters) { +BoundStatement Binder::BindTableFunction(TableFunction &function, vector parameters) { named_parameter_map_t named_parameters; vector input_table_types; vector input_table_names; @@ -364,7 +367,7 @@ unique_ptr Binder::BindTableFunction(TableFunction &function, v std::move(input_table_types), std::move(input_table_names)); } -unique_ptr Binder::Bind(TableFunctionRef &ref) { +BoundStatement Binder::Bind(TableFunctionRef &ref) { QueryErrorContext error_context(ref.query_location); D_ASSERT(ref.function->GetExpressionType() == ExpressionType::FUNCTION); @@ -401,11 +404,10 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { // string alias; string alias = (ref.alias.empty() ? "unnamed_query" + to_string(bind_index) : ref.alias); - auto result = make_uniq(std::move(binder), std::move(query)); // remember ref here is TableFunctionRef and NOT base class - bind_context.AddSubquery(bind_index, alias, ref, result->subquery); - MoveCorrelatedExpressions(*result->binder); - return std::move(result); + bind_context.AddSubquery(bind_index, alias, ref, query); + MoveCorrelatedExpressions(*binder); + return query; } D_ASSERT(func_catalog.type == CatalogType::TABLE_FUNCTION_ENTRY); auto &function = func_catalog.Cast(); @@ -414,7 +416,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { vector arguments; vector parameters; named_parameter_map_t named_parameters; - unique_ptr subquery; + BoundStatement subquery; ErrorData error; if (!BindTableFunctionParameters(function, fexpr.children, arguments, parameters, named_parameters, subquery, error)) { @@ -437,9 +439,9 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { vector input_table_types; vector input_table_names; - if (subquery) { - input_table_types = subquery->subquery.types; - input_table_names = subquery->subquery.names; + if (subquery.plan) { + input_table_types = subquery.types; + input_table_names = subquery.names; } else if (table_function.in_out_function) { for (auto ¶m : parameters) { input_table_types.push_back(param.type()); @@ -457,7 +459,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { parameters[i] = parameters[i].CastAs(context, target_type); } } - } else if (subquery) { + } else if (subquery.plan) { for (idx_t i = 0; i < arguments.size(); i++) { auto target_type = i < table_function.arguments.size() ? table_function.arguments[i] : table_function.varargs; @@ -469,7 +471,7 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { } } - unique_ptr get; + BoundStatement get; try { get = BindTableFunctionInternal(table_function, ref, std::move(parameters), std::move(named_parameters), std::move(input_table_types), std::move(input_table_names)); @@ -478,9 +480,30 @@ unique_ptr Binder::Bind(TableFunctionRef &ref) { error.AddQueryLocation(ref); error.Throw(); } - auto table_function_ref = make_uniq(std::move(get)); - table_function_ref->subquery = std::move(subquery); - return std::move(table_function_ref); + + if (subquery.plan) { + auto child_node = std::move(subquery.plan); + + reference node = *get.plan; + + while (!node.get().children.empty()) { + D_ASSERT(node.get().children.size() == 1); + if (node.get().children.size() != 1) { + throw InternalException( + "Binder::CreatePlan: linear path expected, but found node with %d children", + node.get().children.size()); + } + node = *node.get().children[0]; + } + + D_ASSERT(node.get().type == LogicalOperatorType::LOGICAL_GET); + node.get().children.push_back(std::move(child_node)); + } + BoundStatement result_statement; + result_statement.names = get.names; + result_statement.types = get.types; + result_statement.plan = std::move(get.plan); + return result_statement; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp deleted file mode 100644 index 085498fbb..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_basetableref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundBaseTableRef &ref) { - return std::move(ref.get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp b/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp deleted file mode 100644 index 83e965b5e..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_column_data_ref.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_column_data_ref.hpp" -#include "duckdb/planner/operator/logical_column_data_get.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundColumnDataRef &ref) { - auto types = ref.collection->Types(); - // Create a (potentially owning) LogicalColumnDataGet - auto root = make_uniq_base(ref.bind_index, std::move(types), - std::move(ref.collection)); - return root; -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp b/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp deleted file mode 100644 index 4ee2b9a76..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_cteref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_cteref.hpp" -#include "duckdb/planner/tableref/bound_cteref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTERef &ref) { - return make_uniq(ref.bind_index, ref.cte_index, ref.types, ref.bound_columns, ref.is_recurring); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp b/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp deleted file mode 100644 index b674b43df..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_delimgetref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/tableref/bound_basetableref.hpp" -#include "duckdb/planner/operator/logical_delim_get.hpp" -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundDelimGetRef &ref) { - return make_uniq(ref.bind_index, ref.column_types); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp b/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp deleted file mode 100644 index f31fc929b..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_dummytableref.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" -#include "duckdb/planner/tableref/bound_dummytableref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundEmptyTableRef &ref) { - return make_uniq(ref.bind_index); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp deleted file mode 100644 index ba6253bce..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_expressionlistref.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_expressionlistref.hpp" -#include "duckdb/planner/operator/logical_expression_get.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundExpressionListRef &ref) { - auto root = make_uniq_base(GenerateTableIndex()); - // values list, first plan any subqueries in the list - for (auto &expr_list : ref.values) { - for (auto &expr : expr_list) { - PlanSubqueries(expr, root); - } - } - // now create a LogicalExpressionGet from the set of expressions - // fetch the types - vector types; - for (auto &expr : ref.values[0]) { - types.push_back(expr->return_type); - } - auto expr_get = make_uniq(ref.bind_index, types, std::move(ref.values)); - expr_get->AddChild(std::move(root)); - return std::move(expr_get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp index 9de5829f2..aa6193b13 100644 --- a/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/plan_joinref.cpp @@ -298,8 +298,8 @@ unique_ptr Binder::CreatePlan(BoundJoinRef &ref) { // Set the flag to ensure that children do not flatten before the root is_outside_flattened = false; } - auto left = CreatePlan(*ref.left); - auto right = CreatePlan(*ref.right); + auto left = std::move(ref.left.plan); + auto right = std::move(ref.right.plan); is_outside_flattened = old_is_outside_flattened; // For joins, depth of the bindings will be one higher on the right because of the lateral binder diff --git a/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp b/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp deleted file mode 100644 index 4d9482e5b..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_pivotref.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "duckdb/planner/tableref/bound_pivotref.hpp" -#include "duckdb/planner/operator/logical_pivot.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundPivotRef &ref) { - auto subquery = ref.child_binder->CreatePlan(*ref.child); - - auto result = make_uniq(ref.bind_index, std::move(subquery), std::move(ref.bound_pivot)); - return std::move(result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp b/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp deleted file mode 100644 index 745bff555..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_subqueryref.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundSubqueryRef &ref) { - // generate the logical plan for the subquery - // this happens separately from the current LogicalPlan generation - ref.binder->is_outside_flattened = is_outside_flattened; - auto subquery = std::move(ref.subquery.plan); - if (ref.binder->has_unplanned_dependent_joins) { - has_unplanned_dependent_joins = true; - } - return subquery; -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp b/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp deleted file mode 100644 index 6c2f9957a..000000000 --- a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/tableref/bound_table_function.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundTableFunction &ref) { - if (ref.subquery) { - auto child_node = CreatePlan(*ref.subquery); - - reference node = *ref.get; - - while (!node.get().children.empty()) { - D_ASSERT(node.get().children.size() == 1); - if (node.get().children.size() != 1) { - throw InternalException( - "Binder::CreatePlan: linear path expected, but found node with %d children", - node.get().children.size()); - } - node = *node.get().children[0]; - } - - D_ASSERT(node.get().type == LogicalOperatorType::LOGICAL_GET); - node.get().children.push_back(std::move(child_node)); - } - return std::move(ref.get); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/operator/logical_vacuum.cpp b/src/duckdb/src/planner/operator/logical_vacuum.cpp index 36352a0ea..ce4a76951 100644 --- a/src/duckdb/src/planner/operator/logical_vacuum.cpp +++ b/src/duckdb/src/planner/operator/logical_vacuum.cpp @@ -1,5 +1,5 @@ #include "duckdb/planner/operator/logical_vacuum.hpp" - +#include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -46,11 +46,14 @@ unique_ptr LogicalVacuum::Deserialize(Deserializer &deserialize auto &context = deserializer.Get(); auto binder = Binder::CreateBinder(context); auto bound_table = binder->Bind(*info.ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { + if (bound_table.plan->type != LogicalOperatorType::LOGICAL_GET) { + throw InvalidInputException("can only vacuum or analyze base tables"); + } + auto table_ptr = bound_table.plan->Cast().GetTable(); + if (!table_ptr) { throw InvalidInputException("can only vacuum or analyze base tables"); } - auto ref = unique_ptr_cast(std::move(bound_table)); - auto &table = ref->table; + auto &table = *table_ptr; result->SetTable(table); // FIXME: we should probably verify that the 'column_id_map' and 'columns' are the same on the bound table after // deserialization? diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp index 78bca8a02..a38bc2a6c 100644 --- a/src/duckdb/src/planner/planner.cpp +++ b/src/duckdb/src/planner/planner.cpp @@ -41,7 +41,7 @@ void Planner::CreatePlan(SQLStatement &statement) { bool parameters_resolved = true; try { profiler.StartPhase(MetricsType::PLANNER_BINDING); - binder->parameters = &bound_parameters; + binder->SetParameters(bound_parameters); auto bound_statement = binder->Bind(statement); profiler.EndPhase(); diff --git a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp index d376eec81..10004d8f7 100644 --- a/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp +++ b/src/duckdb/src/planner/subquery/rewrite_correlated_expressions.cpp @@ -9,7 +9,6 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/tableref/bound_joinref.hpp" #include "duckdb/planner/operator/logical_dependent_join.hpp" -#include "duckdb/planner/tableref/bound_subqueryref.hpp" namespace duckdb { diff --git a/src/duckdb/src/planner/table_binding.cpp b/src/duckdb/src/planner/table_binding.cpp index d9bdd71c7..c60513cc9 100644 --- a/src/duckdb/src/planner/table_binding.cpp +++ b/src/duckdb/src/planner/table_binding.cpp @@ -304,4 +304,8 @@ unique_ptr DummyBinding::ParamToArg(ColumnRefExpression &colre return arg; } +CTEBinding::CTEBinding(BindingAlias alias, vector types, vector names, idx_t index) + : Binding(BindingType::CTE, std::move(alias), std::move(types), std::move(names), index), reference_count(0) { +} + } // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint_manager.cpp b/src/duckdb/src/storage/checkpoint_manager.cpp index af361d4bf..7e7b1c0c9 100644 --- a/src/duckdb/src/storage/checkpoint_manager.cpp +++ b/src/duckdb/src/storage/checkpoint_manager.cpp @@ -21,7 +21,6 @@ #include "duckdb/parser/parsed_data/create_schema_info.hpp" #include "duckdb/parser/parsed_data/create_view_info.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_tableref.hpp" #include "duckdb/planner/parsed_data/bound_create_table_info.hpp" #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/checkpoint/table_data_reader.hpp" diff --git a/src/duckdb/src/storage/compression/dict_fsst.cpp b/src/duckdb/src/storage/compression/dict_fsst.cpp index 636f5db6d..1b3e73852 100644 --- a/src/duckdb/src/storage/compression/dict_fsst.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst.cpp @@ -116,8 +116,10 @@ unique_ptr DictFSSTCompressionStorage::StringInitScan(const Qu auto &buffer_manager = BufferManager::GetBufferManager(segment.db); auto state = make_uniq(segment, buffer_manager.Pin(segment.block)); state->Initialize(true); - if (StringStats::HasMaxStringLength(segment.stats.statistics)) { - state->all_values_inlined = StringStats::MaxStringLength(segment.stats.statistics) <= string_t::INLINE_LENGTH; + + const auto &stats = segment.stats.statistics; + if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { + state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } return std::move(state); } diff --git a/src/duckdb/src/storage/compression/fsst.cpp b/src/duckdb/src/storage/compression/fsst.cpp index 7e07f7f6f..eaf1c5e78 100644 --- a/src/duckdb/src/storage/compression/fsst.cpp +++ b/src/duckdb/src/storage/compression/fsst.cpp @@ -585,8 +585,9 @@ unique_ptr FSSTStorage::StringInitScan(const QueryContext &con } state->duckdb_fsst_decoder_ptr = state->duckdb_fsst_decoder.get(); - if (StringStats::HasMaxStringLength(segment.stats.statistics)) { - state->all_values_inlined = StringStats::MaxStringLength(segment.stats.statistics) <= string_t::INLINE_LENGTH; + const auto &stats = segment.stats.statistics; + if (stats.GetStatsType() == StatisticsType::STRING_STATS && StringStats::HasMaxStringLength(stats)) { + state->all_values_inlined = StringStats::MaxStringLength(stats) <= string_t::INLINE_LENGTH; } return std::move(state); diff --git a/src/duckdb/src/storage/statistics/base_statistics.cpp b/src/duckdb/src/storage/statistics/base_statistics.cpp index 89ae9cb61..9eac3b9aa 100644 --- a/src/duckdb/src/storage/statistics/base_statistics.cpp +++ b/src/duckdb/src/storage/statistics/base_statistics.cpp @@ -62,6 +62,9 @@ StatisticsType BaseStatistics::GetStatsType(const LogicalType &type) { if (type.id() == LogicalTypeId::SQLNULL) { return StatisticsType::BASE_STATS; } + if (type.id() == LogicalTypeId::GEOMETRY) { + return StatisticsType::GEOMETRY_STATS; + } switch (type.InternalType()) { case PhysicalType::BOOL: case PhysicalType::INT8: @@ -153,6 +156,9 @@ void BaseStatistics::Merge(const BaseStatistics &other) { case StatisticsType::ARRAY_STATS: ArrayStats::Merge(*this, other); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Merge(*this, other); + break; default: break; } @@ -174,6 +180,8 @@ BaseStatistics BaseStatistics::CreateUnknownType(LogicalType type) { return StructStats::CreateUnknown(std::move(type)); case StatisticsType::ARRAY_STATS: return ArrayStats::CreateUnknown(std::move(type)); + case StatisticsType::GEOMETRY_STATS: + return GeometryStats::CreateUnknown(std::move(type)); default: return BaseStatistics(std::move(type)); } @@ -191,6 +199,8 @@ BaseStatistics BaseStatistics::CreateEmptyType(LogicalType type) { return StructStats::CreateEmpty(std::move(type)); case StatisticsType::ARRAY_STATS: return ArrayStats::CreateEmpty(std::move(type)); + case StatisticsType::GEOMETRY_STATS: + return GeometryStats::CreateEmpty(std::move(type)); default: return BaseStatistics(std::move(type)); } @@ -329,6 +339,9 @@ void BaseStatistics::Serialize(Serializer &serializer) const { case StatisticsType::ARRAY_STATS: ArrayStats::Serialize(*this, serializer); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Serialize(*this, serializer); + break; default: break; } @@ -367,6 +380,9 @@ BaseStatistics BaseStatistics::Deserialize(Deserializer &deserializer) { case StatisticsType::ARRAY_STATS: ArrayStats::Deserialize(obj, stats); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Deserialize(obj, stats); + break; default: break; } @@ -397,6 +413,9 @@ string BaseStatistics::ToString() const { case StatisticsType::ARRAY_STATS: result = ArrayStats::ToString(*this) + result; break; + case StatisticsType::GEOMETRY_STATS: + result = GeometryStats::ToString(*this) + result; + break; default: break; } @@ -421,6 +440,9 @@ void BaseStatistics::Verify(Vector &vector, const SelectionVector &sel, idx_t co case StatisticsType::ARRAY_STATS: ArrayStats::Verify(*this, vector, sel, count); break; + case StatisticsType::GEOMETRY_STATS: + GeometryStats::Verify(*this, vector, sel, count); + break; default: break; } @@ -505,6 +527,14 @@ BaseStatistics BaseStatistics::FromConstantType(const Value &input) { } return result; } + case StatisticsType::GEOMETRY_STATS: { + auto result = GeometryStats::CreateEmpty(input.type()); + if (!input.IsNull()) { + auto &string_value = StringValue::Get(input); + GeometryStats::Update(result, string_t(string_value)); + } + return result; + } default: return BaseStatistics(input.type()); } diff --git a/src/duckdb/src/storage/statistics/geometry_stats.cpp b/src/duckdb/src/storage/statistics/geometry_stats.cpp new file mode 100644 index 000000000..42f5efbd8 --- /dev/null +++ b/src/duckdb/src/storage/statistics/geometry_stats.cpp @@ -0,0 +1,171 @@ +#include "duckdb/storage/statistics/geometry_stats.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +vector GeometryTypeSet::ToString(bool snake_case) const { + vector result; + for (auto d = 0; d < VERT_TYPES; d++) { + for (auto i = 0; i < PART_TYPES; i++) { + if (sets[d] & (1 << i)) { + string str; + switch (i) { + case 1: + str = snake_case ? "point" : "Point"; + break; + case 2: + str = snake_case ? "linestring" : "LineString"; + break; + case 3: + str = snake_case ? "polygon" : "Polygon"; + break; + case 4: + str = snake_case ? "multipoint" : "MultiPoint"; + break; + case 5: + str = snake_case ? "multilinestring" : "MultiLineString"; + break; + case 6: + str = snake_case ? "multipolygon" : "MultiPolygon"; + break; + case 7: + str = snake_case ? "geometrycollection" : "GeometryCollection"; + break; + default: + str = snake_case ? "unknown" : "Unknown"; + break; + } + switch (d) { + case 1: + str += snake_case ? "_z" : " Z"; + break; + case 2: + str += snake_case ? "_m" : " M"; + break; + case 3: + str += snake_case ? "_zm" : " ZM"; + break; + default: + break; + } + + result.push_back(str); + } + } + } + return result; +} + +BaseStatistics GeometryStats::CreateUnknown(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeUnknown(); + GetDataUnsafe(result).SetUnknown(); + return result; +} + +BaseStatistics GeometryStats::CreateEmpty(LogicalType type) { + BaseStatistics result(std::move(type)); + result.InitializeEmpty(); + GetDataUnsafe(result).SetEmpty(); + return result; +} + +void GeometryStats::Serialize(const BaseStatistics &stats, Serializer &serializer) { + const auto &data = GetDataUnsafe(stats); + + // Write extent + serializer.WriteObject(200, "extent", [&](Serializer &extent) { + extent.WriteProperty(101, "x_min", data.extent.x_min); + extent.WriteProperty(102, "x_max", data.extent.x_max); + extent.WriteProperty(103, "y_min", data.extent.y_min); + extent.WriteProperty(104, "y_max", data.extent.y_max); + extent.WriteProperty(105, "z_min", data.extent.z_min); + extent.WriteProperty(106, "z_max", data.extent.z_max); + extent.WriteProperty(107, "m_min", data.extent.m_min); + extent.WriteProperty(108, "m_max", data.extent.m_max); + }); + + // Write types + serializer.WriteObject(201, "types", [&](Serializer &types) { + types.WriteProperty(101, "types_xy", data.types.sets[0]); + types.WriteProperty(102, "types_xyz", data.types.sets[1]); + types.WriteProperty(103, "types_xym", data.types.sets[2]); + types.WriteProperty(104, "types_xyzm", data.types.sets[3]); + }); +} + +void GeometryStats::Deserialize(Deserializer &deserializer, BaseStatistics &base) { + auto &data = GetDataUnsafe(base); + + // Read extent + deserializer.ReadObject(200, "extent", [&](Deserializer &extent) { + extent.ReadProperty(101, "x_min", data.extent.x_min); + extent.ReadProperty(102, "x_max", data.extent.x_max); + extent.ReadProperty(103, "y_min", data.extent.y_min); + extent.ReadProperty(104, "y_max", data.extent.y_max); + extent.ReadProperty(105, "z_min", data.extent.z_min); + extent.ReadProperty(106, "z_max", data.extent.z_max); + extent.ReadProperty(107, "m_min", data.extent.m_min); + extent.ReadProperty(108, "m_max", data.extent.m_max); + }); + + // Read types + deserializer.ReadObject(201, "types", [&](Deserializer &types) { + types.ReadProperty(101, "types_xy", data.types.sets[0]); + types.ReadProperty(102, "types_xyz", data.types.sets[1]); + types.ReadProperty(103, "types_xym", data.types.sets[2]); + types.ReadProperty(104, "types_xyzm", data.types.sets[3]); + }); +} + +string GeometryStats::ToString(const BaseStatistics &stats) { + const auto &data = GetDataUnsafe(stats); + string result; + + result += "["; + result += StringUtil::Format("Extent: [X: [%f, %f], Y: [%f, %f], Z: [%f, %f], M: [%f, %f]", data.extent.x_min, + data.extent.x_max, data.extent.y_min, data.extent.y_max, data.extent.z_min, + data.extent.z_max, data.extent.m_min, data.extent.m_max); + result += StringUtil::Format("], Types: [%s]", StringUtil::Join(data.types.ToString(true), ", ")); + + result += "]"; + return result; +} + +void GeometryStats::Update(BaseStatistics &stats, const string_t &value) { + auto &data = GetDataUnsafe(stats); + data.Update(value); +} + +void GeometryStats::Merge(BaseStatistics &stats, const BaseStatistics &other) { + if (other.GetType().id() == LogicalTypeId::VALIDITY) { + return; + } + if (other.GetType().id() == LogicalTypeId::SQLNULL) { + return; + } + + auto &target = GetDataUnsafe(stats); + auto &source = GetDataUnsafe(other); + target.Merge(source); +} + +void GeometryStats::Verify(const BaseStatistics &stats, Vector &vector, const SelectionVector &sel, idx_t count) { + // TODO: Verify stats +} + +const GeometryStatsData &GeometryStats::GetDataUnsafe(const BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::GEOMETRY_STATS); + return stats.stats_union.geometry_data; +} + +GeometryStatsData &GeometryStats::GetDataUnsafe(BaseStatistics &stats) { + D_ASSERT(stats.GetStatsType() == StatisticsType::GEOMETRY_STATS); + return stats.stats_union.geometry_data; +} + +} // namespace duckdb diff --git a/src/duckdb/ub_src_common_sort.cpp b/src/duckdb/ub_src_common_sort.cpp index e472e71ff..bcddfcb4e 100644 --- a/src/duckdb/ub_src_common_sort.cpp +++ b/src/duckdb/ub_src_common_sort.cpp @@ -2,8 +2,6 @@ #include "src/common/sort/merge_sorter.cpp" -#include "src/common/sort/partition_state.cpp" - #include "src/common/sort/radix_sort.cpp" #include "src/common/sort/sort_state.cpp" diff --git a/src/duckdb/ub_src_optimizer.cpp b/src/duckdb/ub_src_optimizer.cpp index cc2d15d70..7a57137c7 100644 --- a/src/duckdb/ub_src_optimizer.cpp +++ b/src/duckdb/ub_src_optimizer.cpp @@ -50,6 +50,8 @@ #include "src/optimizer/topn_optimizer.cpp" +#include "src/optimizer/topn_window_elimination.cpp" + #include "src/optimizer/unnest_rewriter.cpp" #include "src/optimizer/sampling_pushdown.cpp" diff --git a/src/duckdb/ub_src_planner_binder_tableref.cpp b/src/duckdb/ub_src_planner_binder_tableref.cpp index b06304d78..641fd88f6 100644 --- a/src/duckdb/ub_src_planner_binder_tableref.cpp +++ b/src/duckdb/ub_src_planner_binder_tableref.cpp @@ -22,23 +22,5 @@ #include "src/planner/binder/tableref/bind_named_parameters.cpp" -#include "src/planner/binder/tableref/plan_basetableref.cpp" - -#include "src/planner/binder/tableref/plan_delimgetref.cpp" - -#include "src/planner/binder/tableref/plan_dummytableref.cpp" - -#include "src/planner/binder/tableref/plan_expressionlistref.cpp" - -#include "src/planner/binder/tableref/plan_column_data_ref.cpp" - #include "src/planner/binder/tableref/plan_joinref.cpp" -#include "src/planner/binder/tableref/plan_subqueryref.cpp" - -#include "src/planner/binder/tableref/plan_table_function.cpp" - -#include "src/planner/binder/tableref/plan_cteref.cpp" - -#include "src/planner/binder/tableref/plan_pivotref.cpp" - diff --git a/src/duckdb/ub_src_storage_statistics.cpp b/src/duckdb/ub_src_storage_statistics.cpp index 637a311d7..5f8380c90 100644 --- a/src/duckdb/ub_src_storage_statistics.cpp +++ b/src/duckdb/ub_src_storage_statistics.cpp @@ -16,3 +16,5 @@ #include "src/storage/statistics/struct_stats.cpp" +#include "src/storage/statistics/geometry_stats.cpp" + From bcf2f7f575a0628e4670a647942a0d40b4f826a6 Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Sat, 11 Oct 2025 05:03:15 +0000 Subject: [PATCH 5/6] Update vendored DuckDB sources to 9d77bcf518 --- CMakeLists.txt | 3 +- .../extension/parquet/column_writer.cpp | 66 +- src/duckdb/extension/parquet/geo_parquet.cpp | 27 +- .../parquet/include/column_writer.hpp | 9 +- .../extension/parquet/include/geo_parquet.hpp | 29 +- .../parquet/include/parquet_field_id.hpp | 39 + .../parquet/include/parquet_shredding.hpp | 49 ++ .../parquet/include/parquet_writer.hpp | 37 +- .../variant/variant_shredded_conversion.hpp | 9 +- .../include/writer/variant_column_writer.hpp | 23 +- .../extension/parquet/parquet_extension.cpp | 311 ++++---- .../extension/parquet/parquet_field_id.cpp | 178 +++++ .../extension/parquet/parquet_shredding.cpp | 81 ++ .../extension/parquet/parquet_writer.cpp | 52 +- .../variant/variant_shredded_conversion.cpp | 75 +- .../parquet/reader/variant_column_reader.cpp | 2 +- .../extension/parquet/serialize_parquet.cpp | 27 +- .../writer/primitive_column_writer.cpp | 22 +- .../writer/variant/convert_variant.cpp | 633 +++++++++++++++- .../parquet/writer/variant_column_writer.cpp | 131 ---- src/duckdb/src/catalog/catalog_set.cpp | 3 +- .../common/row_operations/row_external.cpp | 157 ---- .../src/common/row_operations/row_gather.cpp | 176 ----- .../common/row_operations/row_heap_gather.cpp | 276 ------- .../row_operations/row_heap_scatter.cpp | 581 -------------- .../row_operations/row_radix_scatter.cpp | 360 --------- .../src/common/row_operations/row_scatter.cpp | 230 ------ src/duckdb/src/common/sort/comparators.cpp | 507 ------------- .../common/{sorting => sort}/hashed_sort.cpp | 0 src/duckdb/src/common/sort/merge_sorter.cpp | 667 ---------------- src/duckdb/src/common/sort/radix_sort.cpp | 352 --------- .../src/common/{sorting => sort}/sort.cpp | 31 +- src/duckdb/src/common/sort/sort_state.cpp | 487 ------------ src/duckdb/src/common/sort/sorted_block.cpp | 387 ---------- .../common/{sorting => sort}/sorted_run.cpp | 0 .../{sorting => sort}/sorted_run_merger.cpp | 1 + src/duckdb/src/common/string_util.cpp | 8 +- .../common/types/row/row_data_collection.cpp | 141 ---- .../types/row/row_data_collection_scanner.cpp | 330 -------- .../src/common/types/row/row_layout.cpp | 62 -- .../src/common/types/selection_vector.cpp | 8 + .../src/execution/expression_executor.cpp | 2 + .../operator/join/physical_asof_join.cpp | 33 +- .../join/physical_nested_loop_join.cpp | 41 +- .../physical_plan/plan_asof_join.cpp | 42 +- .../function/cast/variant/from_variant.cpp | 29 +- .../src/function/cast/variant/to_variant.cpp | 3 + src/duckdb/src/function/macro_function.cpp | 2 +- .../src/function/scalar/create_sort_key.cpp | 8 +- .../function/scalar/operator/arithmetic.cpp | 2 +- .../scalar/variant/variant_extract.cpp | 27 +- .../function/scalar/variant/variant_utils.cpp | 41 +- .../function/table/version/pragma_version.cpp | 6 +- src/duckdb/src/include/duckdb.h | 19 + .../src/include/duckdb/common/hugeint.hpp | 2 +- .../common/operator/comparison_operators.hpp | 11 - .../common/row_operations/row_operations.hpp | 78 +- .../duckdb/common/sort/comparators.hpp | 65 -- .../duckdb/common/sort/duckdb_pdqsort.hpp | 710 ------------------ .../src/include/duckdb/common/sort/sort.hpp | 290 ------- .../duckdb/common/sort/sorted_block.hpp | 165 ---- .../include/duckdb/common/string_map_set.hpp | 17 + .../src/include/duckdb/common/string_util.hpp | 1 + .../include/duckdb/common/types/hugeint.hpp | 12 +- .../duckdb/common/types/selection_vector.hpp | 1 + .../include/duckdb/common/types/variant.hpp | 11 +- .../operator/join/physical_asof_join.hpp | 3 - .../join/physical_nested_loop_join.hpp | 7 +- .../duckdb/function/scalar/variant_utils.hpp | 8 +- .../duckdb/main/capi/extension_api.hpp | 7 + .../main/database_file_path_manager.hpp | 6 +- .../src/include/duckdb/main/relation.hpp | 3 +- .../main/relation/create_table_relation.hpp | 2 + .../main/relation/create_view_relation.hpp | 2 + .../duckdb/main/relation/delete_relation.hpp | 2 + .../duckdb/main/relation/explain_relation.hpp | 2 + .../duckdb/main/relation/insert_relation.hpp | 2 + .../duckdb/main/relation/query_relation.hpp | 1 + .../duckdb/main/relation/update_relation.hpp | 2 + .../main/relation/write_csv_relation.hpp | 2 + .../main/relation/write_parquet_relation.hpp | 2 + .../src/include/duckdb/main/settings.hpp | 2 +- .../duckdb/optimizer/filter_pushdown.hpp | 5 +- .../src/include/duckdb/parser/query_node.hpp | 2 - .../duckdb/parser/query_node/cte_node.hpp | 4 - .../parser/query_node/recursive_cte_node.hpp | 4 - .../duckdb/parser/query_node/select_node.hpp | 4 - .../parser/query_node/set_operation_node.hpp | 2 - .../parser/query_node/statement_node.hpp | 1 - .../src/include/duckdb/planner/binder.hpp | 2 - .../duckdb/planner/bound_query_node.hpp | 24 +- .../include/duckdb/planner/bound_tokens.hpp | 2 - .../planner/query_node/bound_cte_node.hpp | 46 -- .../query_node/bound_recursive_cte_node.hpp | 49 -- .../planner/query_node/bound_select_node.hpp | 6 - .../query_node/bound_set_operation_node.hpp | 7 - .../duckdb/planner/query_node/list.hpp | 2 - .../duckdb/storage/caching_file_system.hpp | 3 +- .../storage/table/array_column_data.hpp | 8 +- .../duckdb/storage/table/column_data.hpp | 12 +- .../duckdb/storage/table/list_column_data.hpp | 8 +- .../duckdb/storage/table/row_group.hpp | 8 +- .../storage/table/row_group_collection.hpp | 8 +- .../storage/table/row_id_column_data.hpp | 8 +- .../storage/table/standard_column_data.hpp | 8 +- .../storage/table/struct_column_data.hpp | 8 +- .../duckdb/storage/table/update_segment.hpp | 4 +- .../duckdb/transaction/duck_transaction.hpp | 3 +- .../duckdb/transaction/update_info.hpp | 5 +- .../duckdb/transaction/wal_write_state.hpp | 2 +- .../verification/statement_verifier.hpp | 2 + src/duckdb/src/include/duckdb_extension.h | 11 + .../src/main/capi/table_description-c.cpp | 33 +- .../src/main/database_file_path_manager.cpp | 39 +- src/duckdb/src/main/relation.cpp | 4 +- .../main/relation/create_table_relation.cpp | 8 + .../main/relation/create_view_relation.cpp | 8 + .../src/main/relation/delete_relation.cpp | 8 + .../src/main/relation/explain_relation.cpp | 8 + .../src/main/relation/insert_relation.cpp | 8 + .../src/main/relation/query_relation.cpp | 4 + .../src/main/relation/update_relation.cpp | 8 + .../src/main/relation/write_csv_relation.cpp | 8 + .../main/relation/write_parquet_relation.cpp | 8 + src/duckdb/src/optimizer/filter_combiner.cpp | 7 + src/duckdb/src/optimizer/filter_pushdown.cpp | 12 +- .../src/optimizer/pushdown/pushdown_get.cpp | 5 +- .../optimizer/rule/regex_optimizations.cpp | 7 + .../parser/query_node/set_operation_node.cpp | 4 - .../src/parser/query_node/statement_node.cpp | 3 - .../parser/statement/relation_statement.cpp | 5 +- .../expression/transform_subquery.cpp | 1 - src/duckdb/src/planner/binder.cpp | 1 - .../expression/bind_macro_expression.cpp | 1 + .../expression/bind_subquery_expression.cpp | 4 - .../binder/query_node/bind_cte_node.cpp | 68 +- .../query_node/bind_recursive_cte_node.cpp | 86 ++- .../binder/query_node/bind_setop_node.cpp | 3 - .../binder/query_node/plan_cte_node.cpp | 26 - .../query_node/plan_recursive_cte_node.cpp | 50 -- .../planner/binder/statement/bind_pragma.cpp | 23 +- .../src/planner/expression_iterator.cpp | 2 - .../src/storage/caching_file_system.cpp | 58 +- src/duckdb/src/storage/data_table.cpp | 4 +- src/duckdb/src/storage/local_storage.cpp | 2 +- .../src/storage/table/array_column_data.cpp | 9 +- src/duckdb/src/storage/table/column_data.cpp | 18 +- .../src/storage/table/list_column_data.cpp | 9 +- src/duckdb/src/storage/table/row_group.cpp | 22 +- .../storage/table/row_group_collection.cpp | 12 +- .../src/storage/table/row_id_column_data.cpp | 9 +- .../storage/table/standard_column_data.cpp | 17 +- .../src/storage/table/struct_column_data.cpp | 19 +- .../src/storage/table/update_segment.cpp | 22 +- src/duckdb/src/transaction/commit_state.cpp | 18 + .../src/transaction/duck_transaction.cpp | 12 +- .../src/transaction/wal_write_state.cpp | 10 +- .../src/verification/statement_verifier.cpp | 19 +- src/duckdb/ub_extension_parquet_writer.cpp | 2 - src/duckdb/ub_src_common_row_operations.cpp | 12 - src/duckdb/ub_src_common_sort.cpp | 10 +- src/duckdb/ub_src_common_sorting.cpp | 8 - src/duckdb/ub_src_common_types_row.cpp | 6 - .../ub_src_planner_binder_query_node.cpp | 4 - 164 files changed, 2116 insertions(+), 7134 deletions(-) create mode 100644 src/duckdb/extension/parquet/include/parquet_field_id.hpp create mode 100644 src/duckdb/extension/parquet/include/parquet_shredding.hpp create mode 100644 src/duckdb/extension/parquet/parquet_field_id.cpp create mode 100644 src/duckdb/extension/parquet/parquet_shredding.cpp delete mode 100644 src/duckdb/extension/parquet/writer/variant_column_writer.cpp delete mode 100644 src/duckdb/src/common/row_operations/row_external.cpp delete mode 100644 src/duckdb/src/common/row_operations/row_gather.cpp delete mode 100644 src/duckdb/src/common/row_operations/row_heap_gather.cpp delete mode 100644 src/duckdb/src/common/row_operations/row_heap_scatter.cpp delete mode 100644 src/duckdb/src/common/row_operations/row_radix_scatter.cpp delete mode 100644 src/duckdb/src/common/row_operations/row_scatter.cpp delete mode 100644 src/duckdb/src/common/sort/comparators.cpp rename src/duckdb/src/common/{sorting => sort}/hashed_sort.cpp (100%) delete mode 100644 src/duckdb/src/common/sort/merge_sorter.cpp delete mode 100644 src/duckdb/src/common/sort/radix_sort.cpp rename src/duckdb/src/common/{sorting => sort}/sort.cpp (96%) delete mode 100644 src/duckdb/src/common/sort/sort_state.cpp delete mode 100644 src/duckdb/src/common/sort/sorted_block.cpp rename src/duckdb/src/common/{sorting => sort}/sorted_run.cpp (100%) rename src/duckdb/src/common/{sorting => sort}/sorted_run_merger.cpp (99%) delete mode 100644 src/duckdb/src/common/types/row/row_data_collection.cpp delete mode 100644 src/duckdb/src/common/types/row/row_data_collection_scanner.cpp delete mode 100644 src/duckdb/src/common/types/row/row_layout.cpp delete mode 100644 src/duckdb/src/include/duckdb/common/sort/comparators.hpp delete mode 100644 src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp delete mode 100644 src/duckdb/src/include/duckdb/common/sort/sort.hpp delete mode 100644 src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp delete mode 100644 src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp delete mode 100644 src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp delete mode 100644 src/duckdb/ub_src_common_sorting.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 280269485..74d5831f5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,6 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_common_row_operations.cpp src/duckdb/ub_src_common_serializer.cpp src/duckdb/ub_src_common_sort.cpp - src/duckdb/ub_src_common_sorting.cpp src/duckdb/ub_src_common_tree_renderer.cpp src/duckdb/ub_src_common_types.cpp src/duckdb/ub_src_common_types_column.cpp @@ -377,9 +376,11 @@ set(DUCKDB_SRC_FILES src/duckdb/extension/parquet/parquet_timestamp.cpp src/duckdb/extension/parquet/parquet_float16.cpp src/duckdb/extension/parquet/parquet_statistics.cpp + src/duckdb/extension/parquet/parquet_shredding.cpp src/duckdb/extension/parquet/parquet_multi_file_info.cpp src/duckdb/extension/parquet/column_reader.cpp src/duckdb/extension/parquet/geo_parquet.cpp + src/duckdb/extension/parquet/parquet_field_id.cpp src/duckdb/extension/parquet/parquet_extension.cpp src/duckdb/extension/parquet/column_writer.cpp src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 304e21751..0220d8b24 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -245,8 +245,9 @@ void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterStat //===--------------------------------------------------------------------===// ParquetColumnSchema ColumnWriter::FillParquetSchema(vector &schemas, - const LogicalType &type, const string &name, - optional_ptr field_ids, idx_t max_repeat, + const LogicalType &type, const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, idx_t max_repeat, idx_t max_define, bool can_have_nulls) { auto null_type = can_have_nulls ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; if (!can_have_nulls) { @@ -263,6 +264,10 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vectorchild_field_ids; } } + optional_ptr shredding_type; + if (shredding_types) { + shredding_type = shredding_types->GetChild(name); + } if (type.id() == LogicalTypeId::STRUCT && type.GetAlias() == "PARQUET_VARIANT") { // variant type @@ -273,32 +278,53 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vector] // } - const bool is_shredded = false; + const bool is_shredded = shredding_type != nullptr; + + child_list_t child_types; + child_types.emplace_back("metadata", LogicalType::BLOB); + child_types.emplace_back("value", LogicalType::BLOB); + if (is_shredded) { + auto &typed_value_type = shredding_type->type; + if (typed_value_type.id() != LogicalTypeId::ANY) { + child_types.emplace_back("typed_value", + VariantColumnWriter::TransformTypedValueRecursive(typed_value_type)); + } + } // variant group duckdb_parquet::SchemaElement top_element; top_element.repetition_type = null_type; - top_element.num_children = is_shredded ? 3 : 2; + top_element.num_children = child_types.size(); top_element.logicalType.__isset.VARIANT = true; top_element.logicalType.VARIANT.__isset.specification_version = true; top_element.logicalType.VARIANT.specification_version = 1; top_element.__isset.logicalType = true; top_element.__isset.num_children = true; top_element.__isset.repetition_type = true; + top_element.name = name; schemas.push_back(std::move(top_element)); - child_list_t child_types; - child_types.emplace_back("metadata", LogicalType::BLOB); - child_types.emplace_back("value", LogicalType::BLOB); - if (is_shredded) { - throw NotImplementedException("Writing shredded VARIANT isn't supported for Parquet yet"); - } - ParquetColumnSchema variant_column(name, type, max_define, max_repeat, schema_idx, 0); variant_column.children.reserve(child_types.size()); for (auto &child_type : child_types) { + auto &child_name = child_type.first; + bool is_optional; + if (child_name == "metadata") { + is_optional = false; + } else if (child_name == "value") { + if (is_shredded) { + //! When shredding the variant, the 'value' becomes optional + is_optional = true; + } else { + is_optional = false; + } + } else { + D_ASSERT(child_name == "typed_value"); + is_optional = true; + } variant_column.children.emplace_back(FillParquetSchema(schemas, child_type.second, child_type.first, - child_field_ids, max_repeat, max_define + 1, false)); + allow_geometry, child_field_ids, shredding_type, + max_repeat, max_define + 1, is_optional)); } return variant_column; } @@ -324,7 +350,8 @@ ParquetColumnSchema ColumnWriter::FillParquetSchema(vectorfield_id; } - ParquetWriter::SetSchemaProperties(type, schema_element); + ParquetWriter::SetSchemaProperties(type, schema_element, allow_geometry); schemas.push_back(std::move(schema_element)); return ParquetColumnSchema(name, type, max_define, max_repeat, schema_idx, 0); } @@ -441,8 +469,6 @@ ColumnWriter::CreateWriterRecursive(ClientContext &context, ParquetWriter &write path_in_schema.push_back(schema.name); if (type.id() == LogicalTypeId::STRUCT && type.GetAlias() == "PARQUET_VARIANT") { - D_ASSERT(schema.children.size() == 2); //! NOTE: shredded variants not supported yet - vector> child_writers; child_writers.reserve(schema.children.size()); for (idx_t i = 0; i < schema.children.size(); i++) { diff --git a/src/duckdb/extension/parquet/geo_parquet.cpp b/src/duckdb/extension/parquet/geo_parquet.cpp index 48e2b047f..a2cc2a821 100644 --- a/src/duckdb/extension/parquet/geo_parquet.cpp +++ b/src/duckdb/extension/parquet/geo_parquet.cpp @@ -43,17 +43,19 @@ unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_ throw InvalidInputException("Geoparquet metadata is not an object"); } - auto result = make_uniq(); + // We dont actually care about the version for now, as we only support V1+native + auto result = make_uniq(GeoParquetVersion::BOTH); // Check and parse the version const auto version_val = yyjson_obj_get(root, "version"); if (!yyjson_is_str(version_val)) { throw InvalidInputException("Geoparquet metadata does not have a version"); } - result->version = yyjson_get_str(version_val); - if (StringUtil::StartsWith(result->version, "2")) { - // Guard against a breaking future 2.0 version - throw InvalidInputException("Geoparquet version %s is not supported", result->version); + + auto version = yyjson_get_str(version_val); + if (StringUtil::StartsWith(version, "3")) { + // Guard against a breaking future 3.0 version + throw InvalidInputException("Geoparquet version %s is not supported", version); } // Check and parse the geometry columns @@ -177,7 +179,20 @@ void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) yyjson_mut_doc_set_root(doc, root); // Add the version - yyjson_mut_obj_add_strncpy(doc, root, "version", version.c_str(), version.size()); + switch (version) { + case GeoParquetVersion::V1: + case GeoParquetVersion::BOTH: + yyjson_mut_obj_add_strcpy(doc, root, "version", "1.0.0"); + break; + case GeoParquetVersion::V2: + yyjson_mut_obj_add_strcpy(doc, root, "version", "2.0.0"); + break; + case GeoParquetVersion::NONE: + default: + // Should never happen, we should not be writing anything + yyjson_mut_doc_free(doc); + throw InternalException("GeoParquetVersion::NONE should not write metadata"); + } // Add the primary column yyjson_mut_obj_add_strncpy(doc, root, "primary_column", primary_geometry_column.c_str(), diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp index d208bddf6..929e94a11 100644 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ b/src/duckdb/extension/parquet/include/column_writer.hpp @@ -18,6 +18,7 @@ class ParquetWriter; class ColumnWriterPageState; class PrimitiveColumnWriterState; struct ChildFieldIDs; +struct ShreddingType; class ResizeableBuffer; class ParquetBloomFilter; @@ -89,9 +90,11 @@ class ColumnWriter { } static ParquetColumnSchema FillParquetSchema(vector &schemas, - const LogicalType &type, const string &name, - optional_ptr field_ids, idx_t max_repeat = 0, - idx_t max_define = 1, bool can_have_nulls = true); + const LogicalType &type, const string &name, bool allow_geometry, + optional_ptr field_ids, + optional_ptr shredding_types, + idx_t max_repeat = 0, idx_t max_define = 1, + bool can_have_nulls = true); //! Create the column writer for a specific type recursively static unique_ptr CreateWriterRecursive(ClientContext &context, ParquetWriter &writer, const vector &parquet_schemas, diff --git a/src/duckdb/extension/parquet/include/geo_parquet.hpp b/src/duckdb/extension/parquet/include/geo_parquet.hpp index 424e7c324..0e236c73a 100644 --- a/src/duckdb/extension/parquet/include/geo_parquet.hpp +++ b/src/duckdb/extension/parquet/include/geo_parquet.hpp @@ -33,6 +33,31 @@ enum class GeoParquetColumnEncoding : uint8_t { MULTIPOLYGON, }; +enum class GeoParquetVersion : uint8_t { + // Write GeoParquet 1.0 metadata + // GeoParquet 1.0 has the widest support among readers and writers + V1, + + // Write GeoParquet 2.0 + // The GeoParquet 2.0 options is identical to GeoParquet 1.0 except the underlying storage + // of spatial columns is Parquet native geometry, where the Parquet writer will include + // native statistics according to the underlying Parquet options. Compared to 'BOTH', this will + // actually write the metadata as containing GeoParquet version 2.0.0 + // However, V2 isnt standardized yet, so this option is still a bit experimental + V2, + + // Write GeoParquet 1.0 metadata, with native Parquet geometry types + // This is a bit of a hold-over option for compatibility with systems that + // reject GeoParquet 2.0 metadata, but can read Parquet native geometry types as they simply ignore the extra + // logical type. DuckDB v1.4.0 falls into this category. + BOTH, + + // Do not write GeoParquet metadata + // This option suppresses GeoParquet metadata; however, spatial types will be written as + // Parquet native Geometry/Geography. + NONE, +}; + struct GeoParquetColumnMetadata { // The encoding of the geometry column GeoParquetColumnEncoding geometry_encoding; @@ -49,6 +74,8 @@ struct GeoParquetColumnMetadata { class GeoParquetFileMetadata { public: + explicit GeoParquetFileMetadata(GeoParquetVersion geo_parquet_version) : version(geo_parquet_version) { + } void AddGeoParquetStats(const string &column_name, const LogicalType &type, const GeometryStatsData &stats); void Write(duckdb_parquet::FileMetaData &file_meta_data); @@ -68,8 +95,8 @@ class GeoParquetFileMetadata { private: mutex write_lock; - string version = "1.1.0"; unordered_map geometry_columns; + GeoParquetVersion version; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_field_id.hpp b/src/duckdb/extension/parquet/include/parquet_field_id.hpp new file mode 100644 index 000000000..9d5dd754c --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_field_id.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/case_insensitive_map.hpp" + +namespace duckdb { + +struct FieldID; +struct ChildFieldIDs { + ChildFieldIDs(); + ChildFieldIDs Copy() const; + unique_ptr> ids; + + void Serialize(Serializer &serializer) const; + static ChildFieldIDs Deserialize(Deserializer &source); +}; + +struct FieldID { +public: + static constexpr const auto DUCKDB_FIELD_ID = "__duckdb_field_id"; + FieldID(); + explicit FieldID(int32_t field_id); + FieldID Copy() const; + bool set; + int32_t field_id; + ChildFieldIDs child_field_ids; + + void Serialize(Serializer &serializer) const; + static FieldID Deserialize(Deserializer &source); + +public: + static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, + const vector &sql_types); + static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, + unordered_set &unique_field_ids, + const case_insensitive_map_t &name_to_type_map); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_shredding.hpp b/src/duckdb/extension/parquet/include/parquet_shredding.hpp new file mode 100644 index 000000000..f43cbc42c --- /dev/null +++ b/src/duckdb/extension/parquet/include/parquet_shredding.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/types/variant.hpp" + +namespace duckdb { + +struct ShreddingType; + +struct ChildShreddingTypes { +public: + ChildShreddingTypes(); + +public: + ChildShreddingTypes Copy() const; + +public: + void Serialize(Serializer &serializer) const; + static ChildShreddingTypes Deserialize(Deserializer &source); + +public: + unique_ptr> types; +}; + +struct ShreddingType { +public: + ShreddingType(); + explicit ShreddingType(const LogicalType &type); + +public: + ShreddingType Copy() const; + +public: + void Serialize(Serializer &serializer) const; + static ShreddingType Deserialize(Deserializer &source); + +public: + static ShreddingType GetShreddingTypes(const Value &val); + void AddChild(const string &name, ShreddingType &&child); + optional_ptr GetChild(const string &name) const; + +public: + bool set = false; + LogicalType type; + ChildShreddingTypes children; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index a2bfc3a80..46661215f 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -21,6 +21,8 @@ #include "parquet_statistics.hpp" #include "column_writer.hpp" +#include "parquet_field_id.hpp" +#include "parquet_shredding.hpp" #include "parquet_types.h" #include "geo_parquet.hpp" #include "writer/parquet_write_stats.hpp" @@ -43,29 +45,6 @@ struct PreparedRowGroup { vector> states; }; -struct FieldID; -struct ChildFieldIDs { - ChildFieldIDs(); - ChildFieldIDs Copy() const; - unique_ptr> ids; - - void Serialize(Serializer &serializer) const; - static ChildFieldIDs Deserialize(Deserializer &source); -}; - -struct FieldID { - static constexpr const auto DUCKDB_FIELD_ID = "__duckdb_field_id"; - FieldID(); - explicit FieldID(int32_t field_id); - FieldID Copy() const; - bool set; - int32_t field_id; - ChildFieldIDs child_field_ids; - - void Serialize(Serializer &serializer) const; - static FieldID Deserialize(Deserializer &source); -}; - struct ParquetBloomFilterEntry { unique_ptr bloom_filter; idx_t row_group_idx; @@ -81,11 +60,11 @@ class ParquetWriter { public: ParquetWriter(ClientContext &context, FileSystem &fs, string file_name, vector types, vector names, duckdb_parquet::CompressionCodec::type codec, ChildFieldIDs field_ids, - const vector> &kv_metadata, + ShreddingType shredding_types, const vector> &kv_metadata, shared_ptr encryption_config, optional_idx dictionary_size_limit, idx_t string_dictionary_page_size_limit, bool enable_bloom_filters, double bloom_filter_false_positive_ratio, int64_t compression_level, bool debug_use_openssl, - ParquetVersion parquet_version); + ParquetVersion parquet_version, GeoParquetVersion geoparquet_version); ~ParquetWriter(); public: @@ -95,7 +74,8 @@ class ParquetWriter { void Finalize(); static duckdb_parquet::Type::type DuckDBTypeToParquetType(const LogicalType &duckdb_type); - static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele); + static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry); ClientContext &GetContext() { return context; @@ -139,6 +119,9 @@ class ParquetWriter { ParquetVersion GetParquetVersion() const { return parquet_version; } + GeoParquetVersion GetGeoParquetVersion() const { + return geoparquet_version; + } const string &GetFileName() const { return file_name; } @@ -166,6 +149,7 @@ class ParquetWriter { vector column_names; duckdb_parquet::CompressionCodec::type codec; ChildFieldIDs field_ids; + ShreddingType shredding_types; shared_ptr encryption_config; optional_idx dictionary_size_limit; idx_t string_dictionary_page_size_limit; @@ -175,6 +159,7 @@ class ParquetWriter { bool debug_use_openssl; shared_ptr encryption_util; ParquetVersion parquet_version; + GeoParquetVersion geoparquet_version; vector column_schemas; unique_ptr writer; diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp index 27ece7d70..bbcf71792 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp @@ -11,13 +11,14 @@ class VariantShreddedConversion { public: static vector Convert(Vector &metadata, Vector &group, idx_t offset, idx_t length, idx_t total_size, - bool is_field = false); + bool is_field); static vector ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); + idx_t length, idx_t total_size, const bool is_field); static vector ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); + idx_t length, idx_t total_size, const bool is_field); static vector ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, - idx_t offset, idx_t length, idx_t total_size); + idx_t offset, idx_t length, idx_t total_size, + const bool is_field); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp index d1c5af1cf..74fdda608 100644 --- a/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/variant_column_writer.hpp @@ -8,36 +8,23 @@ #pragma once -#include "column_writer.hpp" +#include "struct_column_writer.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" namespace duckdb { -class VariantColumnWriter : public ColumnWriter { +class VariantColumnWriter : public StructColumnWriter { public: VariantColumnWriter(ParquetWriter &writer, const ParquetColumnSchema &column_schema, vector schema_path_p, vector> child_writers_p, bool can_have_nulls) - : ColumnWriter(writer, column_schema, std::move(schema_path_p), can_have_nulls), - child_writers(std::move(child_writers_p)) { + : StructColumnWriter(writer, column_schema, std::move(schema_path_p), std::move(child_writers_p), + can_have_nulls) { } ~VariantColumnWriter() override = default; - vector> child_writers; - -public: - unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; - bool HasAnalyze() override; - void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - void FinalizeAnalyze(ColumnWriterState &state) override; - void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count, - bool vector_can_span_multiple_pages) override; - - void BeginWrite(ColumnWriterState &state) override; - void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; - void FinalizeWrite(ColumnWriterState &state) override; - public: static ScalarFunction GetTransformFunction(); + static LogicalType TransformTypedValueRecursive(const LogicalType &type); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 6e8e6fe1b..1a02ae20d 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -12,6 +12,7 @@ #include "parquet_metadata.hpp" #include "parquet_reader.hpp" #include "parquet_writer.hpp" +#include "parquet_shredding.hpp" #include "reader/struct_column_reader.hpp" #include "zstd_file_system.hpp" #include "writer/primitive_column_writer.hpp" @@ -46,6 +47,7 @@ #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/table/row_group.hpp" @@ -57,156 +59,6 @@ namespace duckdb { -static case_insensitive_map_t GetChildNameToTypeMap(const LogicalType &type) { - case_insensitive_map_t name_to_type_map; - switch (type.id()) { - case LogicalTypeId::LIST: - name_to_type_map.emplace("element", ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - name_to_type_map.emplace("key", MapType::KeyType(type)); - name_to_type_map.emplace("value", MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - if (child_type.first == FieldID::DUCKDB_FIELD_ID) { - throw BinderException("Cannot have column named \"%s\" with FIELD_IDS", FieldID::DUCKDB_FIELD_ID); - } - name_to_type_map.emplace(child_type); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNameToTypeMap"); - } // LCOV_EXCL_STOP - return name_to_type_map; -} - -static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, - vector &child_types) { - switch (type.id()) { - case LogicalTypeId::LIST: - child_names.emplace_back("element"); - child_types.emplace_back(ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - child_names.emplace_back("key"); - child_names.emplace_back("value"); - child_types.emplace_back(MapType::KeyType(type)); - child_types.emplace_back(MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - child_names.emplace_back(child_type.first); - child_types.emplace_back(child_type.second); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNamesAndTypes"); - } // LCOV_EXCL_STOP -} - -static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, - const vector &sql_types) { - D_ASSERT(names.size() == sql_types.size()); - for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { - const auto &col_name = names[col_idx]; - auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); - D_ASSERT(inserted.second); - - const auto &col_type = sql_types[col_idx]; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - continue; - } - - // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first - vector child_names; - vector child_types; - GetChildNamesAndTypes(col_type, child_names, child_types); - - GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); - } -} - -static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, - unordered_set &unique_field_ids, - const case_insensitive_map_t &name_to_type_map) { - const auto &struct_type = field_ids_value.type(); - if (struct_type.id() != LogicalTypeId::STRUCT) { - throw BinderException( - "Expected FIELD_IDS to be a STRUCT, e.g., {col1: 42, col2: {%s: 43, nested_col: 44}, col3: 44}", - FieldID::DUCKDB_FIELD_ID); - } - const auto &struct_children = StructValue::GetChildren(field_ids_value); - D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); - if (col_name == FieldID::DUCKDB_FIELD_ID) { - continue; - } - - auto it = name_to_type_map.find(col_name); - if (it == name_to_type_map.end()) { - string names; - for (const auto &name : name_to_type_map) { - if (!names.empty()) { - names += ", "; - } - names += name.first; - } - throw BinderException( - "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " - "column is a partition column. Available column names: [%s]", - col_name, names); - } - D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys - - const auto &child_value = struct_children[i]; - const auto &child_type = child_value.type(); - optional_ptr field_id_value; - optional_ptr child_field_ids_value; - - if (child_type.id() == LogicalTypeId::STRUCT) { - const auto &nested_children = StructValue::GetChildren(child_value); - D_ASSERT(StructType::GetChildTypes(child_type).size() == nested_children.size()); - for (idx_t nested_i = 0; nested_i < nested_children.size(); nested_i++) { - const auto &field_id_or_nested_col = StructType::GetChildName(child_type, nested_i); - if (field_id_or_nested_col == FieldID::DUCKDB_FIELD_ID) { - field_id_value = &nested_children[nested_i]; - } else { - child_field_ids_value = &child_value; - } - } - } else { - field_id_value = &child_value; - } - - FieldID field_id; - if (field_id_value) { - Value field_id_integer_value = field_id_value->DefaultCastAs(LogicalType::INTEGER); - const uint32_t field_id_int = IntegerValue::Get(field_id_integer_value); - if (!unique_field_ids.insert(field_id_int).second) { - throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); - } - field_id = FieldID(UnsafeNumericCast(field_id_int)); - } - auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); - D_ASSERT(inserted.second); - - if (child_field_ids_value) { - const auto &col_type = it->second; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - throw BinderException("Column \"%s\" with type \"%s\" cannot have a nested FIELD_IDS specification", - col_name, LogicalTypeIdToString(col_type.id())); - } - - GetFieldIDs(*child_field_ids_value, inserted.first->second.child_field_ids, unique_field_ids, - GetChildNameToTypeMap(col_type)); - } - } -} - struct ParquetWriteBindData : public TableFunctionData { vector sql_types; vector column_names; @@ -236,11 +88,15 @@ struct ParquetWriteBindData : public TableFunctionData { optional_idx row_groups_per_file; ChildFieldIDs field_ids; + ShreddingType shredding_types; //! The compression level, higher value is more int64_t compression_level = ZStdFileSystem::DefaultCompressionLevel(); //! Which encodings to include when writing ParquetVersion parquet_version = ParquetVersion::V1; + + //! Which geo-parquet version to use when writing + GeoParquetVersion geoparquet_version = GeoParquetVersion::V1; }; struct ParquetWriteGlobalState : public GlobalFunctionData { @@ -294,6 +150,8 @@ static void ParquetListCopyOptions(ClientContext &context, CopyOptionsInput &inp copy_options["binary_as_string"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["file_row_number"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); copy_options["can_have_nan"] = CopyOption(LogicalType::BOOLEAN, CopyOptionMode::READ_ONLY); + copy_options["geoparquet_version"] = CopyOption(LogicalType::VARCHAR, CopyOptionMode::WRITE_ONLY); + copy_options["shredding"] = CopyOption(LogicalType::ANY, CopyOptionMode::WRITE_ONLY); } static unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBindInput &input, @@ -345,7 +203,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun if (option.second[0].type().id() == LogicalTypeId::VARCHAR && StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { idx_t field_id = 0; - GenerateFieldIDs(bind_data->field_ids, field_id, names, sql_types); + FieldID::GenerateFieldIDs(bind_data->field_ids, field_id, names, sql_types); } else { unordered_set unique_field_ids; case_insensitive_map_t name_to_type_map; @@ -356,7 +214,57 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } name_to_type_map.emplace(names[col_idx], sql_types[col_idx]); } - GetFieldIDs(option.second[0], bind_data->field_ids, unique_field_ids, name_to_type_map); + FieldID::GetFieldIDs(option.second[0], bind_data->field_ids, unique_field_ids, name_to_type_map); + } + } else if (loption == "shredding") { + if (option.second[0].type().id() == LogicalTypeId::VARCHAR && + StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { + throw NotImplementedException("The 'auto' option is not yet implemented for 'shredding'"); + } else { + case_insensitive_set_t variant_names; + for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { + if (sql_types[col_idx].id() != LogicalTypeId::STRUCT) { + continue; + } + if (sql_types[col_idx].GetAlias() != "PARQUET_VARIANT") { + continue; + } + variant_names.emplace(names[col_idx]); + } + auto &shredding_types_value = option.second[0]; + if (shredding_types_value.type().id() != LogicalTypeId::STRUCT) { + BinderException("SHREDDING value should be a STRUCT of column names to types, i.e: {col1: " + "'INTEGER[]', col2: 'BOOLEAN'}"); + } + const auto &struct_type = shredding_types_value.type(); + const auto &struct_children = StructValue::GetChildren(shredding_types_value); + D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); + auto it = variant_names.find(col_name); + if (it == variant_names.end()) { + string names; + for (const auto &entry : variant_names) { + if (!names.empty()) { + names += ", "; + } + names += entry; + } + if (names.empty()) { + throw BinderException("VARIANT by name \"%s\" specified in SHREDDING not found. There are " + "no VARIANT columns present.", + col_name); + } else { + throw BinderException( + "VARIANT by name \"%s\" specified in SHREDDING not found. Consider using " + "WRITE_PARTITION_COLUMNS if this " + "column is a partition column. Available names of VARIANT columns: [%s]", + col_name, names); + } + } + const auto &child_value = struct_children[i]; + bind_data->shredding_types.AddChild(col_name, ShreddingType::GetShreddingTypes(child_value)); + } } } else if (loption == "kv_metadata") { auto &kv_struct = option.second[0]; @@ -429,6 +337,19 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } else { throw BinderException("Expected parquet_version 'V1' or 'V2'"); } + } else if (loption == "geoparquet_version") { + const auto roption = StringUtil::Upper(option.second[0].ToString()); + if (roption == "NONE") { + bind_data->geoparquet_version = GeoParquetVersion::NONE; + } else if (roption == "V1") { + bind_data->geoparquet_version = GeoParquetVersion::V1; + } else if (roption == "V2") { + bind_data->geoparquet_version = GeoParquetVersion::V2; + } else if (roption == "BOTH") { + bind_data->geoparquet_version = GeoParquetVersion::BOTH; + } else { + throw BinderException("Expected geoparquet_version 'NONE', 'V1' or 'BOTH'"); + } } else { throw InternalException("Unrecognized option for PARQUET: %s", option.first.c_str()); } @@ -457,10 +378,11 @@ static unique_ptr ParquetWriteInitializeGlobal(ClientContext auto &fs = FileSystem::GetFileSystem(context); global_state->writer = make_uniq( context, fs, file_path, parquet_bind.sql_types, parquet_bind.column_names, parquet_bind.codec, - parquet_bind.field_ids.Copy(), parquet_bind.kv_metadata, parquet_bind.encryption_config, - parquet_bind.dictionary_size_limit, parquet_bind.string_dictionary_page_size_limit, - parquet_bind.enable_bloom_filters, parquet_bind.bloom_filter_false_positive_ratio, - parquet_bind.compression_level, parquet_bind.debug_use_openssl, parquet_bind.parquet_version); + parquet_bind.field_ids.Copy(), parquet_bind.shredding_types.Copy(), parquet_bind.kv_metadata, + parquet_bind.encryption_config, parquet_bind.dictionary_size_limit, + parquet_bind.string_dictionary_page_size_limit, parquet_bind.enable_bloom_filters, + parquet_bind.bloom_filter_false_positive_ratio, parquet_bind.compression_level, parquet_bind.debug_use_openssl, + parquet_bind.parquet_version, parquet_bind.geoparquet_version); return std::move(global_state); } @@ -629,6 +551,39 @@ ParquetVersion EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template <> +const char *EnumUtil::ToChars(GeoParquetVersion value) { + switch (value) { + case GeoParquetVersion::NONE: + return "NONE"; + case GeoParquetVersion::V1: + return "V1"; + case GeoParquetVersion::V2: + return "V2"; + case GeoParquetVersion::BOTH: + return "BOTH"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); + } +} + +template <> +GeoParquetVersion EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "NONE")) { + return GeoParquetVersion::NONE; + } + if (StringUtil::Equals(value, "V1")) { + return GeoParquetVersion::V1; + } + if (StringUtil::Equals(value, "V2")) { + return GeoParquetVersion::V2; + } + if (StringUtil::Equals(value, "BOTH")) { + return GeoParquetVersion::BOTH; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + static optional_idx SerializeCompressionLevel(const int64_t compression_level) { return compression_level < 0 ? NumericLimits::Maximum() - NumericCast(AbsValue(compression_level)) : NumericCast(compression_level); @@ -682,6 +637,9 @@ static void ParquetCopySerialize(Serializer &serializer, const FunctionData &bin serializer.WritePropertyWithDefault(115, "string_dictionary_page_size_limit", bind_data.string_dictionary_page_size_limit, default_value.string_dictionary_page_size_limit); + serializer.WritePropertyWithDefault(116, "geoparquet_version", bind_data.geoparquet_version, + default_value.geoparquet_version); + serializer.WriteProperty(117, "shredding_types", bind_data.shredding_types); } static unique_ptr ParquetCopyDeserialize(Deserializer &deserializer, CopyFunction &function) { @@ -714,6 +672,9 @@ static unique_ptr ParquetCopyDeserialize(Deserializer &deserialize deserializer.ReadPropertyWithExplicitDefault(114, "parquet_version", default_value.parquet_version); data->string_dictionary_page_size_limit = deserializer.ReadPropertyWithExplicitDefault( 115, "string_dictionary_page_size_limit", default_value.string_dictionary_page_size_limit); + data->geoparquet_version = + deserializer.ReadPropertyWithExplicitDefault(116, "geoparquet_version", default_value.geoparquet_version); + data->shredding_types = deserializer.ReadProperty(117, "shredding_types"); return std::move(data); } @@ -844,6 +805,38 @@ static bool IsGeometryType(const LogicalType &type, ClientContext &context) { return GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context); } +static string GetShredding(case_insensitive_map_t> &options, const string &col_name) { + //! At this point, the options haven't been parsed yet, so we have to parse them ourselves. + auto it = options.find("shredding"); + if (it == options.end()) { + return string(); + } + auto &shredding = it->second; + if (shredding.empty()) { + return string(); + } + + auto &shredding_val = shredding[0]; + if (shredding_val.type().id() != LogicalTypeId::STRUCT) { + return string(); + } + + auto &shredded_variants = StructType::GetChildTypes(shredding_val.type()); + auto &values = StructValue::GetChildren(shredding_val); + for (idx_t i = 0; i < shredded_variants.size(); i++) { + auto &shredded_variant = shredded_variants[i]; + if (shredded_variant.first != col_name) { + continue; + } + auto &shredded_val = values[i]; + if (shredded_val.type().id() != LogicalTypeId::VARCHAR) { + return string(); + } + return shredded_val.GetValue(); + } + return string(); +} + static vector> ParquetWriteSelect(CopyToSelectInput &input) { auto &context = input.context; @@ -872,11 +865,17 @@ static vector> ParquetWriteSelect(CopyToSelectInput &inpu vector> arguments; arguments.push_back(std::move(expr)); + auto shredded_type_str = GetShredding(input.options, name); + if (!shredded_type_str.empty()) { + arguments.push_back(make_uniq(Value(shredded_type_str))); + } + auto transform_func = VariantColumnWriter::GetTransformFunction(); transform_func.bind(context, transform_func, arguments); auto func_expr = make_uniq(transform_func.return_type, transform_func, std::move(arguments), nullptr, false); + func_expr->SetAlias(name); result.push_back(std::move(func_expr)); any_change = true; } diff --git a/src/duckdb/extension/parquet/parquet_field_id.cpp b/src/duckdb/extension/parquet/parquet_field_id.cpp new file mode 100644 index 000000000..d1ff138cc --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_field_id.cpp @@ -0,0 +1,178 @@ +#include "parquet_field_id.hpp" +#include "duckdb/common/exception/binder_exception.hpp" + +namespace duckdb { + +ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { +} + +ChildFieldIDs ChildFieldIDs::Copy() const { + ChildFieldIDs result; + for (const auto &id : *ids) { + result.ids->emplace(id.first, id.second.Copy()); + } + return result; +} + +FieldID::FieldID() : set(false) { +} + +FieldID::FieldID(int32_t field_id_p) : set(true), field_id(field_id_p) { +} + +FieldID FieldID::Copy() const { + auto result = set ? FieldID(field_id) : FieldID(); + result.child_field_ids = child_field_ids.Copy(); + return result; +} + +static case_insensitive_map_t GetChildNameToTypeMap(const LogicalType &type) { + case_insensitive_map_t name_to_type_map; + switch (type.id()) { + case LogicalTypeId::LIST: + name_to_type_map.emplace("element", ListType::GetChildType(type)); + break; + case LogicalTypeId::MAP: + name_to_type_map.emplace("key", MapType::KeyType(type)); + name_to_type_map.emplace("value", MapType::ValueType(type)); + break; + case LogicalTypeId::STRUCT: + for (auto &child_type : StructType::GetChildTypes(type)) { + if (child_type.first == FieldID::DUCKDB_FIELD_ID) { + throw BinderException("Cannot have column named \"%s\" with FIELD_IDS", FieldID::DUCKDB_FIELD_ID); + } + name_to_type_map.emplace(child_type); + } + break; + default: // LCOV_EXCL_START + throw InternalException("Unexpected type in GetChildNameToTypeMap"); + } // LCOV_EXCL_STOP + return name_to_type_map; +} + +static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, + vector &child_types) { + switch (type.id()) { + case LogicalTypeId::LIST: + child_names.emplace_back("element"); + child_types.emplace_back(ListType::GetChildType(type)); + break; + case LogicalTypeId::MAP: + child_names.emplace_back("key"); + child_names.emplace_back("value"); + child_types.emplace_back(MapType::KeyType(type)); + child_types.emplace_back(MapType::ValueType(type)); + break; + case LogicalTypeId::STRUCT: + for (auto &child_type : StructType::GetChildTypes(type)) { + child_names.emplace_back(child_type.first); + child_types.emplace_back(child_type.second); + } + break; + default: // LCOV_EXCL_START + throw InternalException("Unexpected type in GetChildNamesAndTypes"); + } // LCOV_EXCL_STOP +} + +void FieldID::GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, + const vector &sql_types) { + D_ASSERT(names.size() == sql_types.size()); + for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { + const auto &col_name = names[col_idx]; + auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); + D_ASSERT(inserted.second); + + const auto &col_type = sql_types[col_idx]; + if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && + col_type.id() != LogicalTypeId::STRUCT) { + continue; + } + + // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first + vector child_names; + vector child_types; + GetChildNamesAndTypes(col_type, child_names, child_types); + GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); + } +} + +void FieldID::GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, + unordered_set &unique_field_ids, + const case_insensitive_map_t &name_to_type_map) { + const auto &struct_type = field_ids_value.type(); + if (struct_type.id() != LogicalTypeId::STRUCT) { + throw BinderException( + "Expected FIELD_IDS to be a STRUCT, e.g., {col1: 42, col2: {%s: 43, nested_col: 44}, col3: 44}", + FieldID::DUCKDB_FIELD_ID); + } + const auto &struct_children = StructValue::GetChildren(field_ids_value); + D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); + for (idx_t i = 0; i < struct_children.size(); i++) { + const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); + if (col_name == FieldID::DUCKDB_FIELD_ID) { + continue; + } + + auto it = name_to_type_map.find(col_name); + if (it == name_to_type_map.end()) { + string names; + for (const auto &name : name_to_type_map) { + if (!names.empty()) { + names += ", "; + } + names += name.first; + } + throw BinderException( + "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " + "column is a partition column. Available column names: [%s]", + col_name, names); + } + D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys + + const auto &child_value = struct_children[i]; + const auto &child_type = child_value.type(); + optional_ptr field_id_value; + optional_ptr child_field_ids_value; + + if (child_type.id() == LogicalTypeId::STRUCT) { + const auto &nested_children = StructValue::GetChildren(child_value); + D_ASSERT(StructType::GetChildTypes(child_type).size() == nested_children.size()); + for (idx_t nested_i = 0; nested_i < nested_children.size(); nested_i++) { + const auto &field_id_or_nested_col = StructType::GetChildName(child_type, nested_i); + if (field_id_or_nested_col == FieldID::DUCKDB_FIELD_ID) { + field_id_value = &nested_children[nested_i]; + } else { + child_field_ids_value = &child_value; + } + } + } else { + field_id_value = &child_value; + } + + FieldID field_id; + if (field_id_value) { + Value field_id_integer_value = field_id_value->DefaultCastAs(LogicalType::INTEGER); + const uint32_t field_id_int = IntegerValue::Get(field_id_integer_value); + if (!unique_field_ids.insert(field_id_int).second) { + throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); + } + field_id = FieldID(UnsafeNumericCast(field_id_int)); + } + auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); + D_ASSERT(inserted.second); + + if (child_field_ids_value) { + const auto &col_type = it->second; + if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && + col_type.id() != LogicalTypeId::STRUCT) { + throw BinderException("Column \"%s\" with type \"%s\" cannot have a nested FIELD_IDS specification", + col_name, LogicalTypeIdToString(col_type.id())); + } + + GetFieldIDs(*child_field_ids_value, inserted.first->second.child_field_ids, unique_field_ids, + GetChildNameToTypeMap(col_type)); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_shredding.cpp b/src/duckdb/extension/parquet/parquet_shredding.cpp new file mode 100644 index 000000000..b7ed673a8 --- /dev/null +++ b/src/duckdb/extension/parquet/parquet_shredding.cpp @@ -0,0 +1,81 @@ +#include "parquet_shredding.hpp" +#include "duckdb/common/exception/binder_exception.hpp" +#include "duckdb/common/type_visitor.hpp" + +namespace duckdb { + +ChildShreddingTypes::ChildShreddingTypes() : types(make_uniq>()) { +} + +ChildShreddingTypes ChildShreddingTypes::Copy() const { + ChildShreddingTypes result; + for (const auto &type : *types) { + result.types->emplace(type.first, type.second.Copy()); + } + return result; +} + +ShreddingType::ShreddingType() : set(false) { +} + +ShreddingType::ShreddingType(const LogicalType &type) : set(true), type(type) { +} + +ShreddingType ShreddingType::Copy() const { + auto result = set ? ShreddingType(type) : ShreddingType(); + result.children = children.Copy(); + return result; +} + +static ShreddingType ConvertShreddingTypeRecursive(const LogicalType &type) { + if (type.id() == LogicalTypeId::VARIANT) { + return ShreddingType(LogicalType(LogicalTypeId::ANY)); + } + if (!type.IsNested()) { + return ShreddingType(type); + } + + switch (type.id()) { + case LogicalTypeId::STRUCT: { + ShreddingType res(type); + auto &children = StructType::GetChildTypes(type); + for (auto &entry : children) { + res.AddChild(entry.first, ConvertShreddingTypeRecursive(entry.second)); + } + return res; + } + case LogicalTypeId::LIST: { + ShreddingType res(type); + const auto &child = ListType::GetChildType(type); + res.AddChild("element", ConvertShreddingTypeRecursive(child)); + return res; + } + default: + break; + } + throw BinderException("VARIANT can only be shredded on LIST/STRUCT/ANY/non-nested type, not %s", type.ToString()); +} + +void ShreddingType::AddChild(const string &name, ShreddingType &&child) { + children.types->emplace(name, std::move(child)); +} + +optional_ptr ShreddingType::GetChild(const string &name) const { + auto it = children.types->find(name); + if (it == children.types->end()) { + return nullptr; + } + return it->second; +} + +ShreddingType ShreddingType::GetShreddingTypes(const Value &val) { + if (val.type().id() != LogicalTypeId::VARCHAR) { + throw BinderException("SHREDDING value should be of type VARCHAR, a stringified type to use for the column"); + } + auto type_str = val.GetValue(); + auto logical_type = TransformStringToLogicalType(type_str); + + return ConvertShreddingTypeRecursive(logical_type); +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index a2e34b9e5..82cb6d276 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -3,6 +3,7 @@ #include "duckdb.hpp" #include "mbedtls_wrapper.hpp" #include "parquet_crypto.hpp" +#include "parquet_shredding.hpp" #include "parquet_timestamp.hpp" #include "resizable_buffer.hpp" #include "duckdb/common/file_system.hpp" @@ -35,29 +36,6 @@ using duckdb_parquet::PageType; using ParquetRowGroup = duckdb_parquet::RowGroup; using duckdb_parquet::Type; -ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { -} - -ChildFieldIDs ChildFieldIDs::Copy() const { - ChildFieldIDs result; - for (const auto &id : *ids) { - result.ids->emplace(id.first, id.second.Copy()); - } - return result; -} - -FieldID::FieldID() : set(false) { -} - -FieldID::FieldID(int32_t field_id_p) : set(true), field_id(field_id_p) { -} - -FieldID FieldID::Copy() const { - auto result = set ? FieldID(field_id) : FieldID(); - result.child_field_ids = child_field_ids.Copy(); - return result; -} - class MyTransport : public TTransport { public: explicit MyTransport(WriteStream &serializer) : serializer(serializer) { @@ -166,7 +144,8 @@ Type::type ParquetWriter::DuckDBTypeToParquetType(const LogicalType &duckdb_type throw NotImplementedException("Unimplemented type for Parquet \"%s\"", duckdb_type.ToString()); } -void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele) { +void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele, + bool allow_geometry) { if (duckdb_type.IsJSONType()) { schema_ele.converted_type = ConvertedType::JSON; schema_ele.__isset.converted_type = true; @@ -174,7 +153,7 @@ void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_p schema_ele.logicalType.__set_JSON(duckdb_parquet::JsonType()); return; } - if (duckdb_type.GetAlias() == "WKB_BLOB") { + if (duckdb_type.GetAlias() == "WKB_BLOB" && allow_geometry) { schema_ele.__isset.logicalType = true; schema_ele.logicalType.__isset.GEOMETRY = true; // TODO: Set CRS in the future @@ -352,18 +331,21 @@ class ParquetStatsAccumulator { ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file_name_p, vector types_p, vector names_p, CompressionCodec::type codec, ChildFieldIDs field_ids_p, - const vector> &kv_metadata, + ShreddingType shredding_types_p, const vector> &kv_metadata, shared_ptr encryption_config_p, optional_idx dictionary_size_limit_p, idx_t string_dictionary_page_size_limit_p, bool enable_bloom_filters_p, double bloom_filter_false_positive_ratio_p, - int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version) + int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version, + GeoParquetVersion geoparquet_version) : context(context), file_name(std::move(file_name_p)), sql_types(std::move(types_p)), column_names(std::move(names_p)), codec(codec), field_ids(std::move(field_ids_p)), - encryption_config(std::move(encryption_config_p)), dictionary_size_limit(dictionary_size_limit_p), + shredding_types(std::move(shredding_types_p)), encryption_config(std::move(encryption_config_p)), + dictionary_size_limit(dictionary_size_limit_p), string_dictionary_page_size_limit(string_dictionary_page_size_limit_p), enable_bloom_filters(enable_bloom_filters_p), bloom_filter_false_positive_ratio(bloom_filter_false_positive_ratio_p), compression_level(compression_level_p), - debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), total_written(0), num_row_groups(0) { + debug_use_openssl(debug_use_openssl_p), parquet_version(parquet_version), geoparquet_version(geoparquet_version), + total_written(0), num_row_groups(0) { // initialize the file writer writer = make_uniq(fs, file_name.c_str(), @@ -416,10 +398,13 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file auto &unique_names = column_names; VerifyUniqueNames(unique_names); + // V1 GeoParquet stores geometries as blobs, no logical type + auto allow_geometry = geoparquet_version != GeoParquetVersion::V1; + // construct the child schemas for (idx_t i = 0; i < sql_types.size(); i++) { - auto child_schema = - ColumnWriter::FillParquetSchema(file_meta_data.schema, sql_types[i], unique_names[i], &field_ids); + auto child_schema = ColumnWriter::FillParquetSchema(file_meta_data.schema, sql_types[i], unique_names[i], + allow_geometry, &field_ids, &shredding_types); column_schemas.push_back(std::move(child_schema)); } // now construct the writers based on the schemas @@ -977,7 +962,8 @@ void ParquetWriter::Finalize() { } // Add geoparquet metadata to the file metadata - if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { + if (geoparquet_data && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context) && + geoparquet_version != GeoParquetVersion::NONE) { geoparquet_data->Write(file_meta_data); } @@ -1007,7 +993,7 @@ void ParquetWriter::Finalize() { GeoParquetFileMetadata &ParquetWriter::GetGeoParquetData() { if (!geoparquet_data) { - geoparquet_data = make_uniq(); + geoparquet_data = make_uniq(geoparquet_version); } return *geoparquet_data; } diff --git a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp index 916e6e2cd..b96304d98 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp @@ -124,7 +124,7 @@ VariantValue ConvertShreddedValue::Convert(hugeint_t val) { template vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &blob, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { UnifiedVectorFormat metadata_format; metadata.ToUnifiedFormat(length, metadata_format); auto metadata_data = metadata_format.GetData(metadata_format); @@ -174,7 +174,12 @@ vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &b } else { ret[i] = OP::Convert(data[typed_index]); } - } else if (value_validity.RowIsValid(value_index)) { + } else { + if (is_field && !value_validity.RowIsValid(value_index)) { + //! Value is missing for this field + continue; + } + D_ASSERT(value_validity.RowIsValid(value_index)); auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); ret[i] = VariantBinaryDecoder::Decode(variant_metadata, @@ -187,7 +192,7 @@ vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &b vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { D_ASSERT(!typed_value.GetType().IsNested()); vector result; @@ -196,37 +201,37 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta //! boolean case LogicalTypeId::BOOLEAN: { return ConvertTypedValues, LogicalTypeId::BOOLEAN>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int8 case LogicalTypeId::TINYINT: { return ConvertTypedValues, LogicalTypeId::TINYINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int16 case LogicalTypeId::SMALLINT: { return ConvertTypedValues, LogicalTypeId::SMALLINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int32 case LogicalTypeId::INTEGER: { return ConvertTypedValues, LogicalTypeId::INTEGER>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! int64 case LogicalTypeId::BIGINT: { return ConvertTypedValues, LogicalTypeId::BIGINT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! float case LogicalTypeId::FLOAT: { return ConvertTypedValues, LogicalTypeId::FLOAT>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! double case LogicalTypeId::DOUBLE: { return ConvertTypedValues, LogicalTypeId::DOUBLE>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! decimal4/decimal8/decimal16 case LogicalTypeId::DECIMAL: { @@ -234,15 +239,15 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta switch (physical_type) { case PhysicalType::INT32: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } case PhysicalType::INT64: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } case PhysicalType::INT128: { return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } default: throw NotImplementedException("Decimal with PhysicalType (%s) not implemented for shredded Variant", @@ -252,42 +257,42 @@ vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &meta //! date case LogicalTypeId::DATE: { return ConvertTypedValues, LogicalTypeId::DATE>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! time case LogicalTypeId::TIME: { return ConvertTypedValues, LogicalTypeId::TIME>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestamptz(6) (timestamptz(9) not implemented in DuckDB) case LogicalTypeId::TIMESTAMP_TZ: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP_TZ>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestampntz(6) case LogicalTypeId::TIMESTAMP: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! timestampntz(9) case LogicalTypeId::TIMESTAMP_NS: { return ConvertTypedValues, LogicalTypeId::TIMESTAMP_NS>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! binary case LogicalTypeId::BLOB: { return ConvertTypedValues, LogicalTypeId::BLOB>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! string case LogicalTypeId::VARCHAR: { return ConvertTypedValues, LogicalTypeId::VARCHAR>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } //! uuid case LogicalTypeId::UUID: { return ConvertTypedValues, LogicalTypeId::UUID>( - typed_value, metadata, value, offset, length, total_size); + typed_value, metadata, value, offset, length, total_size, is_field); } default: throw NotImplementedException("Variant shredding on type: '%s' is not implemented", type.ToString()); @@ -395,7 +400,7 @@ static VariantValue ConvertPartiallyShreddedObject(vector vector VariantShreddedConversion::ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { auto &type = typed_value.GetType(); D_ASSERT(type.id() == LogicalTypeId::STRUCT); auto &fields = StructType::GetChildTypes(type); @@ -445,7 +450,10 @@ vector VariantShreddedConversion::ConvertShreddedObject(Vector &me if (typed_validity.RowIsValid(typed_index)) { ret[i] = ConvertPartiallyShreddedObject(shredded_fields, metadata_format, value_format, i, offset); } else { - //! The value on this row is not an object, and guaranteed to be present + if (is_field && !validity.RowIsValid(value_index)) { + //! This object is a field in the parent object, the value is missing, skip it + continue; + } D_ASSERT(validity.RowIsValid(value_index)); auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); @@ -463,7 +471,7 @@ vector VariantShreddedConversion::ConvertShreddedObject(Vector &me vector VariantShreddedConversion::ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { + idx_t total_size, const bool is_field) { auto &child = ListVector::GetEntry(typed_value); auto list_size = ListVector::GetListSize(typed_value); @@ -489,23 +497,26 @@ vector VariantShreddedConversion::ConvertShreddedArray(Vector &met //! We can be sure that none of the values are binary encoded for (idx_t i = 0; i < length; i++) { auto typed_index = list_format.sel->get_index(i + offset); - //! FIXME: next 4 lines duplicated below auto entry = list_data[typed_index]; Vector child_metadata(metadata.GetValue(i)); ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size, false); } } else { for (idx_t i = 0; i < length; i++) { auto typed_index = list_format.sel->get_index(i + offset); auto value_index = value_format.sel->get_index(i + offset); if (validity.RowIsValid(typed_index)) { - //! FIXME: next 4 lines duplicate auto entry = list_data[typed_index]; Vector child_metadata(metadata.GetValue(i)); ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); - } else if (value_validity.RowIsValid(value_index)) { + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size, false); + } else { + if (is_field && !value_validity.RowIsValid(value_index)) { + //! Value is missing for this field + continue; + } + D_ASSERT(value_validity.RowIsValid(value_index)); auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; VariantMetadata variant_metadata(metadata_value); ret[i] = VariantBinaryDecoder::Decode(variant_metadata, @@ -547,11 +558,11 @@ vector VariantShreddedConversion::Convert(Vector &metadata, Vector auto &type = typed_value->GetType(); vector ret; if (type.id() == LogicalTypeId::STRUCT) { - return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size, is_field); } else if (type.id() == LogicalTypeId::LIST) { - return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size, is_field); } else { - return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size); + return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size, is_field); } } else { if (is_field) { diff --git a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp index 402bcbb07..14fdb3987 100644 --- a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp @@ -92,7 +92,7 @@ idx_t VariantColumnReader::Read(uint64_t num_values, data_ptr_t define_out, data } } conversion_result = - VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values); + VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values, false); for (idx_t i = 0; i < conversion_result.size(); i++) { auto &variant = conversion_result[i]; diff --git a/src/duckdb/extension/parquet/serialize_parquet.cpp b/src/duckdb/extension/parquet/serialize_parquet.cpp index aa5632077..6f12d5d89 100644 --- a/src/duckdb/extension/parquet/serialize_parquet.cpp +++ b/src/duckdb/extension/parquet/serialize_parquet.cpp @@ -7,7 +7,8 @@ #include "duckdb/common/serializer/deserializer.hpp" #include "parquet_reader.hpp" #include "parquet_crypto.hpp" -#include "parquet_writer.hpp" +#include "parquet_field_id.hpp" +#include "parquet_shredding.hpp" namespace duckdb { @@ -21,6 +22,16 @@ ChildFieldIDs ChildFieldIDs::Deserialize(Deserializer &deserializer) { return result; } +void ChildShreddingTypes::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault>(100, "types", types.operator*()); +} + +ChildShreddingTypes ChildShreddingTypes::Deserialize(Deserializer &deserializer) { + ChildShreddingTypes result; + deserializer.ReadPropertyWithDefault>(100, "types", result.types.operator*()); + return result; +} + void FieldID::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(100, "set", set); serializer.WritePropertyWithDefault(101, "field_id", field_id); @@ -89,4 +100,18 @@ ParquetOptionsSerialization ParquetOptionsSerialization::Deserialize(Deserialize return result; } +void ShreddingType::Serialize(Serializer &serializer) const { + serializer.WritePropertyWithDefault(100, "set", set); + serializer.WriteProperty(101, "type", type); + serializer.WriteProperty(102, "children", children); +} + +ShreddingType ShreddingType::Deserialize(Deserializer &deserializer) { + ShreddingType result; + deserializer.ReadPropertyWithDefault(100, "set", result.set); + deserializer.ReadProperty(101, "type", result.type); + deserializer.ReadProperty(102, "children", result.children); + return result; +} + } // namespace duckdb diff --git a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp index 16189ab24..0b885dfa8 100644 --- a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp @@ -304,12 +304,24 @@ void PrimitiveColumnWriter::SetParquetStatistics(PrimitiveColumnWriterState &sta } if (state.stats_state->HasGeoStats()) { - column_chunk.meta_data.__isset.geospatial_statistics = true; - state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); - // Add the geospatial statistics to the extra GeoParquet metadata - writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, - *state.stats_state->GetGeoStats()); + auto gpq_version = writer.GetGeoParquetVersion(); + + const auto has_real_stats = gpq_version == GeoParquetVersion::NONE || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; + const auto has_json_stats = gpq_version == GeoParquetVersion::V1 || gpq_version == GeoParquetVersion::BOTH || + gpq_version == GeoParquetVersion::V2; + + if (has_real_stats) { + // Write the parquet native geospatial statistics + column_chunk.meta_data.__isset.geospatial_statistics = true; + state.stats_state->WriteGeoStats(column_chunk.meta_data.geospatial_statistics); + } + if (has_json_stats) { + // Add the geospatial statistics to the extra GeoParquet metadata + writer.GetGeoParquetData().AddGeoParquetStats(column_schema.name, column_schema.type, + *state.stats_state->GetGeoStats()); + } } for (const auto &write_info : state.write_info) { diff --git a/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp index ba6c707dd..836229d19 100644 --- a/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp +++ b/src/duckdb/extension/parquet/writer/variant/convert_variant.cpp @@ -3,6 +3,8 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/function/scalar/variant_utils.hpp" #include "reader/variant/variant_binary_decoder.hpp" +#include "parquet_shredding.hpp" +#include "duckdb/common/types/decimal.hpp" #include "duckdb/common/types/uuid.hpp" namespace duckdb { @@ -97,8 +99,135 @@ static void CreateMetadata(UnifiedVariantVectorData &variant, Vector &metadata, } } +namespace { + +static unordered_set GetVariantType(const LogicalType &type) { + if (type.id() == LogicalTypeId::ANY) { + return {}; + } + switch (type.id()) { + case LogicalTypeId::STRUCT: + return {VariantLogicalType::OBJECT}; + case LogicalTypeId::LIST: + return {VariantLogicalType::ARRAY}; + case LogicalTypeId::BOOLEAN: + return {VariantLogicalType::BOOL_TRUE, VariantLogicalType::BOOL_FALSE}; + case LogicalTypeId::TINYINT: + return {VariantLogicalType::INT8}; + case LogicalTypeId::SMALLINT: + return {VariantLogicalType::INT16}; + case LogicalTypeId::INTEGER: + return {VariantLogicalType::INT32}; + case LogicalTypeId::BIGINT: + return {VariantLogicalType::INT64}; + case LogicalTypeId::FLOAT: + return {VariantLogicalType::FLOAT}; + case LogicalTypeId::DOUBLE: + return {VariantLogicalType::DOUBLE}; + case LogicalTypeId::DECIMAL: + return {VariantLogicalType::DECIMAL}; + case LogicalTypeId::DATE: + return {VariantLogicalType::DATE}; + case LogicalTypeId::TIME: + return {VariantLogicalType::TIME_MICROS}; + case LogicalTypeId::TIMESTAMP_TZ: + return {VariantLogicalType::TIMESTAMP_MICROS_TZ}; + case LogicalTypeId::TIMESTAMP: + return {VariantLogicalType::TIMESTAMP_MICROS}; + case LogicalTypeId::TIMESTAMP_NS: + return {VariantLogicalType::TIMESTAMP_NANOS}; + case LogicalTypeId::BLOB: + return {VariantLogicalType::BLOB}; + case LogicalTypeId::VARCHAR: + return {VariantLogicalType::VARCHAR}; + case LogicalTypeId::UUID: + return {VariantLogicalType::UUID}; + default: + throw BinderException("Type '%s' can't be translated to a VARIANT type", type.ToString()); + } +} + +struct ShreddingState { +public: + explicit ShreddingState(const LogicalType &type, idx_t total_count) + : type(type), shredded_sel(total_count), values_index_sel(total_count), result_sel(total_count) { + variant_types = GetVariantType(type); + } + +public: + bool ValueIsShredded(UnifiedVariantVectorData &variant, idx_t row, idx_t values_index) { + auto type_id = variant.GetTypeId(row, values_index); + if (!variant_types.count(type_id)) { + return false; + } + if (type_id == VariantLogicalType::DECIMAL) { + auto physical_type = type.InternalType(); + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, values_index); + auto decimal_physical_type = decimal_data.GetPhysicalType(); + return physical_type == decimal_physical_type; + } + return true; + } + void SetShredded(idx_t row, idx_t values_index, idx_t result_idx) { + shredded_sel[count] = row; + values_index_sel[count] = values_index; + result_sel[count] = result_idx; + count++; + } + case_insensitive_string_set_t ObjectFields() { + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + case_insensitive_string_set_t res; + auto &child_types = StructType::GetChildTypes(type); + for (auto &entry : child_types) { + auto &type = entry.first; + res.emplace(string_t(type.c_str(), type.size())); + } + return res; + } + +public: + //! The type the field is shredded on + const LogicalType &type; + unordered_set variant_types; + //! row that is shredded + SelectionVector shredded_sel; + //! 'values_index' of the shredded value + SelectionVector values_index_sel; + //! result row of the shredded value + SelectionVector result_sel; + //! The amount of rows that are shredded on + idx_t count = 0; +}; + +} // namespace + +vector GetChildIndices(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data, + optional_ptr shredding_state) { + vector child_indices; + if (!shredding_state || shredding_state->type.id() != LogicalTypeId::STRUCT) { + for (idx_t i = 0; i < nested_data.child_count; i++) { + child_indices.push_back(i); + } + return child_indices; + } + //! FIXME: The variant spec says that field names should be case-sensitive, not insensitive + case_insensitive_string_set_t shredded_fields = shredding_state->ObjectFields(); + + for (idx_t i = 0; i < nested_data.child_count; i++) { + auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); + auto &key = variant.GetKey(row, keys_index); + + if (shredded_fields.count(key)) { + //! This field is shredded on, omit it from the value + continue; + } + child_indices.push_back(i); + } + return child_indices; +} + static idx_t AnalyzeValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, - vector &offsets) { + vector &offsets, optional_ptr shredding_state) { idx_t total_size = 0; //! Every value has at least a value header total_size++; @@ -115,21 +244,30 @@ static idx_t AnalyzeValueData(const UnifiedVariantVectorData &variant, idx_t row //! Calculate value and key offsets for all children idx_t total_offset = 0; uint32_t highest_keys_index = 0; - offsets.resize(offset_size + nested_data.child_count + 1); - for (idx_t i = 0; i < nested_data.child_count; i++) { + + auto child_indices = GetChildIndices(variant, row, nested_data, shredding_state); + if (nested_data.child_count && child_indices.empty()) { + //! All fields of the object are shredded, omit the object entirely + return 0; + } + + auto num_elements = child_indices.size(); + offsets.resize(offset_size + num_elements + 1); + + for (idx_t entry = 0; entry < child_indices.size(); entry++) { + auto i = child_indices[entry]; auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); - offsets[offset_size + i] = total_offset; + offsets[offset_size + entry] = total_offset; - total_offset += AnalyzeValueData(variant, row, values_index, offsets); + total_offset += AnalyzeValueData(variant, row, values_index, offsets, nullptr); highest_keys_index = MaxValue(highest_keys_index, keys_index); } - offsets[offset_size + nested_data.child_count] = total_offset; + offsets[offset_size + num_elements] = total_offset; //! Calculate the sizes for the objects value data auto field_id_size = CalculateByteLength(highest_keys_index); auto field_offset_size = CalculateByteLength(total_offset); - auto num_elements = nested_data.child_count; const bool is_large = num_elements > NumericLimits::Maximum(); //! Now add the sizes for the objects value data @@ -152,7 +290,7 @@ static idx_t AnalyzeValueData(const UnifiedVariantVectorData &variant, idx_t row auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); offsets[offset_size + i] = total_offset; - total_offset += AnalyzeValueData(variant, row, values_index, offsets); + total_offset += AnalyzeValueData(variant, row, values_index, offsets, nullptr); } offsets[offset_size + nested_data.child_count] = total_offset; @@ -421,7 +559,9 @@ static void WritePrimitiveValueData(const UnifiedVariantVectorData &variant, idx } static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t values_index, - data_ptr_t &value_data, const vector &offsets, idx_t &offset_index) { + data_ptr_t &value_data, const vector &offsets, idx_t &offset_index, + optional_ptr shredding_state) { + VariantLogicalType type_id = VariantLogicalType::VARIANT_NULL; if (variant.RowIsValid(row)) { type_id = variant.GetTypeId(row, values_index); @@ -431,22 +571,28 @@ static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, u //! -- Object value header -- + auto child_indices = GetChildIndices(variant, row, nested_data, shredding_state); + if (nested_data.child_count && child_indices.empty()) { + throw InternalException( + "The entire should be omitted, should have been handled by the Analyze step already"); + } + auto num_elements = child_indices.size(); + //! Determine the 'field_id_size' uint32_t highest_keys_index = 0; - for (idx_t i = 0; i < nested_data.child_count; i++) { + for (auto &i : child_indices) { auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); highest_keys_index = MaxValue(highest_keys_index, keys_index); } auto field_id_size = CalculateByteLength(highest_keys_index); uint32_t last_offset = 0; - if (nested_data.child_count) { - last_offset = offsets[offset_index + nested_data.child_count]; + if (num_elements) { + last_offset = offsets[offset_index + num_elements]; } - offset_index += nested_data.child_count + 1; + offset_index += num_elements + 1; auto field_offset_size = CalculateByteLength(last_offset); - auto num_elements = nested_data.child_count; const bool is_large = num_elements > NumericLimits::Maximum(); uint8_t value_header = 0; @@ -476,7 +622,7 @@ static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, u } //! Write the 'field_id' entries - for (idx_t i = 0; i < nested_data.child_count; i++) { + for (auto &i : child_indices) { auto keys_index = variant.GetKeysIndex(row, i + nested_data.children_idx); memcpy(value_data, reinterpret_cast(&keys_index), field_id_size); value_data += field_id_size; @@ -485,13 +631,13 @@ static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, u //! Write the 'field_offset' entries and the child 'value's auto children_ptr = value_data + ((num_elements + 1) * field_offset_size); idx_t total_offset = 0; - for (idx_t i = 0; i < nested_data.child_count; i++) { + for (auto &i : child_indices) { auto values_index = variant.GetValuesIndex(row, i + nested_data.children_idx); memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); value_data += field_offset_size; auto start_ptr = children_ptr; - WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index); + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index, nullptr); total_offset += (children_ptr - start_ptr); } memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); @@ -546,7 +692,7 @@ static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, u memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); value_data += field_offset_size; auto start_ptr = children_ptr; - WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index); + WriteValueData(variant, row, values_index, children_ptr, offsets, offset_index, nullptr); total_offset += (children_ptr - start_ptr); } memcpy(value_data, reinterpret_cast(&total_offset), field_offset_size); @@ -558,28 +704,393 @@ static void WriteValueData(const UnifiedVariantVectorData &variant, idx_t row, u } } -static void CreateValues(UnifiedVariantVectorData &variant, Vector &value, idx_t count) { +static void CreateValues(UnifiedVariantVectorData &variant, Vector &value, optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, optional_ptr shredding_state, + idx_t count) { + auto &validity = FlatVector::Validity(value); auto value_data = FlatVector::GetData(value); - for (idx_t row = 0; row < count; row++) { + for (idx_t i = 0; i < count; i++) { + idx_t value_index = 0; + if (value_index_sel) { + value_index = value_index_sel->get_index(i); + } + + idx_t row = i; + if (sel) { + row = sel->get_index(i); + } + + idx_t result_index = i; + if (result_sel) { + result_index = result_sel->get_index(i); + } + + bool is_shredded = false; + if (variant.RowIsValid(row) && shredding_state && shredding_state->ValueIsShredded(variant, row, value_index)) { + shredding_state->SetShredded(row, value_index, result_index); + is_shredded = true; + if (shredding_state->type.id() != LogicalTypeId::STRUCT) { + //! Value is shredded, directly write a NULL to the 'value' if the type is not an OBJECT + //! When the type is OBJECT, all excess fields would still need to be written to the 'value' + validity.SetInvalid(result_index); + continue; + } + } + //! The (relative) offsets for each value, in the case of nesting vector offsets; //! Determine the size of this 'value' blob - idx_t blob_length = AnalyzeValueData(variant, row, 0, offsets); + idx_t blob_length = AnalyzeValueData(variant, row, value_index, offsets, shredding_state); if (!blob_length) { + //! This is only allowed to happen for a shredded OBJECT, where there are no excess fields to write for the + //! OBJECT + (void)is_shredded; + D_ASSERT(is_shredded); + validity.SetInvalid(result_index); continue; } - value_data[row] = StringVector::EmptyString(value, blob_length); - auto &value_blob = value_data[row]; + value_data[result_index] = StringVector::EmptyString(value, blob_length); + auto &value_blob = value_data[result_index]; auto value_blob_data = reinterpret_cast(value_blob.GetDataWriteable()); idx_t offset_index = 0; - WriteValueData(variant, row, 0, value_blob_data, offsets, offset_index); + WriteValueData(variant, row, value_index, value_blob_data, offsets, offset_index, shredding_state); D_ASSERT(data_ptr_cast(value_blob.GetDataWriteable() + blob_length) == value_blob_data); value_blob.SetSizeAndFinalize(blob_length, blob_length); } } +//! fwd-declare static method +static void WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count); + +static void WriteTypedObjectValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto &type = result.GetType(); + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + + auto &validity = FlatVector::Validity(result); + (void)validity; + + //! Collect the nested data for the objects + auto nested_data = make_unsafe_uniq_array_uninitialized(count); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + //! When we're shredding an object, the top-level struct of it should always be valid + D_ASSERT(validity.RowIsValid(result_sel[i])); + auto value_index = value_index_sel[i]; + D_ASSERT(variant.GetTypeId(row, value_index) == VariantLogicalType::OBJECT); + nested_data[i] = VariantUtils::DecodeNestedData(variant, row, value_index); + } + + auto &shredded_types = StructType::GetChildTypes(type); + auto &shredded_fields = StructVector::GetEntries(result); + D_ASSERT(shredded_types.size() == shredded_fields.size()); + + SelectionVector child_values_indexes; + SelectionVector child_row_sel; + SelectionVector child_result_sel; + child_values_indexes.Initialize(count); + child_row_sel.Initialize(count); + child_result_sel.Initialize(count); + + for (idx_t child_idx = 0; child_idx < shredded_types.size(); child_idx++) { + auto &child_vec = *shredded_fields[child_idx]; + D_ASSERT(child_vec.GetType() == shredded_types[child_idx].second); + + //! Prepare the path component to perform the lookup for + auto &key = shredded_types[child_idx].first; + VariantPathComponent path_component; + path_component.lookup_mode = VariantChildLookupMode::BY_KEY; + path_component.key = key; + + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(variant, path_component, sel, child_values_indexes, lookup_validity, + nested_data.get(), count); + + if (!lookup_validity.AllValid()) { + auto &child_variant_vectors = StructVector::GetEntries(child_vec); + + //! For some of the rows the field is missing, adjust the selection vector to exclude these rows. + idx_t child_count = 0; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + //! The field is missing, set it to null + FlatVector::SetNull(*child_variant_vectors[0], result_sel[i], true); + if (child_variant_vectors.size() >= 2) { + FlatVector::SetNull(*child_variant_vectors[1], result_sel[i], true); + } + continue; + } + + child_row_sel[child_count] = sel[i]; + child_values_indexes[child_count] = child_values_indexes[i]; + child_result_sel[child_count] = result_sel[i]; + child_count++; + } + + if (child_count) { + //! If not all rows are missing this field, write the values for it + WriteVariantValues(variant, child_vec, child_row_sel, child_values_indexes, child_result_sel, + child_count); + } + } else { + WriteVariantValues(variant, child_vec, &sel, child_values_indexes, result_sel, count); + } + } +} + +static void WriteTypedArrayValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto list_data = FlatVector::GetData(result); + + auto nested_data = make_unsafe_uniq_array_uninitialized(count); + + idx_t total_offset = 0; + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto value_index = value_index_sel[i]; + auto result_row = result_sel[i]; + + D_ASSERT(variant.GetTypeId(row, value_index) == VariantLogicalType::ARRAY); + nested_data[i] = VariantUtils::DecodeNestedData(variant, row, value_index); + + list_entry_t list_entry; + list_entry.length = nested_data[i].child_count; + list_entry.offset = total_offset; + list_data[result_row] = list_entry; + + total_offset += nested_data[i].child_count; + } + ListVector::Reserve(result, total_offset); + ListVector::SetListSize(result, total_offset); + + SelectionVector child_sel; + child_sel.Initialize(total_offset); + + SelectionVector child_value_index_sel; + child_value_index_sel.Initialize(total_offset); + + SelectionVector child_result_sel; + child_result_sel.Initialize(total_offset); + + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + + auto &array_data = nested_data[i]; + auto &entry = list_data[result_row]; + for (idx_t j = 0; j < entry.length; j++) { + auto offset = entry.offset + j; + child_sel[offset] = row; + child_value_index_sel[offset] = variant.GetValuesIndex(row, array_data.children_idx + j); + child_result_sel[offset] = offset; + } + } + + auto &child_vector = ListVector::GetEntry(result); + WriteVariantValues(variant, child_vector, child_sel, child_value_index_sel, child_result_sel, total_offset); +} + +//! TODO: introduce a third selection vector, because we also need one to map to the result row to write +//! This becomes necessary when we introduce LISTs into the equation because lists are stored on the same VARIANT row, +//! but we're now going to write the flattened child vector +static void WriteShreddedPrimitive(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count, idx_t type_size) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row)); + + auto byte_offset = variant.GetByteOffset(row, value_index); + auto &data = variant.GetData(row); + auto value_ptr = data.GetData(); + auto result_offset = type_size * result_row; + memcpy(result_data + result_offset, value_ptr + byte_offset, type_size); + } +} + +template +static void WriteShreddedDecimal(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row) && variant.GetTypeId(row, value_index) == VariantLogicalType::DECIMAL); + + auto decimal_data = VariantUtils::DecodeDecimalData(variant, row, value_index); + D_ASSERT(decimal_data.width <= DecimalWidth::max); + auto result_offset = sizeof(T) * result_row; + memcpy(result_data + result_offset, decimal_data.value_ptr, sizeof(T)); + } +} + +static void WriteShreddedString(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row) && (variant.GetTypeId(row, value_index) == VariantLogicalType::VARCHAR || + variant.GetTypeId(row, value_index) == VariantLogicalType::BLOB)); + + auto string_data = VariantUtils::DecodeStringData(variant, row, value_index); + result_data[result_row] = StringVector::AddStringOrBlob(result, string_data); + } +} + +static void WriteShreddedBoolean(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto result_data = FlatVector::GetData(result); + for (idx_t i = 0; i < count; i++) { + auto row = sel[i]; + auto result_row = result_sel[i]; + auto value_index = value_index_sel[i]; + D_ASSERT(variant.RowIsValid(row)); + auto type_id = variant.GetTypeId(row, value_index); + D_ASSERT(type_id == VariantLogicalType::BOOL_FALSE || type_id == VariantLogicalType::BOOL_TRUE); + + result_data[result_row] = type_id == VariantLogicalType::BOOL_TRUE; + } +} + +static void WriteTypedPrimitiveValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, + idx_t count) { + auto &type = result.GetType(); + D_ASSERT(!type.IsNested()); + switch (type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::FLOAT: + case LogicalTypeId::DOUBLE: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::UUID: { + const auto physical_type = type.InternalType(); + WriteShreddedPrimitive(variant, result, sel, value_index_sel, result_sel, count, GetTypeIdSize(physical_type)); + break; + } + case LogicalTypeId::DECIMAL: { + const auto physical_type = type.InternalType(); + switch (physical_type) { + //! DECIMAL4 + case PhysicalType::INT32: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + //! DECIMAL8 + case PhysicalType::INT64: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + //! DECIMAL16 + case PhysicalType::INT128: + WriteShreddedDecimal(variant, result, sel, value_index_sel, result_sel, count); + break; + default: + throw InvalidInputException("Can't shred on column of type '%s'", type.ToString()); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::VARCHAR: { + WriteShreddedString(variant, result, sel, value_index_sel, result_sel, count); + break; + } + case LogicalTypeId::BOOLEAN: + WriteShreddedBoolean(variant, result, sel, value_index_sel, result_sel, count); + break; + default: + throw InvalidInputException("Can't shred on type: %s", type.ToString()); + } +} + +static void WriteTypedValues(UnifiedVariantVectorData &variant, Vector &result, const SelectionVector &sel, + const SelectionVector &value_index_sel, const SelectionVector &result_sel, idx_t count) { + auto &type = result.GetType(); + + if (type.id() == LogicalTypeId::STRUCT) { + //! Shredded OBJECT + WriteTypedObjectValues(variant, result, sel, value_index_sel, result_sel, count); + } else if (type.id() == LogicalTypeId::LIST) { + //! Shredded ARRAY + WriteTypedArrayValues(variant, result, sel, value_index_sel, result_sel, count); + } else { + //! Primitive types + WriteTypedPrimitiveValues(variant, result, sel, value_index_sel, result_sel, count); + } +} + +static void WriteVariantValues(UnifiedVariantVectorData &variant, Vector &result, + optional_ptr sel, + optional_ptr value_index_sel, + optional_ptr result_sel, idx_t count) { + optional_ptr value; + optional_ptr typed_value; + + auto &result_type = result.GetType(); + D_ASSERT(result_type.id() == LogicalTypeId::STRUCT); + auto &child_types = StructType::GetChildTypes(result_type); + auto &child_vectors = StructVector::GetEntries(result); + D_ASSERT(child_types.size() == child_vectors.size()); + for (idx_t i = 0; i < child_types.size(); i++) { + auto &name = child_types[i].first; + if (name == "value") { + value = child_vectors[i].get(); + } else if (name == "typed_value") { + typed_value = child_vectors[i].get(); + } + } + + if (typed_value) { + ShreddingState shredding_state(typed_value->GetType(), count); + CreateValues(variant, *value, sel, value_index_sel, result_sel, &shredding_state, count); + + SelectionVector null_values; + if (shredding_state.count) { + WriteTypedValues(variant, *typed_value, shredding_state.shredded_sel, shredding_state.values_index_sel, + shredding_state.result_sel, shredding_state.count); + //! 'shredding_state.result_sel' will always be a subset of 'result_sel', set the rows not in the subset to + //! NULL + idx_t sel_idx = 0; + for (idx_t i = 0; i < count; i++) { + auto original_index = result_sel ? result_sel->get_index(i) : i; + if (sel_idx < shredding_state.count && shredding_state.result_sel[sel_idx] == original_index) { + sel_idx++; + continue; + } + FlatVector::SetNull(*typed_value, original_index, true); + } + } else { + //! Set all rows of the typed_value to NULL, nothing is shredded on + for (idx_t i = 0; i < count; i++) { + FlatVector::SetNull(*typed_value, result_sel ? result_sel->get_index(i) : i, true); + } + } + } else { + CreateValues(variant, *value, sel, value_index_sel, result_sel, nullptr, count); + } +} + static void ToParquetVariant(DataChunk &input, ExpressionState &state, Vector &result) { // DuckDB Variant: // - keys = VARCHAR[] @@ -599,15 +1110,57 @@ static void ToParquetVariant(DataChunk &input, ExpressionState &state, Vector &r UnifiedVariantVectorData variant(recursive_format); auto &result_vectors = StructVector::GetEntries(result); - CreateMetadata(variant, *result_vectors[0], count); - CreateValues(variant, *result_vectors[1], count); + auto &metadata = *result_vectors[0]; + CreateMetadata(variant, metadata, count); + WriteVariantValues(variant, result, nullptr, nullptr, nullptr, count); + + if (input.AllConstant()) { + result.SetVectorType(VectorType::CONSTANT_VECTOR); + } } -static LogicalType GetParquetVariantType(const LogicalType &type) { - (void)type; +LogicalType VariantColumnWriter::TransformTypedValueRecursive(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::STRUCT: { + //! Wrap all fields of the struct in a struct with 'value' and 'typed_value' fields + auto &child_types = StructType::GetChildTypes(type); + child_list_t replaced_types; + for (auto &entry : child_types) { + child_list_t child_children; + child_children.emplace_back("value", LogicalType::BLOB); + if (entry.second.id() != LogicalTypeId::VARIANT) { + child_children.emplace_back("typed_value", TransformTypedValueRecursive(entry.second)); + } + replaced_types.emplace_back(entry.first, LogicalType::STRUCT(child_children)); + } + return LogicalType::STRUCT(replaced_types); + } + case LogicalTypeId::LIST: { + auto &child_type = ListType::GetChildType(type); + child_list_t replaced_types; + replaced_types.emplace_back("value", LogicalType::BLOB); + if (child_type.id() != LogicalTypeId::VARIANT) { + replaced_types.emplace_back("typed_value", TransformTypedValueRecursive(child_type)); + } + return LogicalType::LIST(LogicalType::STRUCT(replaced_types)); + } + case LogicalTypeId::UNION: + case LogicalTypeId::MAP: + case LogicalTypeId::VARIANT: + case LogicalTypeId::ARRAY: + throw BinderException("'%s' can't appear inside the a 'typed_value' shredded type!", type.ToString()); + default: + return type; + } +} + +static LogicalType GetParquetVariantType(optional_ptr shredding = nullptr) { child_list_t children; children.emplace_back("metadata", LogicalType::BLOB); children.emplace_back("value", LogicalType::BLOB); + if (shredding) { + children.emplace_back("typed_value", VariantColumnWriter::TransformTypedValueRecursive(*shredding)); + } auto res = LogicalType::STRUCT(std::move(children)); res.SetAlias("PARQUET_VARIANT"); return res; @@ -619,7 +1172,29 @@ static unique_ptr BindTransform(ClientContext &context, ScalarFunc return nullptr; } auto type = ExpressionBinder::GetExpressionReturnType(*arguments[0]); - bound_function.return_type = GetParquetVariantType(type); + + if (arguments.size() == 2) { + auto &shredding = *arguments[1]; + auto expr_return_type = ExpressionBinder::GetExpressionReturnType(shredding); + expr_return_type = LogicalType::NormalizeType(expr_return_type); + if (expr_return_type.id() != LogicalTypeId::VARCHAR) { + throw BinderException("Optional second argument 'shredding' has to be of type VARCHAR, i.e: " + "'STRUCT(my_field BOOLEAN)', found type: '%s' instead", + expr_return_type); + } + if (!shredding.IsFoldable()) { + throw BinderException("Optional second argument 'shredding' has to be a constant expression"); + } + Value type_str = ExpressionExecutor::EvaluateScalar(context, shredding); + if (type_str.IsNull()) { + throw BinderException("Optional second argument 'shredding' can not be NULL"); + } + auto shredded_type = TransformStringToLogicalType(type_str.GetValue()); + bound_function.return_type = GetParquetVariantType(shredded_type); + } else { + bound_function.return_type = GetParquetVariantType(); + } + return nullptr; } diff --git a/src/duckdb/extension/parquet/writer/variant_column_writer.cpp b/src/duckdb/extension/parquet/writer/variant_column_writer.cpp deleted file mode 100644 index b4f401da8..000000000 --- a/src/duckdb/extension/parquet/writer/variant_column_writer.cpp +++ /dev/null @@ -1,131 +0,0 @@ -#include "writer/variant_column_writer.hpp" -#include "duckdb/common/types/variant.hpp" -#include "duckdb/common/helper.hpp" - -namespace duckdb { - -namespace { - -class VariantColumnWriterState : public ColumnWriterState { -public: - VariantColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) - : row_group(row_group), col_idx(col_idx) { - } - ~VariantColumnWriterState() override = default; - - duckdb_parquet::RowGroup &row_group; - idx_t col_idx; - vector> child_states; -}; - -} // namespace - -unique_ptr VariantColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { - auto result = make_uniq(row_group, row_group.columns.size()); - - result->child_states.reserve(child_writers.size()); - for (auto &child_writer : child_writers) { - result->child_states.push_back(child_writer->InitializeWriteState(row_group)); - } - return std::move(result); -} - -bool VariantColumnWriter::HasAnalyze() { - for (auto &child_writer : child_writers) { - if (child_writer->HasAnalyze()) { - return true; - } - } - return false; -} - -void VariantColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - auto &child_vectors = StructVector::GetEntries(vector); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - // Need to check again. It might be that just one child needs it but the rest not - if (child_writers[child_idx]->HasAnalyze()) { - child_writers[child_idx]->Analyze(*state.child_states[child_idx], &state_p, *child_vectors[child_idx], - count); - } - } -} - -void VariantColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - // Need to check again. It might be that just one child needs it but the rest not - if (child_writers[child_idx]->HasAnalyze()) { - child_writers[child_idx]->FinalizeAnalyze(*state.child_states[child_idx]); - } - } -} - -void VariantColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count, - bool vector_can_span_multiple_pages) { - D_ASSERT(child_writers.size() == 2); - auto &metadata_writer = *child_writers[0]; - auto &value_writer = *child_writers[1]; - - auto &state = state_p.Cast(); - auto &metadata_state = *state.child_states[0]; - auto &value_state = *state.child_states[1]; - - auto &validity = FlatVector::Validity(vector); - if (parent) { - // propagate empty entries from the parent - if (state.is_empty.size() < parent->is_empty.size()) { - state.is_empty.insert(state.is_empty.end(), parent->is_empty.begin() + state.is_empty.size(), - parent->is_empty.end()); - } - } - HandleRepeatLevels(state_p, parent, count); - HandleDefineLevels(state_p, parent, validity, count, PARQUET_DEFINE_VALID, MaxDefine() - 1); - - auto &child_vectors = StructVector::GetEntries(vector); - metadata_writer.Prepare(metadata_state, &state_p, *child_vectors[0], count, vector_can_span_multiple_pages); - value_writer.Prepare(value_state, &state_p, *child_vectors[1], count, vector_can_span_multiple_pages); -} - -void VariantColumnWriter::BeginWrite(ColumnWriterState &state_p) { - D_ASSERT(child_writers.size() == 2); - auto &metadata_writer = *child_writers[0]; - auto &value_writer = *child_writers[1]; - - auto &state = state_p.Cast(); - auto &metadata_state = *state.child_states[0]; - auto &value_state = *state.child_states[1]; - - metadata_writer.BeginWrite(metadata_state); - value_writer.BeginWrite(value_state); -} - -void VariantColumnWriter::Write(ColumnWriterState &state_p, Vector &input, idx_t count) { - D_ASSERT(child_writers.size() == 2); - - auto &metadata_writer = *child_writers[0]; - auto &value_writer = *child_writers[1]; - - auto &state = state_p.Cast(); - auto &metadata_state = *state.child_states[0]; - auto &value_state = *state.child_states[1]; - - auto &child_vectors = StructVector::GetEntries(input); - metadata_writer.Write(metadata_state, *child_vectors[0], count); - value_writer.Write(value_state, *child_vectors[1], count); -} - -void VariantColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { - D_ASSERT(child_writers.size() == 2); - auto &metadata_writer = *child_writers[0]; - auto &value_writer = *child_writers[1]; - - auto &state = state_p.Cast(); - auto &metadata_state = *state.child_states[0]; - auto &value_state = *state.child_states[1]; - - metadata_writer.FinalizeWrite(metadata_state); - value_writer.FinalizeWrite(value_state); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp index deff8daae..d374f6999 100644 --- a/src/duckdb/src/catalog/catalog_set.cpp +++ b/src/duckdb/src/catalog/catalog_set.cpp @@ -401,8 +401,6 @@ bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); } - entry->OnDrop(); - // create a new tombstone entry and replace the currently stored one // set the timestamp to the timestamp of the current transaction // and point it at the tombstone node @@ -454,6 +452,7 @@ void CatalogSet::VerifyExistenceOfDependency(transaction_t commit_id, CatalogEnt void CatalogSet::CommitDrop(transaction_t commit_id, transaction_t start_time, CatalogEntry &entry) { auto &duck_catalog = GetCatalog(); + entry.OnDrop(); // Make sure that we don't see any uncommitted changes auto transaction_id = MAX_TRANSACTION_ID; // This will allow us to see all committed changes made before this COMMIT happened diff --git a/src/duckdb/src/common/row_operations/row_external.cpp b/src/duckdb/src/common/row_operations/row_external.cpp deleted file mode 100644 index e4e3ec87d..000000000 --- a/src/duckdb/src/common/row_operations/row_external.cpp +++ /dev/null @@ -1,157 +0,0 @@ -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Load heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = Load(heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string pointer with the within-row offset (if not inlined) - Store(UnsafeNumericCast(Load(string_ptr) - heap_row_ptrs[i]), - string_ptr); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data pointer with the within-row offset - Store(UnsafeNumericCast(Load(col_ptr) - heap_row_ptrs[i]), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -void RowOperations::SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset) { - const idx_t row_width = layout.GetRowWidth(); - row_ptr += layout.GetHeapOffset(); - idx_t cumulative_offset = 0; - for (idx_t i = 0; i < count; i++) { - Store(base_offset + cumulative_offset, row_ptr); - cumulative_offset += Load(heap_base_ptr + cumulative_offset); - row_ptr += row_width; - } -} - -void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - const auto heap_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - // Figure out source and size - const auto source_heap_ptr = Load(row_ptr + heap_offset); - const auto size = Load(source_heap_ptr); - D_ASSERT(size >= sizeof(uint32_t)); - - // Copy and swizzle - memcpy(heap_ptr, source_heap_ptr, size); - Store(UnsafeNumericCast(heap_ptr - heap_base_ptr), row_ptr + heap_offset); - - // Increment for next iteration - row_ptr += row_width; - heap_ptr += size; - } -} - -void RowOperations::UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - data_ptr_t heap_ptr_ptr = base_row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - Store(base_heap_ptr + Load(heap_ptr_ptr), heap_ptr_ptr); - heap_ptr_ptr += row_width; - } -} - -static inline void VerifyUnswizzledString(const RowLayout &layout, const idx_t &col_idx, const data_ptr_t &row_ptr) { -#ifdef DEBUG - if (layout.GetTypes()[col_idx].id() != LogicalTypeId::VARCHAR) { - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - ValidityBytes row_mask(row_ptr, layout.ColumnCount()); - if (row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - auto str = Load(row_ptr + layout.GetOffsets()[col_idx]); - str.Verify(); - } -#endif -} - -void RowOperations::UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Restore heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = base_heap_ptr + Load(heap_ptr_ptr); - Store(heap_row_ptrs[i], heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string offset with the pointer (if not inlined) - Store(heap_row_ptrs[i] + Load(string_ptr), string_ptr); - VerifyUnswizzledString(layout, col_idx, row_ptr + i * row_width); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data offset with the pointer - Store(heap_row_ptrs[i] + Load(col_ptr), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_gather.cpp b/src/duckdb/src/common/row_operations/row_gather.cpp deleted file mode 100644 index 8e5ed315b..000000000 --- a/src/duckdb/src/common/row_operations/row_gather.cpp +++ /dev/null @@ -1,176 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/constant_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedGatherLoop(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - idx_t build_size) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - data[col_idx] = Load(row + col_offset); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } - } -} - -static void GatherVarchar(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - idx_t count, const RowLayout &layout, idx_t col_no, idx_t build_size, - data_ptr_t base_heap_ptr) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - auto col_ptr = row + col_offset; - data[col_idx] = Load(col_ptr); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } else if (base_heap_ptr && Load(col_ptr) > string_t::INLINE_LENGTH) { - // Not inline, so unswizzle the copied pointer the pointer - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - auto string_ptr = data_ptr_t(data + col_idx) + string_t::HEADER_SIZE; - Store(heap_row_ptr + Load(string_ptr), string_ptr); -#ifdef DEBUG - data[col_idx].Verify(); -#endif - } - } -} - -static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - data_ptr_t base_heap_ptr) { - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - auto ptrs = FlatVector::GetData(rows); - - // Build the gather locations - auto data_locations = make_unsafe_uniq_array_uninitialized(count); - auto mask_locations = make_unsafe_uniq_array_uninitialized(count); - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - mask_locations[i] = row; - auto col_ptr = ptrs[row_idx] + col_offset; - if (base_heap_ptr) { - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - data_locations[i] = heap_row_ptr + Load(col_ptr); - } else { - data_locations[i] = Load(col_ptr); - } - } - - // Deserialise into the selected locations - NestedValidity parent_validity(mask_locations.get(), col_no); - RowOperations::HeapGather(col, count, col_sel, data_locations.get(), &parent_validity); -} - -void RowOperations::Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size, - data_ptr_t heap_ptr) { - D_ASSERT(rows.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(rows.GetType().id() == LogicalTypeId::POINTER); // "Cannot gather from non-pointer type!" - - col.SetVectorType(VectorType::FLAT_VECTOR); - switch (col.GetType().InternalType()) { - case PhysicalType::UINT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::FLOAT: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::DOUBLE: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INTERVAL: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::VARCHAR: - GatherVarchar(rows, row_sel, col, col_sel, count, layout, col_no, build_size, heap_ptr); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - GatherNestedVector(rows, row_sel, col, col_sel, count, layout, col_no, heap_ptr); - break; - default: - throw InternalException("Unimplemented type for RowOperations::Gather"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_gather.cpp b/src/duckdb/src/common/row_operations/row_heap_gather.cpp deleted file mode 100644 index fa433c64e..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_gather.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -template -static void TemplatedHeapGather(Vector &v, const idx_t count, const SelectionVector &sel, data_ptr_t *key_locations) { - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < count; ++i) { - const auto col_idx = sel.get_index(i); - target[col_idx] = Load(key_locations[i]); - key_locations[i] += sizeof(T); - } -} - -static void HeapGatherStringVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - auto len = Load(key_locations[i]); - key_locations[i] += sizeof(uint32_t); - target[col_idx] = StringVector::AddStringOrBlob(v, string_t(const_char_ptr_cast(key_locations[i]), len)); - key_locations[i] += len; - } -} - -static void HeapGatherStructVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // struct must have a validitymask for its fields - auto &child_types = StructType::GetChildTypes(v.GetType()); - const idx_t struct_validitymask_size = (child_types.size() + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < vcount; i++) { - // use key_locations as the validitymask, and create struct_key_locations - struct_validitymask_locations[i] = key_locations[i]; - key_locations[i] += struct_validitymask_size; - } - - // now deserialize into the struct vectors - auto &children = StructVector::GetEntries(v); - for (idx_t i = 0; i < child_types.size(); i++) { - NestedValidity parent_validity(struct_validitymask_locations, i); - RowOperations::HeapGather(*children[i], vcount, sel, key_locations, &parent_validity); - } -} - -static void HeapGatherListVector(Vector &v, const idx_t vcount, const SelectionVector &sel, data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - - auto child_type = ListType::GetChildType(v.GetType()); - auto list_data = ListVector::GetData(v); - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - uint64_t entry_offset = ListVector::GetListSize(v); - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - // read list length - auto entry_remaining = Load(key_locations[i]); - key_locations[i] += sizeof(uint64_t); - // set list entry attributes - list_data[col_idx].length = entry_remaining; - list_data[col_idx].offset = entry_offset; - // skip over the validity mask - data_ptr_t validitymask_location = key_locations[i]; - idx_t offset_in_byte = 0; - key_locations[i] += (entry_remaining + 7) / 8; - // entry sizes - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type.InternalType())) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += entry_remaining * sizeof(idx_t); - } - - // now read the list data - while (entry_remaining > 0) { - auto next = MinValue(entry_remaining, (idx_t)STANDARD_VECTOR_SIZE); - - // initialize a new vector to append - Vector append_vector(v.GetType()); - append_vector.SetVectorType(v.GetVectorType()); - - auto &list_vec_to_append = ListVector::GetEntry(append_vector); - - // set validity - //! Since we are constructing the vector, this will always be a flat vector. - auto &append_validity = FlatVector::Validity(list_vec_to_append); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - append_validity.Set(entry_idx, *(validitymask_location) & (1 << offset_in_byte)); - if (++offset_in_byte == 8) { - validitymask_location++; - offset_in_byte = 0; - } - } - - // compute entry sizes and set locations where the list entries are - if (TypeIsConstantSize(child_type.InternalType())) { - // constant size list entries - const idx_t type_size = GetTypeIdSize(child_type.InternalType()); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now deserialize and add to listvector - RowOperations::HeapGather(list_vec_to_append, next, *FlatVector::IncrementalSelectionVector(), - list_entry_locations, nullptr); - ListVector::Append(v, list_vec_to_append, next); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapGatherArrayVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // Setup - auto &child_type = ArrayType::GetChildType(v.GetType()); - auto array_size = ArrayType::GetSize(v.GetType()); - auto &child_vector = ArrayVector::GetEntry(v); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < vcount; i++) { - // Setup validity mask - data_ptr_t array_validitymask_location = key_locations[i]; - key_locations[i] += array_validitymask_size; - - NestedValidity parent_validity(array_validitymask_location); - - // The size of each variable size entry is stored after the validity mask - // (if the child type is variable size) - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // row idx - const auto row_idx = sel.get_index(i); - - idx_t array_start = row_idx * array_size; - idx_t elem_remaining = array_size; - - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - SelectionVector array_sel(STANDARD_VECTOR_SIZE); - - if (child_type_is_var_size) { - // variable size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } else { - // constant size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } - - // Pass on this array's validity mask to the child vector - RowOperations::HeapGather(child_vector, chunk_size, array_sel, array_entry_locations, &parent_validity); - - elem_remaining -= chunk_size; - array_start += chunk_size; - parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t *key_locations, - optional_ptr parent_validity) { - v.SetVectorType(VectorType::FLAT_VECTOR); - - auto &validity = FlatVector::Validity(v); - if (parent_validity) { - for (idx_t i = 0; i < vcount; i++) { - const auto valid = parent_validity->IsValid(i); - const auto col_idx = sel.get_index(i); - validity.Set(col_idx, valid); - } - } - - auto type = v.GetType().InternalType(); - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::FLOAT: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::DOUBLE: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INTERVAL: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::VARCHAR: - HeapGatherStringVector(v, vcount, sel, key_locations); - break; - case PhysicalType::STRUCT: - HeapGatherStructVector(v, vcount, sel, key_locations); - break; - case PhysicalType::LIST: - HeapGatherListVector(v, vcount, sel, key_locations); - break; - case PhysicalType::ARRAY: - HeapGatherArrayVector(v, vcount, sel, key_locations); - break; - default: - throw NotImplementedException("Unimplemented deserialize from row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp b/src/duckdb/src/common/row_operations/row_heap_scatter.cpp deleted file mode 100644 index 01cf7b589..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp +++ /dev/null @@ -1,581 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -NestedValidity::NestedValidity(data_ptr_t validitymask_location) - : list_validity_location(validitymask_location), struct_validity_locations(nullptr), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { -} - -NestedValidity::NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index) - : list_validity_location(nullptr), struct_validity_locations(validitymask_locations), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { - ValidityBytes::GetEntryIndex(child_vector_index, entry_idx, idx_in_entry); -} - -void NestedValidity::SetInvalid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = ~(1UL << list_idx_in_entry); - list_validity_location[list_entry_idx] &= bit; - } else { - // Is Struct - const auto bit = ~(1UL << idx_in_entry); - *(struct_validity_locations[idx] + entry_idx) &= bit; - } -} - -void NestedValidity::OffsetListBy(idx_t offset) { - list_validity_offset += offset; -} - -bool NestedValidity::IsValid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = (1UL << list_idx_in_entry); - return list_validity_location[list_entry_idx] & bit; - } else { - // Is Struct - const auto bit = (1UL << idx_in_entry); - return *(struct_validity_locations[idx] + entry_idx) & bit; - } -} - -static void ComputeStringEntrySizes(UnifiedVectorFormat &vdata, idx_t entry_sizes[], const idx_t ser_count, - const SelectionVector &sel, const idx_t offset) { - auto strings = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto str_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(str_idx)) { - entry_sizes[i] += sizeof(uint32_t) + strings[str_idx].GetSize(); - } - } -} - -static void ComputeStructEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - // obtain child vectors - idx_t num_children; - auto &children = StructVector::GetEntries(v); - num_children = children.size(); - // add struct validitymask size - const idx_t struct_validitymask_size = (num_children + 7) / 8; - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += struct_validitymask_size; - } - // compute size of child vectors - for (auto &struct_vector : children) { - RowOperations::ComputeEntrySizes(*struct_vector, entry_sizes, vcount, ser_count, sel, offset); - } -} - -static void ComputeListEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto list_entry = list_data[source_idx]; - - // make room for list length, list validitymask - entry_sizes[i] += sizeof(list_entry.length); - entry_sizes[i] += (list_entry.length + 7) / 8; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ListType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += list_entry.length * sizeof(list_entry.length); - } - - // compute size of each the elements in list_entry and sum them - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // compute and add to the total - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t list_idx = 0; list_idx < next; list_idx++) { - entry_sizes[i] += list_entry_sizes[list_idx]; - } - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } - } -} - -static void ComputeArrayEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_vector = ArrayVector::GetEntry(v); - - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - const idx_t array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - - // Validity for the array elements - entry_sizes[i] += array_validitymask_size; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ArrayType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += array_size * sizeof(idx_t); - } - - auto elem_idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(elem_idx + offset); - - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - // the array could span multiple vectors, so we divide it into chunks - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // compute and add to the total - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t arr_elem_idx = 0; arr_elem_idx < chunk_size; arr_elem_idx++) { - entry_sizes[i] += array_entry_sizes[arr_elem_idx]; - } - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset) { - const auto physical_type = v.GetType().InternalType(); - if (TypeIsConstantSize(physical_type)) { - const auto type_size = GetTypeIdSize(physical_type); - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += type_size; - } - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::STRUCT: - ComputeStructEntrySizes(v, entry_sizes, vcount, ser_count, sel, offset); - break; - case PhysicalType::LIST: - ComputeListEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::ARRAY: - ComputeArrayEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Column with variable size type %s cannot be serialized to row-format", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - ComputeEntrySizes(v, vdata, entry_sizes, vcount, ser_count, sel, offset); -} - -template -static void TemplatedHeapScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - - // set the validitymask - if (!vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStringVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto strings = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } - } - } else { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } else { - // set the validitymask - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStructVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto &children = StructVector::GetEntries(v); - idx_t num_children = children.size(); - - // struct must have a validitymask for its fields - const idx_t struct_validitymask_size = (num_children + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - // initialize the struct validity mask - struct_validitymask_locations[i] = key_locations[i]; - memset(struct_validitymask_locations[i], -1, struct_validitymask_size); - key_locations[i] += struct_validitymask_size; - - // set whether the whole struct is null - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - - // now serialize the struct vectors - for (idx_t i = 0; i < children.size(); i++) { - auto &struct_vector = *children[i]; - NestedValidity struct_validity(struct_validitymask_locations, i); - RowOperations::HeapScatter(struct_vector, vcount, sel, ser_count, key_locations, &struct_validity, offset); - } -} - -static void HeapScatterListVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - - UnifiedVectorFormat list_vdata; - child_vector.ToUnifiedFormat(ListVector::GetListSize(v), list_vdata); - auto child_type = ListType::GetChildType(v.GetType()).InternalType(); - - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (!vdata.validity.RowIsValid(source_idx)) { - if (parent_validity) { - // set the row validitymask for this column to invalid - parent_validity->SetInvalid(i); - } - continue; - } - auto list_entry = list_data[source_idx]; - - // store list length - Store(list_entry.length, key_locations[i]); - key_locations[i] += sizeof(list_entry.length); - - // make room for the validitymask - data_ptr_t list_validitymask_location = key_locations[i]; - idx_t entry_offset_in_byte = 0; - idx_t validitymask_size = (list_entry.length + 7) / 8; - memset(list_validitymask_location, -1, validitymask_size); - key_locations[i] += validitymask_size; - - // serialize size of each entry (if non-constant size) - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type)) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += list_entry.length * sizeof(idx_t); - } - - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // serialize list validity - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - auto list_idx = list_vdata.sel->get_index(entry_idx + entry_offset); - if (!list_vdata.validity.RowIsValid(list_idx)) { - *(list_validitymask_location) &= ~(1UL << entry_offset_in_byte); - } - if (++entry_offset_in_byte == 8) { - list_validitymask_location++; - entry_offset_in_byte = 0; - } - } - - if (TypeIsConstantSize(child_type)) { - // constant size list entries: set list entry locations - const idx_t type_size = GetTypeIdSize(child_type); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries: compute entry sizes and set list entry locations - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += list_entry_sizes[entry_idx]; - Store(list_entry_sizes[entry_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now serialize to the locations - RowOperations::HeapScatter(child_vector, ListVector::GetListSize(v), - *FlatVector::IncrementalSelectionVector(), next, list_entry_locations, nullptr, - entry_offset); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapScatterArrayVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_type = ArrayType::GetChildType(v.GetType()); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - UnifiedVectorFormat child_vdata; - child_vector.ToUnifiedFormat(ArrayVector::GetTotalSize(v), child_vdata); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - // Set if the whole array itself is null in the parent entry - auto source_idx = vdata.sel->get_index(sel.get_index(i) + offset); - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - - // Now we can serialize the array itself - // Every array starts with a validity mask for the children - data_ptr_t array_validitymask_location = key_locations[i]; - memset(array_validitymask_location, -1, array_validitymask_size); - key_locations[i] += array_validitymask_size; - - NestedValidity array_parent_validity(array_validitymask_location); - - // If the array contains variable size entries, we reserve spaces for them here - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // Then comes the elements - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - while (elem_remaining > 0) { - // the array elements can span multiple vectors, so we divide it into chunks - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // Setup the locations for the elements - if (child_type_is_var_size) { - // The elements are variable sized - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += array_entry_sizes[elem_idx]; - - // Now store the size of the entry - Store(array_entry_sizes[elem_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } else { - // The elements are constant sized - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - } - } - - RowOperations::HeapScatter(child_vector, ArrayVector::GetTotalSize(v), - *FlatVector::IncrementalSelectionVector(), chunk_size, array_entry_locations, - &array_parent_validity, array_start); - - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - array_parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset) { - if (TypeIsConstantSize(v.GetType().InternalType())) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - RowOperations::HeapScatterVData(vdata, v.GetType().InternalType(), sel, ser_count, key_locations, - parent_validity, offset); - } else { - switch (v.GetType().InternalType()) { - case PhysicalType::VARCHAR: - HeapScatterStringVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::STRUCT: - HeapScatterStructVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::LIST: - HeapScatterListVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::ARRAY: - HeapScatterArrayVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Serialization of variable length vector with type %s", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::FLOAT: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::DOUBLE: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INTERVAL: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - throw NotImplementedException("FIXME: Serialize to of constant type column to row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp deleted file mode 100644 index a85a71997..000000000 --- a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -template -void TemplatedRadixScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeData(key_locations[i] + 1, source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < sizeof(T) + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', sizeof(T)); - } - key_locations[i] += sizeof(T) + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeData(key_locations[i], source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < sizeof(T); s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += sizeof(T); - } - } -} - -void RadixScatterStringVector(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeStringDataPrefix(key_locations[i] + 1, source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < prefix_len + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', prefix_len); - } - key_locations[i] += prefix_len + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeStringDataPrefix(key_locations[i], source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < prefix_len; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += prefix_len; - } - } -} - -void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, const idx_t width, const idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - auto list_size = ListVector::GetListSize(v); - child_vector.Flatten(list_size); - - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - auto &list_entry = list_data[source_idx]; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 2, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 2); - key_location += width - 2; - } - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - auto &list_entry = list_data[source_idx]; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterArrayVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, array_offset); - - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width, array_offset); - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', width - 1); - } - key_locations[i]++; - } - width--; - } - // serialize the struct - auto &child_vector = *StructVector::GetEntries(v)[0]; - RowOperations::RadixScatter(child_vector, vcount, *FlatVector::IncrementalSelectionVector(), add_count, - key_locations, false, true, false, prefix_len, width, offset); - // invert bits if desc - if (desc) { - for (idx_t i = 0; i < add_count; i++) { - for (idx_t s = 0; s < width; s++) { - *(key_locations[i] - width + s) = ~*(key_locations[i] - width + s); - } - } - } -} - -void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, - idx_t prefix_len, idx_t width, idx_t offset) { -#ifdef DEBUG - // initialize to verify written width later - auto key_locations_copy = make_uniq_array(ser_count); - for (idx_t i = 0; i < ser_count; i++) { - key_locations_copy[i] = key_locations[i]; - } -#endif - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - switch (v.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::FLOAT: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::DOUBLE: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INTERVAL: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::VARCHAR: - RadixScatterStringVector(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, offset); - break; - case PhysicalType::LIST: - RadixScatterListVector(v, vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, width, - offset); - break; - case PhysicalType::STRUCT: - RadixScatterStructVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - case PhysicalType::ARRAY: - RadixScatterArrayVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - default: - throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); - } - -#ifdef DEBUG - for (idx_t i = 0; i < ser_count; i++) { - D_ASSERT(key_locations[i] == key_locations_copy[i] + width); - } -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_scatter.cpp b/src/duckdb/src/common/row_operations/row_scatter.cpp deleted file mode 100644 index 1912d2484..000000000 --- a/src/duckdb/src/common/row_operations/row_scatter.cpp +++ /dev/null @@ -1,230 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/selection_vector.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedScatter(UnifiedVectorFormat &col, Vector &rows, const SelectionVector &sel, const idx_t count, - const idx_t col_offset, const idx_t col_no, const idx_t col_count) { - auto data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - if (!col.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - auto isnull = !col.validity.RowIsValid(col_idx); - T store_value = isnull ? NullValue() : data[col_idx]; - Store(store_value, row + col_offset); - if (isnull) { - ValidityBytes col_mask(ptrs[idx], col_count); - col_mask.SetInvalidUnsafe(col_no); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - Store(data[col_idx], row + col_offset); - } - } -} - -static void ComputeStringEntrySizes(const UnifiedVectorFormat &col, idx_t entry_sizes[], const SelectionVector &sel, - const idx_t count, const idx_t offset = 0) { - auto data = UnifiedVectorFormat::GetData(col); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx) + offset; - const auto &str = data[col_idx]; - if (col.validity.RowIsValid(col_idx) && !str.IsInlined()) { - entry_sizes[i] += str.GetSize(); - } - } -} - -static void ScatterStringVector(UnifiedVectorFormat &col, Vector &rows, data_ptr_t str_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t col_count) { - auto string_data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - // Write out zero length to avoid swizzling problems. - const string_t null(nullptr, 0); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - if (!col.validity.RowIsValid(col_idx)) { - ValidityBytes col_mask(row, col_count); - col_mask.SetInvalidUnsafe(col_no); - Store(null, row + col_offset); - } else if (string_data[col_idx].IsInlined()) { - Store(string_data[col_idx], row + col_offset); - } else { - const auto &str = string_data[col_idx]; - string_t inserted(const_char_ptr_cast(str_locations[i]), UnsafeNumericCast(str.GetSize())); - memcpy(inserted.GetDataWriteable(), str.GetData(), str.GetSize()); - str_locations[i] += str.GetSize(); - inserted.Finalize(); - Store(inserted, row + col_offset); - } - } -} - -static void ScatterNestedVector(Vector &vec, UnifiedVectorFormat &col, Vector &rows, data_ptr_t data_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t vcount) { - // Store pointers to the data in the row - // Do this first because SerializeVector destroys the locations - auto ptrs = FlatVector::GetData(rows); - data_ptr_t validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto row = ptrs[idx]; - validitymask_locations[i] = row; - - Store(data_locations[i], row + col_offset); - } - - // Serialise the data - NestedValidity parent_validity(validitymask_locations, col_no); - RowOperations::HeapScatter(vec, vcount, sel, count, data_locations, &parent_validity); -} - -void RowOperations::Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count) { - if (count == 0) { - return; - } - - // Set the validity mask for each row before inserting data - idx_t column_count = layout.ColumnCount(); - auto ptrs = FlatVector::GetData(rows); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - ValidityBytes(row, column_count).SetAllValid(layout.ColumnCount()); - } - - const auto vcount = columns.size(); - auto &offsets = layout.GetOffsets(); - auto &types = layout.GetTypes(); - - // Compute the entry size of the variable size columns - vector handles; - data_ptr_t data_locations[STANDARD_VECTOR_SIZE]; - if (!layout.AllConstant()) { - idx_t entry_sizes[STANDARD_VECTOR_SIZE]; - std::fill_n(entry_sizes, count, sizeof(uint32_t)); - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - if (TypeIsConstantSize(types[col_no].InternalType())) { - continue; - } - - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - switch (types[col_no].InternalType()) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(col, entry_sizes, sel, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - RowOperations::ComputeEntrySizes(vec, col, entry_sizes, vcount, count, sel); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } - - // Build out the buffer space - handles = string_heap.Build(count, data_locations, entry_sizes); - - // Serialize information that is needed for swizzling if the computation goes out-of-core - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - // Pointer to this row in the heap block - Store(data_locations[i], row + heap_pointer_offset); - // Row size is stored in the heap in front of each row - Store(NumericCast(entry_sizes[i]), data_locations[i]); - data_locations[i] += sizeof(uint32_t); - } - } - - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - auto col_offset = offsets[col_no]; - - switch (types[col_no].InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::FLOAT: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::DOUBLE: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INTERVAL: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::VARCHAR: - ScatterStringVector(col, rows, data_locations, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - ScatterNestedVector(vec, col, rows, data_locations, sel, count, col_offset, col_no, vcount); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/comparators.cpp b/src/duckdb/src/common/sort/comparators.cpp deleted file mode 100644 index 4df4cccc4..000000000 --- a/src/duckdb/src/common/sort/comparators.cpp +++ /dev/null @@ -1,507 +0,0 @@ -#include "duckdb/common/sort/comparators.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) { - const auto &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - // Check if the blob is NULL - ValidityBytes row_mask(row_ptr, sort_layout.column_count); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - // Can't break a NULL tie - return false; - } - auto &row_layout = sort_layout.blob_layout; - if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) { - // Nested type, must be broken - return true; - } - const auto &tie_col_offset = row_layout.GetOffsets()[col_idx]; - auto tie_string = Load(row_ptr + tie_col_offset); - if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col] && tie_string.GetSize() > 0) { - // No need to break the tie - we already compared the full string - return false; - } - return true; -} - -int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) { - // Compare the sorting columns one by one - int comp_res = 0; - data_ptr_t l_ptr_offset = l_ptr; - data_ptr_t r_ptr_offset = r_ptr; - for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) { - comp_res = FastMemcmp(l_ptr_offset, r_ptr_offset, sort_layout.column_sizes[col_idx]); - if (comp_res == 0 && !sort_layout.constant_size[col_idx]) { - comp_res = BreakBlobTie(col_idx, left, right, sort_layout, external_sort); - } - if (comp_res != 0) { - break; - } - l_ptr_offset += sort_layout.column_sizes[col_idx]; - r_ptr_offset += sort_layout.column_sizes[col_idx]; - } - return comp_res; -} - -int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::VARCHAR: - return TemplatedCompareVal(l_ptr, r_ptr); - case PhysicalType::LIST: - case PhysicalType::ARRAY: - case PhysicalType::STRUCT: { - auto l_nested_ptr = Load(l_ptr); - auto r_nested_ptr = Load(r_ptr); - return CompareValAndAdvance(l_nested_ptr, r_nested_ptr, type, true); - } - default: - throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString()); - } -} - -int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external) { - data_ptr_t l_data_ptr = left.DataPtr(*left.sb->blob_sorting_data); - data_ptr_t r_data_ptr = right.DataPtr(*right.sb->blob_sorting_data); - if (!TieIsBreakable(tie_col, l_data_ptr, sort_layout) && !TieIsBreakable(tie_col, r_data_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return 0; - } - // Align the pointers - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - l_data_ptr += tie_col_offset; - r_data_ptr += tie_col_offset; - // Do the comparison - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const auto &type = sort_layout.blob_layout.GetTypes()[col_idx]; - int result; - if (external) { - // Store heap pointers - data_ptr_t l_heap_ptr = left.HeapPtr(*left.sb->blob_sorting_data); - data_ptr_t r_heap_ptr = right.HeapPtr(*right.sb->blob_sorting_data); - // Unswizzle offset to pointer - UnswizzleSingleValue(l_data_ptr, l_heap_ptr, type); - UnswizzleSingleValue(r_data_ptr, r_heap_ptr, type); - // Compare - result = CompareVal(l_data_ptr, r_data_ptr, type); - // Swizzle the pointers back to offsets - SwizzleSingleValue(l_data_ptr, l_heap_ptr, type); - SwizzleSingleValue(r_data_ptr, r_heap_ptr, type); - } else { - result = CompareVal(l_data_ptr, r_data_ptr, type); - } - return order * result; -} - -template -int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) { - const auto left_val = Load(left_ptr); - const auto right_val = Load(right_ptr); - if (Equals::Operation(left_val, right_val)) { - return 0; - } else if (LessThan::Operation(left_val, right_val)) { - return -1; - } else { - return 1; - } -} - -int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::FLOAT: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::DOUBLE: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INTERVAL: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::VARCHAR: - return CompareStringAndAdvance(l_ptr, r_ptr, valid); - case PhysicalType::LIST: - return CompareListAndAdvance(l_ptr, r_ptr, ListType::GetChildType(type), valid); - case PhysicalType::STRUCT: - return CompareStructAndAdvance(l_ptr, r_ptr, StructType::GetChildTypes(type), valid); - case PhysicalType::ARRAY: - return CompareArrayAndAdvance(l_ptr, r_ptr, ArrayType::GetChildType(type), valid, ArrayType::GetSize(type)); - default: - throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString()); - } -} - -template -int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) { - auto result = TemplatedCompareVal(left_ptr, right_ptr); - left_ptr += sizeof(T); - right_ptr += sizeof(T); - return result; -} - -int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) { - if (!valid) { - return 0; - } - uint32_t left_string_size = Load(left_ptr); - uint32_t right_string_size = Load(right_ptr); - left_ptr += sizeof(uint32_t); - right_ptr += sizeof(uint32_t); - auto memcmp_res = memcmp(const_char_ptr_cast(left_ptr), const_char_ptr_cast(right_ptr), - std::min(left_string_size, right_string_size)); - - left_ptr += left_string_size; - right_ptr += right_string_size; - - if (memcmp_res != 0) { - return memcmp_res; - } - if (left_string_size == right_string_size) { - return 0; - } - if (left_string_size < right_string_size) { - return -1; - } - return 1; -} - -int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid) { - idx_t count = types.size(); - // Load validity masks - ValidityBytes left_validity(left_ptr, types.size()); - ValidityBytes right_validity(right_ptr, types.size()); - left_ptr += (count + 7) / 8; - right_ptr += (count + 7) / 8; - // Initialize variables - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Compare - int comp_res = 0; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - auto &type = types[i].second; - if ((left_valid == right_valid) || TypeIsConstantSize(type.InternalType())) { - comp_res = CompareValAndAdvance(left_ptr, right_ptr, types[i].second, left_valid && valid); - } - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -int Comparators::CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid, idx_t array_size) { - if (!valid) { - return 0; - } - - // Load array validity masks - ValidityBytes left_validity(left_ptr, array_size); - ValidityBytes right_validity(right_ptr, array_size); - left_ptr += (array_size + 7) / 8; - right_ptr += (array_size + 7) / 8; - - int comp_res = 0; - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT8: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT128: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INTERVAL: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized array entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += array_size * sizeof(idx_t); - right_ptr += array_size * sizeof(idx_t); - for (idx_t i = 0; i < array_size; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareArrayAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - return comp_res; -} - -int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid) { - if (!valid) { - return 0; - } - // Load list lengths - auto left_len = Load(left_ptr); - auto right_len = Load(right_ptr); - left_ptr += sizeof(idx_t); - right_ptr += sizeof(idx_t); - // Load list validity masks - ValidityBytes left_validity(left_ptr, left_len); - ValidityBytes right_validity(right_ptr, right_len); - left_ptr += (left_len + 7) / 8; - right_ptr += (right_len + 7) / 8; - // Compare - int comp_res = 0; - idx_t count = MinValue(left_len, right_len); - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INTERVAL: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized list entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += left_len * sizeof(idx_t); - right_ptr += right_len * sizeof(idx_t); - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - // All values that we looped over were equal - if (comp_res == 0 && left_len != right_len) { - // Smaller lists first - if (left_len < right_len) { - comp_res = -1; - } else { - comp_res = 1; - } - } - return comp_res; -} - -template -int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const ValidityBytes &left_validity, const ValidityBytes &right_validity, - const idx_t &count) { - int comp_res = 0; - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - comp_res = TemplatedCompareAndAdvance(left_ptr, right_ptr); - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(heap_ptr + Load(data_ptr), data_ptr); -} - -void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(UnsafeNumericCast(Load(data_ptr) - heap_ptr), data_ptr); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/hashed_sort.cpp b/src/duckdb/src/common/sort/hashed_sort.cpp similarity index 100% rename from src/duckdb/src/common/sorting/hashed_sort.cpp rename to src/duckdb/src/common/sort/hashed_sort.cpp diff --git a/src/duckdb/src/common/sort/merge_sorter.cpp b/src/duckdb/src/common/sort/merge_sorter.cpp deleted file mode 100644 index c670fd574..000000000 --- a/src/duckdb/src/common/sort/merge_sorter.cpp +++ /dev/null @@ -1,667 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -MergeSorter::MergeSorter(GlobalSortState &state, BufferManager &buffer_manager) - : state(state), buffer_manager(buffer_manager), sort_layout(state.sort_layout) { -} - -void MergeSorter::PerformInMergeRound() { - while (true) { - // Check for interrupts after merging a partition - if (state.context.interrupted) { - throw InterruptException(); - } - { - lock_guard pair_guard(state.lock); - if (state.pair_idx == state.num_pairs) { - break; - } - GetNextPartition(); - } - MergePartition(); - } -} - -void MergeSorter::MergePartition() { - auto &left_block = *left->sb; - auto &right_block = *right->sb; -#ifdef DEBUG - D_ASSERT(left_block.radix_sorting_data.size() == left_block.payload_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.payload_data->data_blocks.size()); - if (!state.payload_layout.AllConstant() && state.external) { - D_ASSERT(left_block.payload_data->data_blocks.size() == left_block.payload_data->heap_blocks.size()); - D_ASSERT(right_block.payload_data->data_blocks.size() == right_block.payload_data->heap_blocks.size()); - } - if (!sort_layout.all_constant) { - D_ASSERT(left_block.radix_sorting_data.size() == left_block.blob_sorting_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.blob_sorting_data->data_blocks.size()); - if (state.external) { - D_ASSERT(left_block.blob_sorting_data->data_blocks.size() == - left_block.blob_sorting_data->heap_blocks.size()); - D_ASSERT(right_block.blob_sorting_data->data_blocks.size() == - right_block.blob_sorting_data->heap_blocks.size()); - } - } -#endif - // Set up the write block - // Each merge task produces a SortedBlock with exactly state.block_capacity rows or less - result->InitializeWrite(); - // Initialize arrays to store merge data - bool left_smaller[STANDARD_VECTOR_SIZE]; - idx_t next_entry_sizes[STANDARD_VECTOR_SIZE]; - // Merge loop -#ifdef DEBUG - auto l_count = left->Remaining(); - auto r_count = right->Remaining(); -#endif - while (true) { - auto l_remaining = left->Remaining(); - auto r_remaining = right->Remaining(); - if (l_remaining + r_remaining == 0) { - // Done - break; - } - const idx_t next = MinValue(l_remaining + r_remaining, (idx_t)STANDARD_VECTOR_SIZE); - if (l_remaining != 0 && r_remaining != 0) { - // Compute the merge (not needed if one side is exhausted) - ComputeMerge(next, left_smaller); - } - // Actually merge the data (radix, blob, and payload) - MergeRadix(next, left_smaller); - if (!sort_layout.all_constant) { - MergeData(*result->blob_sorting_data, *left_block.blob_sorting_data, *right_block.blob_sorting_data, next, - left_smaller, next_entry_sizes, true); - D_ASSERT(result->radix_sorting_data.size() == result->blob_sorting_data->data_blocks.size()); - } - MergeData(*result->payload_data, *left_block.payload_data, *right_block.payload_data, next, left_smaller, - next_entry_sizes, false); - D_ASSERT(result->radix_sorting_data.size() == result->payload_data->data_blocks.size()); - } -#ifdef DEBUG - D_ASSERT(result->Count() == l_count + r_count); -#endif -} - -void MergeSorter::GetNextPartition() { - // Create result block - state.sorted_blocks_temp[state.pair_idx].push_back(make_uniq(buffer_manager, state)); - result = state.sorted_blocks_temp[state.pair_idx].back().get(); - // Determine which blocks must be merged - auto &left_block = *state.sorted_blocks[state.pair_idx * 2]; - auto &right_block = *state.sorted_blocks[state.pair_idx * 2 + 1]; - const idx_t l_count = left_block.Count(); - const idx_t r_count = right_block.Count(); - // Initialize left and right reader - left = make_uniq(buffer_manager, state); - right = make_uniq(buffer_manager, state); - // Compute the work that this thread must do using Merge Path - idx_t l_end; - idx_t r_end; - if (state.l_start + state.r_start + state.block_capacity < l_count + r_count) { - left->sb = state.sorted_blocks[state.pair_idx * 2].get(); - right->sb = state.sorted_blocks[state.pair_idx * 2 + 1].get(); - const idx_t intersection = state.l_start + state.r_start + state.block_capacity; - GetIntersection(intersection, l_end, r_end); - D_ASSERT(l_end <= l_count); - D_ASSERT(r_end <= r_count); - D_ASSERT(intersection == l_end + r_end); - } else { - l_end = l_count; - r_end = r_count; - } - // Create slices of the data that this thread must merge - left->SetIndices(0, 0); - right->SetIndices(0, 0); - left_input = left_block.CreateSlice(state.l_start, l_end, left->entry_idx); - right_input = right_block.CreateSlice(state.r_start, r_end, right->entry_idx); - left->sb = left_input.get(); - right->sb = right_input.get(); - state.l_start = l_end; - state.r_start = r_end; - D_ASSERT(left->Remaining() + right->Remaining() == state.block_capacity || (l_end == l_count && r_end == r_count)); - // Update global state - if (state.l_start == l_count && state.r_start == r_count) { - // Delete references to previous pair - state.sorted_blocks[state.pair_idx * 2] = nullptr; - state.sorted_blocks[state.pair_idx * 2 + 1] = nullptr; - // Advance pair - state.pair_idx++; - state.l_start = 0; - state.r_start = 0; - } -} - -int MergeSorter::CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx) { - D_ASSERT(l_idx < l.sb->Count()); - D_ASSERT(r_idx < r.sb->Count()); - - // Easy comparison using the previous result (intersections must increase monotonically) - if (l_idx < state.l_start) { - return -1; - } - if (r_idx < state.r_start) { - return 1; - } - - l.sb->GlobalToLocalIndex(l_idx, l.block_idx, l.entry_idx); - r.sb->GlobalToLocalIndex(r_idx, r.block_idx, r.entry_idx); - - l.PinRadix(l.block_idx); - r.PinRadix(r.block_idx); - data_ptr_t l_ptr = l.radix_handle.Ptr() + l.entry_idx * sort_layout.entry_size; - data_ptr_t r_ptr = r.radix_handle.Ptr() + r.entry_idx * sort_layout.entry_size; - - int comp_res; - if (sort_layout.all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, sort_layout.comparison_size); - } else { - l.PinData(*l.sb->blob_sorting_data); - r.PinData(*r.sb->blob_sorting_data); - comp_res = Comparators::CompareTuple(l, r, l_ptr, r_ptr, sort_layout, state.external); - } - return comp_res; -} - -void MergeSorter::GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx) { - const idx_t l_count = left->sb->Count(); - const idx_t r_count = right->sb->Count(); - // Cover some edge cases - // Code coverage off because these edge cases cannot happen unless other code changes - // Edge cases have been tested extensively while developing Merge Path in a script - // LCOV_EXCL_START - if (diagonal >= l_count + r_count) { - l_idx = l_count; - r_idx = r_count; - return; - } else if (diagonal == 0) { - l_idx = 0; - r_idx = 0; - return; - } else if (l_count == 0) { - l_idx = 0; - r_idx = diagonal; - return; - } else if (r_count == 0) { - r_idx = 0; - l_idx = diagonal; - return; - } - // LCOV_EXCL_STOP - // Determine offsets for the binary search - const idx_t l_offset = MinValue(l_count, diagonal); - const idx_t r_offset = diagonal > l_count ? diagonal - l_count : 0; - D_ASSERT(l_offset + r_offset == diagonal); - const idx_t search_space = diagonal > MaxValue(l_count, r_count) ? l_count + r_count - diagonal - : MinValue(diagonal, MinValue(l_count, r_count)); - // Double binary search - idx_t li = 0; - idx_t ri = search_space - 1; - idx_t middle; - int comp_res; - while (li <= ri) { - middle = (li + ri) / 2; - l_idx = l_offset - middle; - r_idx = r_offset + middle; - if (l_idx == l_count || r_idx == 0) { - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (comp_res > 0) { - l_idx--; - r_idx++; - } else { - return; - } - if (l_idx == 0 || r_idx == r_count) { - // This case is incredibly difficult to cover as it is dependent on parallelism randomness - // But it has been tested extensively during development in a script - // LCOV_EXCL_START - return; - // LCOV_EXCL_STOP - } else { - break; - } - } - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx); - if (comp_res > 0) { - li = middle + 1; - } else { - ri = middle - 1; - } - } - int l_r_min1 = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx - 1); - int l_min1_r = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (l_r_min1 > 0 && l_min1_r < 0) { - return; - } else if (l_r_min1 > 0) { - l_idx--; - r_idx++; - } else if (l_min1_r < 0) { - l_idx++; - r_idx--; - } -} - -void MergeSorter::ComputeMerge(const idx_t &count, bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - auto &l_sorted_block = *l.sb; - auto &r_sorted_block = *r.sb; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - // Data pointers for both sides - data_ptr_t l_radix_ptr; - data_ptr_t r_radix_ptr; - // Compute the merge of the next 'count' tuples - idx_t compared = 0; - while (compared < count) { - // Move to the next block (if needed) - if (l.block_idx < l_sorted_block.radix_sorting_data.size() && - l.entry_idx == l_sorted_block.radix_sorting_data[l.block_idx]->count) { - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_sorted_block.radix_sorting_data.size() && - r.entry_idx == r_sorted_block.radix_sorting_data[r.block_idx]->count) { - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_sorted_block.radix_sorting_data.size(); - const bool r_done = r.block_idx == r_sorted_block.radix_sorting_data.size(); - if (l_done || r_done) { - // One of the sides is exhausted, no need to compare - break; - } - // Pin the radix sorting data - left->PinRadix(l.block_idx); - l_radix_ptr = left->RadixPtr(); - right->PinRadix(r.block_idx); - r_radix_ptr = right->RadixPtr(); - - const idx_t l_count = l_sorted_block.radix_sorting_data[l.block_idx]->count; - const idx_t r_count = r_sorted_block.radix_sorting_data[r.block_idx]->count; - // Compute the merge - if (sort_layout.all_constant) { - // All sorting columns are constant size - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = FastMemcmp(l_radix_ptr, r_radix_ptr, sort_layout.comparison_size) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } else { - // Pin the blob data - left->PinData(*l_sorted_block.blob_sorting_data); - right->PinData(*r_sorted_block.blob_sorting_data); - // Merge with variable size sorting columns - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = - Comparators::CompareTuple(*left, *right, l_radix_ptr, r_radix_ptr, sort_layout, state.external) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeRadix(const idx_t &count, const bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - auto &l_blocks = l.sb->radix_sorting_data; - auto &r_blocks = r.sb->radix_sorting_data; - RowDataBlock *l_block = nullptr; - RowDataBlock *r_block = nullptr; - - data_ptr_t l_ptr; - data_ptr_t r_ptr; - - RowDataBlock *result_block = result->radix_sorting_data.back().get(); - auto result_handle = buffer_manager.Pin(result_block->block); - data_ptr_t result_ptr = result_handle.Ptr() + result_block->count * sort_layout.entry_size; - - idx_t copied = 0; - while (copied < count) { - // Move to the next block (if needed) - if (l.block_idx < l_blocks.size() && l.entry_idx == l_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_blocks[l.block_idx]->block = nullptr; - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_blocks.size() && r.entry_idx == r_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_blocks[r.block_idx]->block = nullptr; - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_blocks.size(); - const bool r_done = r.block_idx == r_blocks.size(); - // Pin the radix sortable blocks - idx_t l_count; - if (!l_done) { - l_block = l_blocks[l.block_idx].get(); - left->PinRadix(l.block_idx); - l_ptr = l.RadixPtr(); - l_count = l_block->count; - } else { - l_count = 0; - } - idx_t r_count; - if (!r_done) { - r_block = r_blocks[r.block_idx].get(); - r.PinRadix(r.block_idx); - r_ptr = r.RadixPtr(); - r_count = r_block->count; - } else { - r_count = 0; - } - // Copy using computed merge - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_block, result_ptr, - sort_layout.entry_size, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - const auto &layout = result_data.layout; - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - - // Left and right row data to merge - data_ptr_t l_ptr; - data_ptr_t r_ptr; - // Accompanying left and right heap data (if needed) - data_ptr_t l_heap_ptr; - data_ptr_t r_heap_ptr; - - // Result rows to write to - RowDataBlock *result_data_block = result_data.data_blocks.back().get(); - auto result_data_handle = buffer_manager.Pin(result_data_block->block); - data_ptr_t result_data_ptr = result_data_handle.Ptr() + result_data_block->count * row_width; - // Result heap to write to (if needed) - RowDataBlock *result_heap_block = nullptr; - BufferHandle result_heap_handle; - data_ptr_t result_heap_ptr; - if (!layout.AllConstant() && state.external) { - result_heap_block = result_data.heap_blocks.back().get(); - result_heap_handle = buffer_manager.Pin(result_heap_block->block); - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - - idx_t copied = 0; - while (copied < count) { - // Move to new data blocks (if needed) - if (l.block_idx < l_data.data_blocks.size() && l.entry_idx == l_data.data_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_data.data_blocks[l.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - l_data.heap_blocks[l.block_idx]->block = nullptr; - } - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_data.data_blocks.size() && r.entry_idx == r_data.data_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_data.data_blocks[r.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - r_data.heap_blocks[r.block_idx]->block = nullptr; - } - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_data.data_blocks.size(); - const bool r_done = r.block_idx == r_data.data_blocks.size(); - // Pin the row data blocks - if (!l_done) { - l.PinData(l_data); - l_ptr = l.DataPtr(l_data); - } - if (!r_done) { - r.PinData(r_data); - r_ptr = r.DataPtr(r_data); - } - const idx_t &l_count = !l_done ? l_data.data_blocks[l.block_idx]->count : 0; - const idx_t &r_count = !r_done ? r_data.data_blocks[r.block_idx]->count : 0; - // Perform the merge - if (layout.AllConstant() || !state.external) { - // If all constant size, or if we are doing an in-memory sort, we do not need to touch the heap - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, - row_width, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_data_block, result_data_ptr, row_width, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, row_width, copied, count); - } - } else { - // External sorting with variable size data. Pin the heap blocks too - if (!l_done) { - l_heap_ptr = l.BaseHeapPtr(l_data) + Load(l_ptr + heap_pointer_offset); - D_ASSERT(l_heap_ptr - l.BaseHeapPtr(l_data) >= 0); - D_ASSERT((idx_t)(l_heap_ptr - l.BaseHeapPtr(l_data)) < l_data.heap_blocks[l.block_idx]->byte_offset); - } - if (!r_done) { - r_heap_ptr = r.BaseHeapPtr(r_data) + Load(r_ptr + heap_pointer_offset); - D_ASSERT(r_heap_ptr - r.BaseHeapPtr(r_data) >= 0); - D_ASSERT((idx_t)(r_heap_ptr - r.BaseHeapPtr(r_data)) < r_data.heap_blocks[r.block_idx]->byte_offset); - } - // Both the row and heap data need to be dealt with - if (!l_done && !r_done) { - // Both sides have data - merge - idx_t l_idx_copy = l.entry_idx; - idx_t r_idx_copy = r.entry_idx; - data_ptr_t result_data_ptr_copy = result_data_ptr; - idx_t copied_copy = copied; - // Merge row data - MergeRows(l_ptr, l_idx_copy, l_count, r_ptr, r_idx_copy, r_count, *result_data_block, - result_data_ptr_copy, row_width, left_smaller, copied_copy, count); - const idx_t merged = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t l_heap_ptr_copy = l_heap_ptr; - data_ptr_t r_heap_ptr_copy = r_heap_ptr; - for (idx_t i = 0; i < merged; i++) { - // Store base heap offset in the row data - Store(result_heap_block->byte_offset + copy_bytes, result_data_ptr + heap_pointer_offset); - result_data_ptr += row_width; - // Compute entry size and add to total - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - auto &entry_size = next_entry_sizes[copied + i]; - entry_size = - l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - D_ASSERT(NumericCast(l_heap_ptr_copy - l.BaseHeapPtr(l_data)) + l_smaller * entry_size <= - l_data.heap_blocks[l.block_idx]->byte_offset); - D_ASSERT(NumericCast(r_heap_ptr_copy - r.BaseHeapPtr(r_data)) + r_smaller * entry_size <= - r_data.heap_blocks[r.block_idx]->byte_offset); - l_heap_ptr_copy += l_smaller * entry_size; - r_heap_ptr_copy += r_smaller * entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (result_heap_block->byte_offset + copy_bytes > result_heap_block->capacity) { - idx_t new_capacity = result_heap_block->byte_offset + copy_bytes; - buffer_manager.ReAllocate(result_heap_block->block, new_capacity); - result_heap_block->capacity = new_capacity; - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - D_ASSERT(result_heap_block->byte_offset + copy_bytes <= result_heap_block->capacity); - // Now copy the heap data - for (idx_t i = 0; i < merged; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - const auto &entry_size = next_entry_sizes[copied + i]; - memcpy(result_heap_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_heap_ptr) + - r_smaller * CastPointerToValue(r_heap_ptr)), - entry_size); - D_ASSERT(Load(result_heap_ptr) == entry_size); - result_heap_ptr += entry_size; - l_heap_ptr += l_smaller * entry_size; - r_heap_ptr += r_smaller * entry_size; - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - } - // Update result indices and pointers - result_heap_block->count += merged; - result_heap_block->byte_offset += copy_bytes; - copied += merged; - } else if (r_done) { - // Right side is exhausted - flush left - FlushBlobs(layout, l_count, l_ptr, l.entry_idx, l_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } else { - // Left side is exhausted - flush right - FlushBlobs(layout, r_count, r_ptr, r.entry_idx, r_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } - D_ASSERT(result_data_block->count == result_heap_block->count); - } - } - if (reset_indices) { - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); - } -} - -void MergeSorter::MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, - idx_t &r_entry_idx, const idx_t &r_count, RowDataBlock &target_block, - data_ptr_t &target_ptr, const idx_t &entry_size, const bool left_smaller[], idx_t &copied, - const idx_t &count) { - const idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - idx_t i; - for (i = 0; i < next && l_entry_idx < l_count && r_entry_idx < r_count; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to copy an entry from either side - FastMemcpy( - target_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_ptr) + r_smaller * CastPointerToValue(r_ptr)), - entry_size); - target_ptr += entry_size; - // Use the comparison bool to increment entries and pointers - l_entry_idx += l_smaller; - r_entry_idx += r_smaller; - l_ptr += l_smaller * entry_size; - r_ptr += r_smaller * entry_size; - } - // Update counts - target_block.count += i; - copied += i; -} - -void MergeSorter::FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count) { - // Compute how many entries we can fit - idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - next = MinValue(next, source_count - source_entry_idx); - // Copy them all in a single memcpy - const idx_t copy_bytes = next * entry_size; - memcpy(target_ptr, source_ptr, copy_bytes); - target_ptr += copy_bytes; - source_ptr += copy_bytes; - // Update counts - source_entry_idx += next; - target_block.count += next; - copied += next; -} - -void MergeSorter::FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, - BufferHandle &target_heap_handle, data_ptr_t &target_heap_ptr, idx_t &copied, - const idx_t &count) { - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - idx_t source_entry_idx_copy = source_entry_idx; - data_ptr_t target_data_ptr_copy = target_data_ptr; - idx_t copied_copy = copied; - // Flush row data - FlushRows(source_data_ptr, source_entry_idx_copy, source_count, target_data_block, target_data_ptr_copy, row_width, - copied_copy, count); - const idx_t flushed = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t source_heap_ptr_copy = source_heap_ptr; - for (idx_t i = 0; i < flushed; i++) { - // Store base heap offset in the row data - Store(target_heap_block.byte_offset + copy_bytes, target_data_ptr + heap_pointer_offset); - target_data_ptr += row_width; - // Compute entry size and add to total - auto entry_size = Load(source_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - source_heap_ptr_copy += entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (target_heap_block.byte_offset + copy_bytes > target_heap_block.capacity) { - idx_t new_capacity = target_heap_block.byte_offset + copy_bytes; - buffer_manager.ReAllocate(target_heap_block.block, new_capacity); - target_heap_block.capacity = new_capacity; - target_heap_ptr = target_heap_handle.Ptr() + target_heap_block.byte_offset; - } - D_ASSERT(target_heap_block.byte_offset + copy_bytes <= target_heap_block.capacity); - // Copy the heap data in one go - memcpy(target_heap_ptr, source_heap_ptr, copy_bytes); - target_heap_ptr += copy_bytes; - source_heap_ptr += copy_bytes; - source_entry_idx += flushed; - copied += flushed; - // Update result indices and pointers - target_heap_block.count += flushed; - target_heap_block.byte_offset += copy_bytes; - D_ASSERT(target_heap_block.byte_offset <= target_heap_block.capacity); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/radix_sort.cpp b/src/duckdb/src/common/sort/radix_sort.cpp deleted file mode 100644 index b193cee61..000000000 --- a/src/duckdb/src/common/sort/radix_sort.cpp +++ /dev/null @@ -1,352 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/duckdb_pdqsort.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -//! Calls std::sort on strings that are tied by their prefix after the radix sort -static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &start, const idx_t &end, - const idx_t &tie_col, bool *ties, const data_ptr_t blob_ptr, const SortLayout &sort_layout) { - const auto row_width = sort_layout.blob_layout.GetRowWidth(); - // Locate the first blob row in question - data_ptr_t row_ptr = dataptr + start * sort_layout.entry_size; - data_ptr_t blob_row_ptr = blob_ptr + Load(row_ptr + sort_layout.comparison_size) * row_width; - if (!Comparators::TieIsBreakable(tie_col, blob_row_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return; - } - // Fill pointer array for sorting - auto ptr_block = make_unsafe_uniq_array_uninitialized(end - start); - auto entry_ptrs = (data_ptr_t *)ptr_block.get(); - for (idx_t i = start; i < end; i++) { - entry_ptrs[i - start] = row_ptr; - row_ptr += sort_layout.entry_size; - } - // Slow pointer-based sorting - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - auto logical_type = sort_layout.blob_layout.GetTypes()[col_idx]; - std::sort(entry_ptrs, entry_ptrs + end - start, - [&blob_ptr, &order, &sort_layout, &tie_col_offset, &row_width, &logical_type](const data_ptr_t l, - const data_ptr_t r) { - idx_t left_idx = Load(l + sort_layout.comparison_size); - idx_t right_idx = Load(r + sort_layout.comparison_size); - data_ptr_t left_ptr = blob_ptr + left_idx * row_width + tie_col_offset; - data_ptr_t right_ptr = blob_ptr + right_idx * row_width + tie_col_offset; - return order * Comparators::CompareVal(left_ptr, right_ptr, logical_type) < 0; - }); - // Re-order - auto temp_block = buffer_manager.GetBufferAllocator().Allocate((end - start) * sort_layout.entry_size); - data_ptr_t temp_ptr = temp_block.get(); - for (idx_t i = 0; i < end - start; i++) { - FastMemcpy(temp_ptr, entry_ptrs[i], sort_layout.entry_size); - temp_ptr += sort_layout.entry_size; - } - memcpy(dataptr + start * sort_layout.entry_size, temp_block.get(), (end - start) * sort_layout.entry_size); - // Determine if there are still ties (if this is not the last column) - if (tie_col < sort_layout.column_count - 1) { - data_ptr_t idx_ptr = dataptr + start * sort_layout.entry_size + sort_layout.comparison_size; - // Load current entry - data_ptr_t current_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - for (idx_t i = 0; i < end - start - 1; i++) { - // Load next entry and compare - idx_ptr += sort_layout.entry_size; - data_ptr_t next_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - ties[start + i] = Comparators::CompareVal(current_ptr, next_ptr, logical_type) == 0; - current_ptr = next_ptr; - } - } -} - -//! Identifies sequences of rows that are tied by the prefix of a blob column, and sorts them -static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool *ties, data_ptr_t dataptr, - const idx_t &count, const idx_t &tie_col, const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - auto &blob_block = *sb.blob_sorting_data->data_blocks.back(); - auto blob_handle = buffer_manager.Pin(blob_block.block); - const data_ptr_t blob_ptr = blob_handle.Ptr(); - - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - SortTiedBlobs(buffer_manager, dataptr, i, j + 1, tie_col, ties, blob_ptr, sort_layout); - i = j; - } -} - -//! Returns whether there are any 'true' values in the ties[] array -static bool AnyTies(bool ties[], const idx_t &count) { - D_ASSERT(!ties[count - 1]); - bool any_ties = false; - for (idx_t i = 0; i < count - 1; i++) { - any_ties = any_ties || ties[i]; - } - return any_ties; -} - -//! Compares subsequent rows to check for ties -static void ComputeTies(data_ptr_t dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &tie_size, - bool ties[], const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - D_ASSERT(col_offset + tie_size <= sort_layout.comparison_size); - // Align dataptr - dataptr += col_offset; - for (idx_t i = 0; i < count - 1; i++) { - ties[i] = ties[i] && FastMemcmp(dataptr, dataptr + sort_layout.entry_size, tie_size) == 0; - dataptr += sort_layout.entry_size; - } -} - -//! Textbook LSD radix sort -void RadixSortLSD(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &sorting_size) { - auto temp_block = buffer_manager.GetBufferAllocator().Allocate(count * row_width); - bool swap = false; - - idx_t counts[SortConstants::VALUES_PER_RADIX]; - for (idx_t r = 1; r <= sorting_size; r++) { - // Init counts to 0 - memset(counts, 0, sizeof(counts)); - // Const some values for convenience - const data_ptr_t source_ptr = swap ? temp_block.get() : dataptr; - const data_ptr_t target_ptr = swap ? dataptr : temp_block.get(); - const idx_t offset = col_offset + sorting_size - r; - // Collect counts - data_ptr_t offset_ptr = source_ptr + offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute offsets from counts - idx_t max_count = counts[0]; - for (idx_t val = 1; val < SortConstants::VALUES_PER_RADIX; val++) { - max_count = MaxValue(max_count, counts[val]); - counts[val] = counts[val] + counts[val - 1]; - } - if (max_count == count) { - continue; - } - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr + (count - 1) * row_width; - for (idx_t i = 0; i < count; i++) { - idx_t &radix_offset = --counts[*(row_ptr + offset)]; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr -= row_width; - } - swap = !swap; - } - // Move data back to original buffer (if it was swapped) - if (swap) { - memcpy(dataptr, temp_block.get(), count * row_width); - } -} - -//! Insertion sort, used when count of values is low -inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, - const idx_t &col_offset, const idx_t &row_width, const idx_t &total_comp_width, - const idx_t &offset, bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - if (count > 1) { - const idx_t total_offset = col_offset + offset; - auto temp_val = make_unsafe_uniq_array_uninitialized(row_width); - const data_ptr_t val = temp_val.get(); - const auto comp_width = total_comp_width - offset; - for (idx_t i = 1; i < count; i++) { - FastMemcpy(val, source_ptr + i * row_width, row_width); - idx_t j = i; - while (j > 0 && - FastMemcmp(source_ptr + (j - 1) * row_width + total_offset, val + total_offset, comp_width) > 0) { - FastMemcpy(source_ptr + j * row_width, source_ptr + (j - 1) * row_width, row_width); - j--; - } - FastMemcpy(source_ptr + j * row_width, val, row_width); - } - } - if (swap) { - memcpy(target_ptr, source_ptr, count * row_width); - } -} - -//! MSD radix sort that switches to insertion sort with low bucket sizes -void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &comp_width, const idx_t &offset, idx_t locations[], bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - // Init counts to 0 - memset(locations, 0, SortConstants::MSD_RADIX_LOCATIONS * sizeof(idx_t)); - idx_t *counts = locations + 1; - // Collect counts - const idx_t total_offset = col_offset + offset; - data_ptr_t offset_ptr = source_ptr + total_offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute locations from counts - idx_t max_count = 0; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - max_count = MaxValue(max_count, counts[radix]); - counts[radix] += locations[radix]; - } - if (max_count != count) { - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr; - for (idx_t i = 0; i < count; i++) { - const idx_t &radix_offset = locations[*(row_ptr + total_offset)]++; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr += row_width; - } - swap = !swap; - } - // Check if done - if (offset == comp_width - 1) { - if (swap) { - memcpy(orig_ptr, temp_ptr, count * row_width); - } - return; - } - if (max_count == count) { - RadixSortMSD(orig_ptr, temp_ptr, count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - return; - } - // Recurse - idx_t radix_count = locations[0]; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - const idx_t loc = (locations[radix] - radix_count) * row_width; - if (radix_count > SortConstants::INSERTION_SORT_THRESHOLD) { - RadixSortMSD(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - } else if (radix_count != 0) { - InsertionSort(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - swap); - } - radix_count = locations[radix + 1] - locations[radix]; - } -} - -//! Calls different sort functions, depending on the count and sorting sizes -void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { - - if (contains_string) { - auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); - auto end = begin + count; - duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); - return duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); - } - - if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { - return InsertionSort(dataptr, nullptr, count, col_offset, sort_layout.entry_size, sorting_size, 0, false); - } - - if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { - return RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); - } - - const auto block_size = buffer_manager.GetBlockSize(); - auto temp_block = - buffer_manager.Allocate(MemoryTag::ORDER_BY, MaxValue(count * sort_layout.entry_size, block_size)); - auto pre_allocated_array = - make_unsafe_uniq_array_uninitialized(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); - RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, - pre_allocated_array.get(), false); -} - -//! Identifies sequences of rows that are tied, and calls radix sort on these -static void SubSortTiedTuples(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &count, - const idx_t &col_offset, const idx_t &sorting_size, bool ties[], - const SortLayout &sort_layout, bool contains_string) { - D_ASSERT(!ties[count - 1]); - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - RadixSort(buffer_manager, dataptr + i * sort_layout.entry_size, j - i + 1, col_offset, sorting_size, - sort_layout, contains_string); - i = j; - } -} - -void LocalSortState::SortInMemory() { - auto &sb = *sorted_blocks.back(); - auto &block = *sb.radix_sorting_data.back(); - const auto &count = block.count; - auto handle = buffer_manager->Pin(block.block); - const auto dataptr = handle.Ptr(); - // Assign an index to each row - data_ptr_t idx_dataptr = dataptr + sort_layout->comparison_size; - for (uint32_t i = 0; i < count; i++) { - Store(i, idx_dataptr); - idx_dataptr += sort_layout->entry_size; - } - // Radix sort and break ties until no more ties, or until all columns are sorted - idx_t sorting_size = 0; - idx_t col_offset = 0; - unsafe_unique_array ties_ptr; - bool *ties = nullptr; - bool contains_string = false; - for (idx_t i = 0; i < sort_layout->column_count; i++) { - sorting_size += sort_layout->column_sizes[i]; - contains_string = contains_string || sort_layout->logical_types[i].InternalType() == PhysicalType::VARCHAR; - if (sort_layout->constant_size[i] && i < sort_layout->column_count - 1) { - // Add columns to the sorting size until we reach a variable size column, or the last column - continue; - } - - if (!ties) { - // This is the first sort - RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); - ties_ptr = make_unsafe_uniq_array_uninitialized(count); - ties = ties_ptr.get(); - std::fill_n(ties, count - 1, true); - ties[count - 1] = false; - } else { - // For subsequent sorts, we only have to subsort the tied tuples - SubSortTiedTuples(*buffer_manager, dataptr, count, col_offset, sorting_size, ties, *sort_layout, - contains_string); - } - - contains_string = false; - - if (sort_layout->constant_size[i] && i == sort_layout->column_count - 1) { - // All columns are sorted, no ties to break because last column is constant size - break; - } - - ComputeTies(dataptr, count, col_offset, sorting_size, ties, *sort_layout); - if (!AnyTies(ties, count)) { - // No ties, stop sorting - break; - } - - if (!sort_layout->constant_size[i]) { - SortTiedBlobs(*buffer_manager, sb, ties, dataptr, count, i, *sort_layout); - if (!AnyTies(ties, count)) { - // No more ties after tie-breaking, stop - break; - } - } - - col_offset += sorting_size; - sorting_size = 0; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sort.cpp b/src/duckdb/src/common/sort/sort.cpp similarity index 96% rename from src/duckdb/src/common/sorting/sort.cpp rename to src/duckdb/src/common/sort/sort.cpp index 56bde8499..8f2a1e6e7 100644 --- a/src/duckdb/src/common/sorting/sort.cpp +++ b/src/duckdb/src/common/sort/sort.cpp @@ -377,6 +377,15 @@ class SortGlobalSourceState : public GlobalSourceState { return merger_global_state ? merger_global_state->MaxThreads() : 1; } + void Destroy() { + if (!merger_global_state) { + return; + } + auto guard = merger_global_state->Lock(); + merger.sorted_runs.clear(); + sink.temporary_memory_state.reset(); + } + public: //! The global sink state SortGlobalSinkState &sink; @@ -476,16 +485,26 @@ SourceResultType Sort::MaterializeColumnData(ExecutionContext &context, Operator } // Merge into global output collection - auto guard = gstate.Lock(); - if (!gstate.column_data) { - gstate.column_data = std::move(local_column_data); - } else { - gstate.column_data->Merge(*local_column_data); + { + auto guard = gstate.Lock(); + if (!gstate.column_data) { + gstate.column_data = std::move(local_column_data); + } else { + gstate.column_data->Merge(*local_column_data); + } } + // Destroy local state before returning + input.local_state.Cast().merger_local_state.reset(); + // Return type indicates whether materialization is done const auto progress_data = GetProgress(context.client, input.global_state); - return progress_data.done == progress_data.total ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; + if (progress_data.done == progress_data.total) { + // Destroy global state before returning + gstate.Destroy(); + return SourceResultType::FINISHED; + } + return SourceResultType::HAVE_MORE_OUTPUT; } unique_ptr Sort::GetColumnData(OperatorSourceInput &input) const { diff --git a/src/duckdb/src/common/sort/sort_state.cpp b/src/duckdb/src/common/sort/sort_state.cpp deleted file mode 100644 index 369f032f1..000000000 --- a/src/duckdb/src/common/sort/sort_state.cpp +++ /dev/null @@ -1,487 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" - -#include -#include - -namespace duckdb { - -idx_t GetNestedSortingColSize(idx_t &col_size, const LogicalType &type) { - auto physical_type = type.InternalType(); - if (TypeIsConstantSize(physical_type)) { - col_size += GetTypeIdSize(physical_type); - return 0; - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: { - // Nested strings are between 4 and 11 chars long for alignment - auto size_before_str = col_size; - col_size += 11; - col_size -= (col_size - 12) % 8; - return col_size - size_before_str; - } - case PhysicalType::LIST: - // Lists get 2 bytes (null and empty list) - col_size += 2; - return GetNestedSortingColSize(col_size, ListType::GetChildType(type)); - case PhysicalType::STRUCT: - // Structs get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, StructType::GetChildType(type, 0)); - case PhysicalType::ARRAY: - // Arrays get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, ArrayType::GetChildType(type)); - default: - throw NotImplementedException("Unable to order column with type %s", type.ToString()); - } - } -} - -SortLayout::SortLayout(const vector &orders) - : column_count(orders.size()), all_constant(true), comparison_size(0), entry_size(0) { - vector blob_layout_types; - for (idx_t i = 0; i < column_count; i++) { - const auto &order = orders[i]; - - order_types.push_back(order.type); - order_by_null_types.push_back(order.null_order); - auto &expr = *order.expression; - logical_types.push_back(expr.return_type); - - auto physical_type = expr.return_type.InternalType(); - constant_size.push_back(TypeIsConstantSize(physical_type)); - - if (order.stats) { - stats.push_back(order.stats.get()); - has_null.push_back(stats.back()->CanHaveNull()); - } else { - stats.push_back(nullptr); - has_null.push_back(true); - } - - idx_t col_size = has_null.back() ? 1 : 0; - prefix_lengths.push_back(0); - if (!TypeIsConstantSize(physical_type) && physical_type != PhysicalType::VARCHAR) { - prefix_lengths.back() = GetNestedSortingColSize(col_size, expr.return_type); - } else if (physical_type == PhysicalType::VARCHAR) { - idx_t size_before = col_size; - if (stats.back() && StringStats::HasMaxStringLength(*stats.back())) { - col_size += StringStats::MaxStringLength(*stats.back()); - if (col_size > 12) { - col_size = 12; - } else { - constant_size.back() = true; - } - } else { - col_size = 12; - } - prefix_lengths.back() = col_size - size_before; - } else { - col_size += GetTypeIdSize(physical_type); - } - - comparison_size += col_size; - column_sizes.push_back(col_size); - } - entry_size = comparison_size + sizeof(uint32_t); - - // 8-byte alignment - if (entry_size % 8 != 0) { - // First assign more bytes to strings instead of aligning - idx_t bytes_to_fill = 8 - (entry_size % 8); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - if (bytes_to_fill == 0) { - break; - } - if (logical_types[col_idx].InternalType() == PhysicalType::VARCHAR && stats[col_idx] && - StringStats::HasMaxStringLength(*stats[col_idx])) { - idx_t diff = StringStats::MaxStringLength(*stats[col_idx]) - prefix_lengths[col_idx]; - if (diff > 0) { - // Increase all sizes accordingly - idx_t increase = MinValue(bytes_to_fill, diff); - column_sizes[col_idx] += increase; - prefix_lengths[col_idx] += increase; - constant_size[col_idx] = increase == diff; - comparison_size += increase; - entry_size += increase; - bytes_to_fill -= increase; - } - } - } - entry_size = AlignValue(entry_size); - } - - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - all_constant = all_constant && constant_size[col_idx]; - if (!constant_size[col_idx]) { - sorting_to_blob_col[col_idx] = blob_layout_types.size(); - blob_layout_types.push_back(logical_types[col_idx]); - } - } - - blob_layout.Initialize(blob_layout_types); -} - -SortLayout SortLayout::GetPrefixComparisonLayout(idx_t num_prefix_cols) const { - SortLayout result; - result.column_count = num_prefix_cols; - result.all_constant = true; - result.comparison_size = 0; - for (idx_t col_idx = 0; col_idx < num_prefix_cols; col_idx++) { - result.order_types.push_back(order_types[col_idx]); - result.order_by_null_types.push_back(order_by_null_types[col_idx]); - result.logical_types.push_back(logical_types[col_idx]); - - result.all_constant = result.all_constant && constant_size[col_idx]; - result.constant_size.push_back(constant_size[col_idx]); - - result.comparison_size += column_sizes[col_idx]; - result.column_sizes.push_back(column_sizes[col_idx]); - - result.prefix_lengths.push_back(prefix_lengths[col_idx]); - result.stats.push_back(stats[col_idx]); - result.has_null.push_back(has_null[col_idx]); - } - result.entry_size = entry_size; - result.blob_layout = blob_layout; - result.sorting_to_blob_col = sorting_to_blob_col; - return result; -} - -LocalSortState::LocalSortState() : initialized(false) { - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("Sorting is not supported on big endian architectures"); - } -} - -void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p) { - sort_layout = &global_sort_state.sort_layout; - payload_layout = &global_sort_state.payload_layout; - buffer_manager = &buffer_manager_p; - const auto block_size = buffer_manager->GetBlockSize(); - - // Radix sorting data - auto entries_per_block = RowDataCollection::EntriesPerBlock(sort_layout->entry_size, block_size); - radix_sorting_data = make_uniq(*buffer_manager, entries_per_block, sort_layout->entry_size); - - // Blob sorting data - if (!sort_layout->all_constant) { - auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(blob_row_width, block_size); - blob_sorting_data = make_uniq(*buffer_manager, entries_per_block, blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, block_size, 1U, true); - } - - // Payload data - auto payload_row_width = payload_layout->GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(payload_row_width, block_size); - payload_data = make_uniq(*buffer_manager, entries_per_block, payload_row_width); - payload_heap = make_uniq(*buffer_manager, block_size, 1U, true); - initialized = true; -} - -void LocalSortState::SinkChunk(DataChunk &sort, DataChunk &payload) { - D_ASSERT(sort.size() == payload.size()); - // Build and serialize sorting data to radix sortable rows - auto data_pointers = FlatVector::GetData(addresses); - auto handles = radix_sorting_data->Build(sort.size(), data_pointers, nullptr); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - bool has_null = sort_layout->has_null[sort_col]; - bool nulls_first = sort_layout->order_by_null_types[sort_col] == OrderByNullType::NULLS_FIRST; - bool desc = sort_layout->order_types[sort_col] == OrderType::DESCENDING; - RowOperations::RadixScatter(sort.data[sort_col], sort.size(), sel_ptr, sort.size(), data_pointers, desc, - has_null, nulls_first, sort_layout->prefix_lengths[sort_col], - sort_layout->column_sizes[sort_col]); - } - - // Also fully serialize blob sorting columns (to be able to break ties - if (!sort_layout->all_constant) { - DataChunk blob_chunk; - blob_chunk.SetCardinality(sort.size()); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - if (!sort_layout->constant_size[sort_col]) { - blob_chunk.data.emplace_back(sort.data[sort_col]); - } - } - handles = blob_sorting_data->Build(blob_chunk.size(), data_pointers, nullptr); - auto blob_data = blob_chunk.ToUnifiedFormat(); - RowOperations::Scatter(blob_chunk, blob_data.get(), sort_layout->blob_layout, addresses, *blob_sorting_heap, - sel_ptr, blob_chunk.size()); - D_ASSERT(blob_sorting_heap->keep_pinned); - } - - // Finally, serialize payload data - handles = payload_data->Build(payload.size(), data_pointers, nullptr); - auto input_data = payload.ToUnifiedFormat(); - RowOperations::Scatter(payload, input_data.get(), *payload_layout, addresses, *payload_heap, sel_ptr, - payload.size()); - D_ASSERT(payload_heap->keep_pinned); -} - -idx_t LocalSortState::SizeInBytes() const { - idx_t size_in_bytes = radix_sorting_data->SizeInBytes() + payload_data->SizeInBytes(); - if (!sort_layout->all_constant) { - size_in_bytes += blob_sorting_data->SizeInBytes() + blob_sorting_heap->SizeInBytes(); - } - if (!payload_layout->AllConstant()) { - size_in_bytes += payload_heap->SizeInBytes(); - } - return size_in_bytes; -} - -void LocalSortState::Sort(GlobalSortState &global_sort_state, bool reorder_heap) { - D_ASSERT(radix_sorting_data->count == payload_data->count); - if (radix_sorting_data->count == 0) { - return; - } - // Move all data to a single SortedBlock - sorted_blocks.emplace_back(make_uniq(*buffer_manager, global_sort_state)); - auto &sb = *sorted_blocks.back(); - // Fixed-size sorting data - auto sorting_block = ConcatenateBlocks(*radix_sorting_data); - sb.radix_sorting_data.push_back(std::move(sorting_block)); - // Variable-size sorting data - if (!sort_layout->all_constant) { - auto &blob_data = *blob_sorting_data; - auto new_block = ConcatenateBlocks(blob_data); - sb.blob_sorting_data->data_blocks.push_back(std::move(new_block)); - } - // Payload data - auto payload_block = ConcatenateBlocks(*payload_data); - sb.payload_data->data_blocks.push_back(std::move(payload_block)); - // Now perform the actual sort - SortInMemory(); - // Re-order before the merge sort - ReOrder(global_sort_state, reorder_heap); -} - -unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &row_data) { - // Don't copy and delete if there is only one block. - if (row_data.blocks.size() == 1) { - auto new_block = std::move(row_data.blocks[0]); - row_data.blocks.clear(); - row_data.count = 0; - return new_block; - } - // Create block with the correct capacity - auto &buffer_manager = row_data.buffer_manager; - const idx_t &entry_size = row_data.entry_size; - idx_t capacity = MaxValue((buffer_manager.GetBlockSize() + entry_size - 1) / entry_size, row_data.count); - auto new_block = make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, entry_size); - new_block->count = row_data.count; - auto new_block_handle = buffer_manager.Pin(new_block->block); - data_ptr_t new_block_ptr = new_block_handle.Ptr(); - // Copy the data of the blocks into a single block - for (idx_t i = 0; i < row_data.blocks.size(); i++) { - auto &block = row_data.blocks[i]; - auto block_handle = buffer_manager.Pin(block->block); - memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); - new_block_ptr += block->count * entry_size; - block.reset(); - } - row_data.blocks.clear(); - row_data.count = 0; - return new_block; -} - -void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap) { - sd.swizzled = reorder_heap; - auto &unordered_data_block = sd.data_blocks.back(); - const idx_t count = unordered_data_block->count; - auto unordered_data_handle = buffer_manager->Pin(unordered_data_block->block); - const data_ptr_t unordered_data_ptr = unordered_data_handle.Ptr(); - // Create new block that will hold re-ordered row data - auto ordered_data_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, - unordered_data_block->capacity, unordered_data_block->entry_size); - ordered_data_block->count = count; - auto ordered_data_handle = buffer_manager->Pin(ordered_data_block->block); - data_ptr_t ordered_data_ptr = ordered_data_handle.Ptr(); - // Re-order fixed-size row layout - const idx_t row_width = sd.layout.GetRowWidth(); - const idx_t sorting_entry_size = gstate.sort_layout.entry_size; - for (idx_t i = 0; i < count; i++) { - auto index = Load(sorting_ptr); - FastMemcpy(ordered_data_ptr, unordered_data_ptr + index * row_width, row_width); - ordered_data_ptr += row_width; - sorting_ptr += sorting_entry_size; - } - ordered_data_block->block->SetSwizzling( - sd.layout.AllConstant() || !sd.swizzled ? nullptr : "LocalSortState::ReOrder.ordered_data"); - // Replace the unordered data block with the re-ordered data block - sd.data_blocks.clear(); - sd.data_blocks.push_back(std::move(ordered_data_block)); - // Deal with the heap (if necessary) - if (!sd.layout.AllConstant() && reorder_heap) { - // Swizzle the column pointers to offsets - RowOperations::SwizzleColumns(sd.layout, ordered_data_handle.Ptr(), count); - sd.data_blocks.back()->block->SetSwizzling(nullptr); - // Create a single heap block to store the ordered heap - idx_t total_byte_offset = - std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); - idx_t heap_block_size = MaxValue(total_byte_offset, buffer_manager->GetBlockSize()); - auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1U); - ordered_heap_block->count = count; - ordered_heap_block->byte_offset = total_byte_offset; - auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); - data_ptr_t ordered_heap_ptr = ordered_heap_handle.Ptr(); - // Fill the heap in order - ordered_data_ptr = ordered_data_handle.Ptr(); - const idx_t heap_pointer_offset = sd.layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto heap_row_ptr = Load(ordered_data_ptr + heap_pointer_offset); - auto heap_row_size = Load(heap_row_ptr); - memcpy(ordered_heap_ptr, heap_row_ptr, heap_row_size); - ordered_heap_ptr += heap_row_size; - ordered_data_ptr += row_width; - } - // Swizzle the base pointer to the offset of each row in the heap - RowOperations::SwizzleHeapPointer(sd.layout, ordered_data_handle.Ptr(), ordered_heap_handle.Ptr(), count); - // Move the re-ordered heap to the SortedData, and clear the local heap - sd.heap_blocks.push_back(std::move(ordered_heap_block)); - heap.pinned_blocks.clear(); - heap.blocks.clear(); - heap.count = 0; - } -} - -void LocalSortState::ReOrder(GlobalSortState &gstate, bool reorder_heap) { - auto &sb = *sorted_blocks.back(); - auto sorting_handle = buffer_manager->Pin(sb.radix_sorting_data.back()->block); - const data_ptr_t sorting_ptr = sorting_handle.Ptr() + gstate.sort_layout.comparison_size; - // Re-order variable size sorting columns - if (!gstate.sort_layout.all_constant) { - ReOrder(*sb.blob_sorting_data, sorting_ptr, *blob_sorting_heap, gstate, reorder_heap); - } - // And the payload - ReOrder(*sb.payload_data, sorting_ptr, *payload_heap, gstate, reorder_heap); -} - -GlobalSortState::GlobalSortState(ClientContext &context_p, const vector &orders, - RowLayout &payload_layout) - : context(context_p), buffer_manager(BufferManager::GetBufferManager(context)), sort_layout(SortLayout(orders)), - payload_layout(payload_layout), block_capacity(0), external(false) { -} - -void GlobalSortState::AddLocalState(LocalSortState &local_sort_state) { - if (!local_sort_state.radix_sorting_data) { - return; - } - - // Sort accumulated data - // we only re-order the heap when the data is expected to not fit in memory - // re-ordering the heap avoids random access when reading/merging but incurs a significant cost of shuffling data - // when data fits in memory, doing random access on reads is cheaper than re-shuffling - local_sort_state.Sort(*this, external || !local_sort_state.sorted_blocks.empty()); - - // Append local state sorted data to this global state - lock_guard append_guard(lock); - for (auto &sb : local_sort_state.sorted_blocks) { - sorted_blocks.push_back(std::move(sb)); - } - auto &payload_heap = local_sort_state.payload_heap; - for (idx_t i = 0; i < payload_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(payload_heap->blocks[i])); - pinned_blocks.push_back(std::move(payload_heap->pinned_blocks[i])); - } - if (!sort_layout.all_constant) { - auto &blob_heap = local_sort_state.blob_sorting_heap; - for (idx_t i = 0; i < blob_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(blob_heap->blocks[i])); - pinned_blocks.push_back(std::move(blob_heap->pinned_blocks[i])); - } - } -} - -void GlobalSortState::PrepareMergePhase() { - // Determine if we need to use do an external sort - idx_t total_heap_size = - std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); - if (external || (pinned_blocks.empty() && total_heap_size * 4 > buffer_manager.GetQueryMaxMemory())) { - external = true; - } - // Use the data that we have to determine which partition size to use during the merge - if (external && total_heap_size > 0) { - // If we have variable size data we need to be conservative, as there might be skew - idx_t max_block_size = 0; - for (auto &sb : sorted_blocks) { - idx_t size_in_bytes = sb->SizeInBytes(); - if (size_in_bytes > max_block_size) { - max_block_size = size_in_bytes; - block_capacity = sb->Count(); - } - } - } else { - for (auto &sb : sorted_blocks) { - block_capacity = MaxValue(block_capacity, sb->Count()); - } - } - // Unswizzle and pin heap blocks if we can fit everything in memory - if (!external) { - for (auto &sb : sorted_blocks) { - sb->blob_sorting_data->Unswizzle(); - sb->payload_data->Unswizzle(); - } - } -} - -void GlobalSortState::InitializeMergeRound() { - D_ASSERT(sorted_blocks_temp.empty()); - // If we reverse this list, the blocks that were merged last will be merged first in the next round - // These are still in memory, therefore this reduces the amount of read/write to disk! - std::reverse(sorted_blocks.begin(), sorted_blocks.end()); - // Uneven number of blocks - keep one on the side - if (sorted_blocks.size() % 2 == 1) { - odd_one_out = std::move(sorted_blocks.back()); - sorted_blocks.pop_back(); - } - // Init merge path path indices - pair_idx = 0; - num_pairs = sorted_blocks.size() / 2; - l_start = 0; - r_start = 0; - // Allocate room for merge results - for (idx_t p_idx = 0; p_idx < num_pairs; p_idx++) { - sorted_blocks_temp.emplace_back(); - } -} - -void GlobalSortState::CompleteMergeRound(bool keep_radix_data) { - sorted_blocks.clear(); - for (auto &sorted_block_vector : sorted_blocks_temp) { - sorted_blocks.push_back(make_uniq(buffer_manager, *this)); - sorted_blocks.back()->AppendSortedBlocks(sorted_block_vector); - } - sorted_blocks_temp.clear(); - if (odd_one_out) { - sorted_blocks.push_back(std::move(odd_one_out)); - odd_one_out = nullptr; - } - // Only one block left: Done! - if (sorted_blocks.size() == 1 && !keep_radix_data) { - sorted_blocks[0]->radix_sorting_data.clear(); - sorted_blocks[0]->blob_sorting_data = nullptr; - } -} -void GlobalSortState::Print() { - PayloadScanner scanner(*this, false); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - for (;;) { - scanner.Scan(chunk); - const auto count = chunk.size(); - if (!count) { - break; - } - chunk.Print(); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sorted_block.cpp b/src/duckdb/src/common/sort/sorted_block.cpp deleted file mode 100644 index c4766c956..000000000 --- a/src/duckdb/src/common/sort/sorted_block.cpp +++ /dev/null @@ -1,387 +0,0 @@ -#include "duckdb/common/sort/sorted_block.hpp" - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" - -#include - -namespace duckdb { - -SortedData::SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, - GlobalSortState &state) - : type(type), layout(layout), swizzled(state.external), buffer_manager(buffer_manager), state(state) { -} - -idx_t SortedData::Count() { - idx_t count = std::accumulate(data_blocks.begin(), data_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!layout.AllConstant() && state.external) { - D_ASSERT(count == std::accumulate(heap_blocks.begin(), heap_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; })); - } - return count; -} - -void SortedData::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); - data_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, layout.GetRowWidth())); - if (!layout.AllConstant() && state.external) { - heap_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_size, 1U)); - D_ASSERT(data_blocks.size() == heap_blocks.size()); - } -} - -unique_ptr SortedData::CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index) { - // Add the corresponding blocks to the result - auto result = make_uniq(type, layout, buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->data_blocks.push_back(data_blocks[i]->Copy()); - if (!layout.AllConstant() && state.external) { - result->heap_blocks.push_back(heap_blocks[i]->Copy()); - } - } - // All of the blocks that come before block with idx = start_block_idx can be reset (other references exist) - for (idx_t i = 0; i < start_block_index; i++) { - data_blocks[i]->block = nullptr; - if (!layout.AllConstant() && state.external) { - heap_blocks[i]->block = nullptr; - } - } - // Use start and end entry indices to set the boundaries - D_ASSERT(end_entry_index <= result->data_blocks.back()->count); - result->data_blocks.back()->count = end_entry_index; - if (!layout.AllConstant() && state.external) { - result->heap_blocks.back()->count = end_entry_index; - } - return result; -} - -void SortedData::Unswizzle() { - if (layout.AllConstant() || !swizzled) { - return; - } - for (idx_t i = 0; i < data_blocks.size(); i++) { - auto &data_block = data_blocks[i]; - auto &heap_block = heap_blocks[i]; - D_ASSERT(data_block->block->IsSwizzled()); - auto data_handle_p = buffer_manager.Pin(data_block->block); - auto heap_handle_p = buffer_manager.Pin(heap_block->block); - RowOperations::UnswizzlePointers(layout, data_handle_p.Ptr(), heap_handle_p.Ptr(), data_block->count); - state.heap_blocks.push_back(std::move(heap_block)); - state.pinned_blocks.push_back(std::move(heap_handle_p)); - } - swizzled = false; - heap_blocks.clear(); -} - -SortedBlock::SortedBlock(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), state(state), sort_layout(state.sort_layout), - payload_layout(state.payload_layout) { - blob_sorting_data = make_uniq(SortedDataType::BLOB, sort_layout.blob_layout, buffer_manager, state); - payload_data = make_uniq(SortedDataType::PAYLOAD, payload_layout, buffer_manager, state); -} - -idx_t SortedBlock::Count() const { - idx_t count = std::accumulate(radix_sorting_data.begin(), radix_sorting_data.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!sort_layout.all_constant) { - D_ASSERT(count == blob_sorting_data->Count()); - } - D_ASSERT(count == payload_data->Count()); - return count; -} - -void SortedBlock::InitializeWrite() { - CreateBlock(); - if (!sort_layout.all_constant) { - blob_sorting_data->CreateBlock(); - } - payload_data->CreateBlock(); -} - -void SortedBlock::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + sort_layout.entry_size - 1) / sort_layout.entry_size, state.block_capacity); - radix_sorting_data.push_back( - make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, sort_layout.entry_size)); -} - -void SortedBlock::AppendSortedBlocks(vector> &sorted_blocks) { - D_ASSERT(Count() == 0); - for (auto &sb : sorted_blocks) { - for (auto &radix_block : sb->radix_sorting_data) { - radix_sorting_data.push_back(std::move(radix_block)); - } - if (!sort_layout.all_constant) { - for (auto &blob_block : sb->blob_sorting_data->data_blocks) { - blob_sorting_data->data_blocks.push_back(std::move(blob_block)); - } - for (auto &heap_block : sb->blob_sorting_data->heap_blocks) { - blob_sorting_data->heap_blocks.push_back(std::move(heap_block)); - } - } - for (auto &payload_data_block : sb->payload_data->data_blocks) { - payload_data->data_blocks.push_back(std::move(payload_data_block)); - } - if (!payload_data->layout.AllConstant()) { - for (auto &payload_heap_block : sb->payload_data->heap_blocks) { - payload_data->heap_blocks.push_back(std::move(payload_heap_block)); - } - } - } -} - -void SortedBlock::GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index) { - if (global_idx == Count()) { - local_block_index = radix_sorting_data.size() - 1; - local_entry_index = radix_sorting_data.back()->count; - return; - } - D_ASSERT(global_idx < Count()); - local_entry_index = global_idx; - for (local_block_index = 0; local_block_index < radix_sorting_data.size(); local_block_index++) { - const idx_t &block_count = radix_sorting_data[local_block_index]->count; - if (local_entry_index >= block_count) { - local_entry_index -= block_count; - } else { - break; - } - } - D_ASSERT(local_entry_index < radix_sorting_data[local_block_index]->count); -} - -unique_ptr SortedBlock::CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx) { - // Identify blocks/entry indices of this slice - idx_t start_block_index; - idx_t start_entry_index; - GlobalToLocalIndex(start, start_block_index, start_entry_index); - idx_t end_block_index; - idx_t end_entry_index; - GlobalToLocalIndex(end, end_block_index, end_entry_index); - // Add the corresponding blocks to the result - auto result = make_uniq(buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->radix_sorting_data.push_back(radix_sorting_data[i]->Copy()); - } - // Reset all blocks that come before block with idx = start_block_idx (slice holds new reference) - for (idx_t i = 0; i < start_block_index; i++) { - radix_sorting_data[i]->block = nullptr; - } - // Use start and end entry indices to set the boundaries - entry_idx = start_entry_index; - D_ASSERT(end_entry_index <= result->radix_sorting_data.back()->count); - result->radix_sorting_data.back()->count = end_entry_index; - // Same for the var size sorting data - if (!sort_layout.all_constant) { - result->blob_sorting_data = blob_sorting_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - } - // And the payload data - result->payload_data = payload_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - return result; -} - -idx_t SortedBlock::HeapSize() const { - idx_t result = 0; - if (!sort_layout.all_constant) { - for (auto &block : blob_sorting_data->heap_blocks) { - result += block->capacity; - } - } - if (!payload_layout.AllConstant()) { - for (auto &block : payload_data->heap_blocks) { - result += block->capacity; - } - } - return result; -} - -idx_t SortedBlock::SizeInBytes() const { - idx_t bytes = 0; - for (idx_t i = 0; i < radix_sorting_data.size(); i++) { - bytes += radix_sorting_data[i]->capacity * sort_layout.entry_size; - if (!sort_layout.all_constant) { - bytes += blob_sorting_data->data_blocks[i]->capacity * sort_layout.blob_layout.GetRowWidth(); - bytes += blob_sorting_data->heap_blocks[i]->capacity; - } - bytes += payload_data->data_blocks[i]->capacity * payload_layout.GetRowWidth(); - if (!payload_layout.AllConstant()) { - bytes += payload_data->heap_blocks[i]->capacity; - } - } - return bytes; -} - -SBScanState::SBScanState(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), sort_layout(state.sort_layout), state(state), block_idx(0), entry_idx(0) { -} - -void SBScanState::PinRadix(idx_t block_idx_to) { - auto &radix_sorting_data = sb->radix_sorting_data; - D_ASSERT(block_idx_to < radix_sorting_data.size()); - auto &block = radix_sorting_data[block_idx_to]; - if (!radix_handle.IsValid() || radix_handle.GetBlockHandle() != block->block) { - radix_handle = buffer_manager.Pin(block->block); - } -} - -void SBScanState::PinData(SortedData &sd) { - D_ASSERT(block_idx < sd.data_blocks.size()); - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - - auto &data_block = sd.data_blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = buffer_manager.Pin(data_block->block); - } - if (sd.layout.AllConstant() || !state.external) { - return; - } - auto &heap_block = sd.heap_blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = buffer_manager.Pin(heap_block->block); - } -} - -data_ptr_t SBScanState::RadixPtr() const { - return radix_handle.Ptr() + entry_idx * sort_layout.entry_size; -} - -data_ptr_t SBScanState::DataPtr(SortedData &sd) const { - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - D_ASSERT(sd.data_blocks[block_idx]->block->Readers() != 0 && - data_handle.GetBlockHandle() == sd.data_blocks[block_idx]->block); - return data_handle.Ptr() + entry_idx * sd.layout.GetRowWidth(); -} - -data_ptr_t SBScanState::HeapPtr(SortedData &sd) const { - return BaseHeapPtr(sd) + Load(DataPtr(sd) + sd.layout.GetHeapOffset()); -} - -data_ptr_t SBScanState::BaseHeapPtr(SortedData &sd) const { - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - D_ASSERT(!sd.layout.AllConstant() && state.external); - D_ASSERT(sd.heap_blocks[block_idx]->block->Readers() != 0 && - heap_handle.GetBlockHandle() == sd.heap_blocks[block_idx]->block); - return heap_handle.Ptr(); -} - -idx_t SBScanState::Remaining() const { - const auto &blocks = sb->radix_sorting_data; - idx_t remaining = 0; - if (block_idx < blocks.size()) { - remaining += blocks[block_idx]->count - entry_idx; - for (idx_t i = block_idx + 1; i < blocks.size(); i++) { - remaining += blocks[i]->count; - } - } - return remaining; -} - -void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { - block_idx = block_idx_to; - entry_idx = entry_idx_to; -} - -PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { - auto count = sorted_data.Count(); - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant()) { - heap->count = count; - } - - if (flush_p) { - // If we are flushing, we can just move the data - rows->blocks = std::move(sorted_data.data_blocks); - if (!layout.AllConstant()) { - heap->blocks = std::move(sorted_data.heap_blocks); - } - } else { - // Not flushing, create references to the blocks - for (auto &block : sorted_data.data_blocks) { - rows->blocks.emplace_back(block->Copy()); - } - if (!layout.AllConstant()) { - for (auto &block : sorted_data.heap_blocks) { - heap->blocks.emplace_back(block->Copy()); - } - } - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, bool flush_p) - : PayloadScanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state, flush_p) { -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush_p) { - auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; - auto count = sorted_data.data_blocks[block_idx]->count; - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (flush_p) { - rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); - } else { - rows->blocks.emplace_back(sorted_data.data_blocks[block_idx]->Copy()); - } - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { - if (flush_p) { - heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); - } else { - heap->blocks.emplace_back(sorted_data.heap_blocks[block_idx]->Copy()); - } - heap->count = count; - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -void PayloadScanner::Scan(DataChunk &chunk) { - scanner->Scan(chunk); -} - -int SBIterator::ComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for IEJoin!"); - } -} - -static idx_t GetBlockCountWithEmptyCheck(const GlobalSortState &gss) { - D_ASSERT(!gss.sorted_blocks.empty()); - return gss.sorted_blocks[0]->radix_sorting_data.size(); -} - -SBIterator::SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p) - : sort_layout(gss.sort_layout), block_count(GetBlockCountWithEmptyCheck(gss)), block_capacity(gss.block_capacity), - entry_size(sort_layout.entry_size), all_constant(sort_layout.all_constant), external(gss.external), - cmp(ComparisonValue(comparison)), scan(gss.buffer_manager, gss), block_ptr(nullptr), entry_ptr(nullptr) { - - scan.sb = gss.sorted_blocks[0].get(); - scan.block_idx = block_count; - SetIndex(entry_idx_p); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sorted_run.cpp b/src/duckdb/src/common/sort/sorted_run.cpp similarity index 100% rename from src/duckdb/src/common/sorting/sorted_run.cpp rename to src/duckdb/src/common/sort/sorted_run.cpp diff --git a/src/duckdb/src/common/sorting/sorted_run_merger.cpp b/src/duckdb/src/common/sort/sorted_run_merger.cpp similarity index 99% rename from src/duckdb/src/common/sorting/sorted_run_merger.cpp rename to src/duckdb/src/common/sort/sorted_run_merger.cpp index d87cef470..874a7fc04 100644 --- a/src/duckdb/src/common/sorting/sorted_run_merger.cpp +++ b/src/duckdb/src/common/sort/sorted_run_merger.cpp @@ -844,6 +844,7 @@ SortedRunMerger::SortedRunMerger(const Sort &sort_p, vector SortedRunMerger::GetLocalSourceState(ExecutionContext &, GlobalSourceState &gstate_p) const { auto &gstate = gstate_p.Cast(); + auto guard = gstate.Lock(); return make_uniq(gstate); } diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp index 1e6309ee0..d3931d0bf 100644 --- a/src/duckdb/src/common/string_util.cpp +++ b/src/duckdb/src/common/string_util.cpp @@ -287,9 +287,13 @@ bool StringUtil::IsUpper(const string &str) { // Jenkins hash function: https://en.wikipedia.org/wiki/Jenkins_hash_function uint64_t StringUtil::CIHash(const string &str) { + return StringUtil::CIHash(str.c_str(), str.size()); +} + +uint64_t StringUtil::CIHash(const char *str, idx_t size) { uint32_t hash = 0; - for (auto c : str) { - hash += static_cast(StringUtil::CharacterToLower(static_cast(c))); + for (idx_t i = 0; i < size; i++) { + hash += static_cast(StringUtil::CharacterToLower(static_cast(str[i]))); hash += hash << 10; hash ^= hash >> 6; } diff --git a/src/duckdb/src/common/types/row/row_data_collection.cpp b/src/duckdb/src/common/types/row/row_data_collection.cpp deleted file mode 100644 index b178b7fb5..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection.hpp" - -namespace duckdb { - -RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, - bool keep_pinned) - : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), - keep_pinned(keep_pinned) { - D_ASSERT(block_capacity * entry_size + entry_size > buffer_manager.GetBlockSize()); -} - -idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, - vector &append_entries, idx_t remaining, idx_t entry_sizes[]) { - idx_t append_count = 0; - data_ptr_t dataptr; - if (entry_sizes) { - D_ASSERT(entry_size == 1); - // compute how many entries fit if entry size is variable - dataptr = handle.Ptr() + block.byte_offset; - for (idx_t i = 0; i < remaining; i++) { - if (block.byte_offset + entry_sizes[i] > block.capacity) { - if (block.count == 0 && append_count == 0 && entry_sizes[i] > block.capacity) { - // special case: single entry is bigger than block capacity - // resize current block to fit the entry, append it, and move to the next block - block.capacity = entry_sizes[i]; - buffer_manager.ReAllocate(block.block, block.capacity); - dataptr = handle.Ptr(); - append_count++; - block.byte_offset += entry_sizes[i]; - } - break; - } - append_count++; - block.byte_offset += entry_sizes[i]; - } - } else { - append_count = MinValue(remaining, block.capacity - block.count); - dataptr = handle.Ptr() + block.count * entry_size; - } - append_entries.emplace_back(dataptr, append_count); - block.count += append_count; - return append_count; -} - -RowDataBlock &RowDataCollection::CreateBlock() { - blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_capacity, entry_size)); - return *blocks.back(); -} - -vector RowDataCollection::Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], - const SelectionVector *sel) { - vector handles; - vector append_entries; - - // first allocate space of where to serialize the keys and payload columns - idx_t remaining = added_count; - { - // first append to the last block (if any) - lock_guard append_lock(rdc_lock); - count += added_count; - - if (!blocks.empty()) { - auto &last_block = *blocks.back(); - if (last_block.count < last_block.capacity) { - // last block has space: pin the buffer of this block - auto handle = buffer_manager.Pin(last_block.block); - // now append to the block - idx_t append_count = AppendToBlock(last_block, handle, append_entries, remaining, entry_sizes); - remaining -= append_count; - handles.push_back(std::move(handle)); - } - } - while (remaining > 0) { - // now for the remaining data, allocate new buffers to store the data and append there - auto &new_block = CreateBlock(); - auto handle = buffer_manager.Pin(new_block.block); - - // offset the entry sizes array if we have added entries already - idx_t *offset_entry_sizes = entry_sizes ? entry_sizes + added_count - remaining : nullptr; - - idx_t append_count = AppendToBlock(new_block, handle, append_entries, remaining, offset_entry_sizes); - D_ASSERT(new_block.count > 0); - remaining -= append_count; - - if (keep_pinned) { - pinned_blocks.push_back(std::move(handle)); - } else { - handles.push_back(std::move(handle)); - } - } - } - // now set up the key_locations based on the append entries - idx_t append_idx = 0; - for (auto &append_entry : append_entries) { - idx_t next = append_idx + append_entry.count; - if (entry_sizes) { - for (; append_idx < next; append_idx++) { - key_locations[append_idx] = append_entry.baseptr; - append_entry.baseptr += entry_sizes[append_idx]; - } - } else { - for (; append_idx < next; append_idx++) { - auto idx = sel->get_index(append_idx); - key_locations[idx] = append_entry.baseptr; - append_entry.baseptr += entry_size; - } - } - } - // return the unique pointers to the handles because they must stay pinned - return handles; -} - -void RowDataCollection::Merge(RowDataCollection &other) { - if (other.count == 0) { - return; - } - RowDataCollection temp(buffer_manager, buffer_manager.GetBlockSize(), 1); - { - // One lock at a time to avoid deadlocks - lock_guard read_lock(other.rdc_lock); - temp.count = other.count; - temp.block_capacity = other.block_capacity; - temp.entry_size = other.entry_size; - temp.blocks = std::move(other.blocks); - temp.pinned_blocks = std::move(other.pinned_blocks); - } - other.Clear(); - - lock_guard write_lock(rdc_lock); - count += temp.count; - block_capacity = MaxValue(block_capacity, temp.block_capacity); - entry_size = MaxValue(entry_size, temp.entry_size); - for (auto &block : temp.blocks) { - blocks.emplace_back(std::move(block)); - } - for (auto &handle : temp.pinned_blocks) { - pinned_blocks.emplace_back(std::move(handle)); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp deleted file mode 100644 index 9b3a4be06..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp +++ /dev/null @@ -1,330 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -#include - -namespace duckdb { - -void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block_collection, - RowDataCollection &swizzled_string_heap, - RowDataCollection &block_collection, RowDataCollection &string_heap, - const RowLayout &layout) { - if (block_collection.count == 0) { - return; - } - - if (layout.AllConstant()) { - // No heap blocks! Just merge fixed-size data - swizzled_block_collection.Merge(block_collection); - return; - } - - // We create one heap block per data block and swizzle the pointers - D_ASSERT(string_heap.keep_pinned == swizzled_string_heap.keep_pinned); - auto &buffer_manager = block_collection.buffer_manager; - auto &heap_blocks = string_heap.blocks; - idx_t heap_block_idx = 0; - idx_t heap_block_remaining = heap_blocks[heap_block_idx]->count; - for (auto &data_block : block_collection.blocks) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - - // Pin the data block and swizzle the pointers within the rows - auto data_handle = buffer_manager.Pin(data_block->block); - auto data_ptr = data_handle.Ptr(); - if (!string_heap.keep_pinned) { - D_ASSERT(!data_block->block->IsSwizzled()); - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - data_block->block->SetSwizzling(nullptr); - } - // At this point the data block is pinned and the heap pointer is valid - // so we can copy heap data as needed - - // We want to copy as little of the heap data as possible, check how the data and heap blocks line up - if (heap_block_remaining >= data_block->count) { - // Easy: current heap block contains all strings for this data block, just copy (reference) the block - swizzled_string_heap.blocks.emplace_back(heap_blocks[heap_block_idx]->Copy()); - swizzled_string_heap.blocks.back()->count = data_block->count; - - // Swizzle the heap pointer if we are not pinning the heap - auto &heap_block = swizzled_string_heap.blocks.back()->block; - auto heap_handle = buffer_manager.Pin(heap_block); - if (!swizzled_string_heap.keep_pinned) { - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, - NumericCast(heap_offset)); - } else { - swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); - } - - // Update counter - heap_block_remaining -= data_block->count; - } else { - // Strings for this data block are spread over the current heap block and the next (and possibly more) - if (string_heap.keep_pinned) { - // The heap is changing underneath the data block, - // so swizzle the string pointers to make them portable. - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - } - idx_t data_block_remaining = data_block->count; - vector> ptrs_and_sizes; - idx_t total_size = 0; - const auto base_row_ptr = data_ptr; - while (data_block_remaining > 0) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - auto next = MinValue(data_block_remaining, heap_block_remaining); - - // Figure out where to start copying strings, and how many bytes we need to copy - auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_end_ptr = - Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); - auto size = NumericCast(heap_end_ptr - heap_start_ptr + Load(heap_end_ptr)); - ptrs_and_sizes.emplace_back(heap_start_ptr, size); - D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); - - // Swizzle the heap pointer - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_start_ptr, next, total_size); - total_size += size; - - // Update where we are in the data and heap blocks - data_ptr += next * layout.GetRowWidth(); - data_block_remaining -= next; - heap_block_remaining -= next; - } - - // Finally, we allocate a new heap block and copy data to it - swizzled_string_heap.blocks.emplace_back(make_uniq( - MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, buffer_manager.GetBlockSize()), 1U)); - auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); - auto new_heap_ptr = new_heap_handle.Ptr(); - for (auto &ptr_and_size : ptrs_and_sizes) { - memcpy(new_heap_ptr, ptr_and_size.first, ptr_and_size.second); - new_heap_ptr += ptr_and_size.second; - } - new_heap_ptr = new_heap_handle.Ptr(); - if (swizzled_string_heap.keep_pinned) { - // Since the heap blocks are pinned, we can unswizzle the data again. - swizzled_string_heap.pinned_blocks.emplace_back(std::move(new_heap_handle)); - RowOperations::UnswizzlePointers(layout, base_row_ptr, new_heap_ptr, data_block->count); - RowOperations::UnswizzleHeapPointer(layout, base_row_ptr, new_heap_ptr, data_block->count); - } - } - } - - // We're done with variable-sized data, now just merge the fixed-size data - swizzled_block_collection.Merge(block_collection); - D_ASSERT(swizzled_block_collection.blocks.size() == swizzled_string_heap.blocks.size()); - - // Update counts and cleanup - swizzled_string_heap.count = string_heap.count; - string_heap.Clear(); -} - -void RowDataCollectionScanner::ScanState::PinData() { - auto &rows = scanner.rows; - D_ASSERT(block_idx < rows.blocks.size()); - auto &data_block = rows.blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = rows.buffer_manager.Pin(data_block->block); - } - if (scanner.layout.AllConstant() || !scanner.external) { - return; - } - - auto &heap = scanner.heap; - D_ASSERT(block_idx < heap.blocks.size()); - auto &heap_block = heap.blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = heap.buffer_manager.Pin(heap_block->block); - } -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - ValidateUnscannedBlock(); -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, idx_t block_idx, - bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - D_ASSERT(block_idx < rows.blocks.size()); - read_state.block_idx = block_idx; - read_state.entry_idx = 0; - - // Pretend that we have scanned up to the start block - // and will stop at the end - auto begin = rows.blocks.begin(); - auto end = begin + NumericCast(block_idx); - total_scanned = - std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - total_count = total_scanned + (*end)->count; - - ValidateUnscannedBlock(); -} - -void RowDataCollectionScanner::SwizzleBlockInternal(RowDataBlock &data_block, RowDataBlock &heap_block) { - // Pin the data block and swizzle the pointers within the rows - D_ASSERT(!data_block.block->IsSwizzled()); - auto data_handle = rows.buffer_manager.Pin(data_block.block); - auto data_ptr = data_handle.Ptr(); - RowOperations::SwizzleColumns(layout, data_ptr, data_block.count); - data_block.block->SetSwizzling(nullptr); - - // Swizzle the heap pointers - auto heap_handle = heap.buffer_manager.Pin(heap_block.block); - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, NumericCast(heap_offset)); -} - -void RowDataCollectionScanner::SwizzleBlock(idx_t block_idx) { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - auto &data_block = rows.blocks[block_idx]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[block_idx]); - } -} - -void RowDataCollectionScanner::ReSwizzle() { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - for (idx_t i = 0; i < rows.blocks.size(); ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } -} - -void RowDataCollectionScanner::ValidateUnscannedBlock() const { - if (unswizzling && read_state.block_idx < rows.blocks.size() && Remaining()) { - D_ASSERT(rows.blocks[read_state.block_idx]->block->IsSwizzled()); - } -} - -void RowDataCollectionScanner::Scan(DataChunk &chunk) { - auto count = MinValue((idx_t)STANDARD_VECTOR_SIZE, total_count - total_scanned); - if (count == 0) { - chunk.SetCardinality(count); - return; - } - - // Only flush blocks we processed. - const auto flush_block_idx = read_state.block_idx; - - const idx_t &row_width = layout.GetRowWidth(); - // Set up a batch of pointers to scan data from - idx_t scanned = 0; - auto data_pointers = FlatVector::GetData(addresses); - - // We must pin ALL blocks we are going to gather from - vector pinned_blocks; - while (scanned < count) { - read_state.PinData(); - auto &data_block = rows.blocks[read_state.block_idx]; - idx_t next = MinValue(data_block->count - read_state.entry_idx, count - scanned); - const data_ptr_t data_ptr = read_state.data_handle.Ptr() + read_state.entry_idx * row_width; - // Set up the next pointers - data_ptr_t row_ptr = data_ptr; - for (idx_t i = 0; i < next; i++) { - data_pointers[scanned + i] = row_ptr; - row_ptr += row_width; - } - // Unswizzle the offsets back to pointers (if needed) - if (unswizzling) { - RowOperations::UnswizzlePointers(layout, data_ptr, read_state.heap_handle.Ptr(), next); - rows.blocks[read_state.block_idx]->block->SetSwizzling("RowDataCollectionScanner::Scan"); - } - // Update state indices - read_state.entry_idx += next; - scanned += next; - total_scanned += next; - if (read_state.entry_idx == data_block->count) { - // Pin completed blocks so we don't lose them - pinned_blocks.emplace_back(rows.buffer_manager.Pin(data_block->block)); - if (unswizzling) { - auto &heap_block = heap.blocks[read_state.block_idx]; - pinned_blocks.emplace_back(heap.buffer_manager.Pin(heap_block->block)); - } - read_state.block_idx++; - read_state.entry_idx = 0; - ValidateUnscannedBlock(); - } - } - D_ASSERT(scanned == count); - // Deserialize the payload data - for (idx_t col_no = 0; col_no < layout.ColumnCount(); col_no++) { - RowOperations::Gather(addresses, *FlatVector::IncrementalSelectionVector(), chunk.data[col_no], - *FlatVector::IncrementalSelectionVector(), count, layout, col_no); - } - chunk.SetCardinality(count); - chunk.Verify(); - - // Switch to a new set of pinned blocks - read_state.pinned_blocks.swap(pinned_blocks); - - if (flush) { - // Release blocks we have passed. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - rows.blocks[i]->block = nullptr; - if (unswizzling) { - heap.blocks[i]->block = nullptr; - } - } - } else if (unswizzling) { - // Reswizzle blocks we have passed so they can be flushed safely. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } - } -} - -void RowDataCollectionScanner::Reset(bool flush_p) { - flush = flush_p; - total_scanned = 0; - - read_state.block_idx = 0; - read_state.entry_idx = 0; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_layout.cpp b/src/duckdb/src/common/types/row/row_layout.cpp deleted file mode 100644 index 3add8e425..000000000 --- a/src/duckdb/src/common/types/row/row_layout.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_layout.cpp -// -// -//===----------------------------------------------------------------------===// - -#include "duckdb/common/types/row/row_layout.hpp" - -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -RowLayout::RowLayout() : flag_width(0), data_width(0), row_width(0), all_constant(true), heap_pointer_offset(0) { -} - -void RowLayout::Initialize(vector types_p, bool align) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (const auto &type : types) { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - - // This enables pointer swizzling for out-of-core computation. - if (!all_constant) { - // When unswizzled, the pointer lives here. - // When swizzled, the pointer is replaced by an offset. - heap_pointer_offset = row_width; - // The 8 byte pointer will be replaced with an 8 byte idx_t when swizzled. - // However, this cannot be sizeof(data_ptr_t), since 32 bit builds use 4 byte pointers. - row_width += sizeof(idx_t); - } - - // Data columns. No alignment required. - for (const auto &type : types) { - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - data_width = row_width - flag_width; - - // Alignment padding for the next row - if (align) { - row_width = AlignValue(row_width); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/selection_vector.cpp b/src/duckdb/src/common/types/selection_vector.cpp index 145b6bfa1..a1232340c 100644 --- a/src/duckdb/src/common/types/selection_vector.cpp +++ b/src/duckdb/src/common/types/selection_vector.cpp @@ -50,6 +50,14 @@ buffer_ptr SelectionVector::Slice(const SelectionVector &sel, idx return data; } +idx_t SelectionVector::SliceInPlace(const SelectionVector &source, idx_t count) { + for (idx_t i = 0; i < count; ++i) { + set_index(i, get_index(source.get_index(i))); + } + + return count; +} + void SelectionVector::Verify(idx_t count, idx_t vector_size) const { #ifdef DEBUG D_ASSERT(vector_size >= 1); diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp index ec11c1289..6707e28e9 100644 --- a/src/duckdb/src/execution/expression_executor.cpp +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -181,6 +181,7 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co } else { VectorOperations::DefaultCast(vector, intermediate, count, true); } + intermediate.Verify(count); Vector result(vector.GetType(), true, false, count); //! Then cast back into the original type @@ -190,6 +191,7 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co VectorOperations::DefaultCast(intermediate, result, count, true); } vector.Reference(result); + vector.Verify(count); } } diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp index ad974a475..31fc85160 100644 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp @@ -20,9 +20,10 @@ PhysicalAsOfJoin::PhysicalAsOfJoin(PhysicalPlan &physical_plan, LogicalCompariso PhysicalOperator &right) : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::ASOF_JOIN, std::move(op.conditions), op.join_type, op.estimated_cardinality), - comparison_type(ExpressionType::INVALID), predicate(std::move(op.predicate)) { + comparison_type(ExpressionType::INVALID) { // Convert the conditions partitions and sorts + D_ASSERT(!op.predicate.get()); for (auto &cond : conditions) { D_ASSERT(cond.left->return_type == cond.right->return_type); join_key_types.push_back(cond.left->return_type); @@ -406,8 +407,6 @@ class AsOfProbeBuffer { // Predicate evaluation SelectionVector tail_sel; - SelectionVector filter_sel; - ExpressionExecutor filterer; idx_t lhs_match_count; bool fetch_next_left; @@ -415,7 +414,7 @@ class AsOfProbeBuffer { AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin &op) : client(client), op(op), strict(IsStrictComparison(op.comparison_type)), left_outer(IsLeftOuterJoin(op.join_type)), - lhs_executor(client), rhs_executor(client), filterer(client), fetch_next_left(true) { + lhs_executor(client), rhs_executor(client), fetch_next_left(true) { lhs_keys.Initialize(client, op.join_key_types); for (const auto &cond : op.conditions) { @@ -439,11 +438,6 @@ AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &client, const PhysicalAsOfJoin & rhs_executor.AddExpression(*cond.right); } } - - if (op.predicate) { - filter_sel.Initialize(); - filterer.AddExpression(*op.predicate); - } } void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { @@ -692,15 +686,6 @@ void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &ch } } -static idx_t SliceSelectionVector(SelectionVector &target, const SelectionVector &source, const idx_t count) { - idx_t result = 0; - for (idx_t i = 0; i < count; ++i) { - target.set_index(result++, target.get_index(source.get_index(i))); - } - - return result; -} - void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk) { // perform the actual join idx_t matches[STANDARD_VECTOR_SIZE]; @@ -775,21 +760,11 @@ void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &c } else { chunk.Slice(*sel, tail_count); // Slice lhs_match_sel to the remaining lhs rows - lhs_match_count = SliceSelectionVector(lhs_match_sel, *sel, tail_count); + lhs_match_count = lhs_match_sel.SliceInPlace(*sel, tail_count); } } } - // Apply the predicate filter - // TODO: This is wrong - we have to search for a match - if (filterer.expressions.size() == 1) { - const auto filter_count = filterer.SelectExpression(chunk, filter_sel); - if (filter_count < chunk.size()) { - chunk.Slice(filter_sel, filter_count); - lhs_match_count = SliceSelectionVector(lhs_match_sel, filter_sel, filter_count); - } - } - // Update the match masks for the rows we ended up with left_outer.Reset(); for (idx_t i = 0; i < lhs_match_count; ++i) { diff --git a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp index d96cda05d..b449eee2f 100644 --- a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp @@ -1,7 +1,5 @@ #include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" #include "duckdb/parallel/thread_context.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/nested_loop_join.hpp" #include "duckdb/main/client_context.hpp" @@ -9,20 +7,23 @@ namespace duckdb { -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, - PhysicalOperator &right, vector cond, JoinType join_type, +PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, + PhysicalOperator &left, PhysicalOperator &right, + vector cond, JoinType join_type, idx_t estimated_cardinality, unique_ptr pushdown_info_p) : PhysicalComparisonJoin(physical_plan, op, PhysicalOperatorType::NESTED_LOOP_JOIN, std::move(cond), join_type, - estimated_cardinality) { + estimated_cardinality), + predicate(std::move(op.predicate)) { filter_pushdown = std::move(pushdown_info_p); children.push_back(left); children.push_back(right); } -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, - PhysicalOperator &right, vector cond, JoinType join_type, +PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, + PhysicalOperator &left, PhysicalOperator &right, + vector cond, JoinType join_type, idx_t estimated_cardinality) : PhysicalNestedLoopJoin(physical_plan, op, left, right, std::move(cond), join_type, estimated_cardinality, nullptr) { @@ -273,7 +274,7 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { PhysicalNestedLoopJoinState(ClientContext &context, const PhysicalNestedLoopJoin &op, const vector &conditions) : fetch_next_left(true), fetch_next_right(false), lhs_executor(context), left_tuple(0), right_tuple(0), - left_outer(IsLeftOuterJoin(op.join_type)) { + left_outer(IsLeftOuterJoin(op.join_type)), pred_executor(context) { vector condition_types; for (auto &cond : conditions) { lhs_executor.AddExpression(*cond.left); @@ -284,6 +285,11 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { right_condition.Initialize(allocator, condition_types); right_payload.Initialize(allocator, op.children[1].get().GetTypes()); left_outer.Initialize(STANDARD_VECTOR_SIZE); + + if (op.predicate) { + pred_executor.AddExpression(*op.predicate); + pred_matches.Initialize(); + } } bool fetch_next_left; @@ -302,6 +308,10 @@ class PhysicalNestedLoopJoinState : public CachingOperatorState { OuterJoinMarker left_outer; + //! Predicate + ExpressionExecutor pred_executor; + SelectionVector pred_matches; + public: void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { context.thread.profiler.Flush(op); @@ -438,11 +448,20 @@ OperatorResultType PhysicalNestedLoopJoin::ResolveComplexJoin(ExecutionContext & if (match_count > 0) { // we have matching tuples! // construct the result - state.left_outer.SetMatches(lvector, match_count); - gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); - chunk.Slice(input, lvector, match_count); chunk.Slice(right_payload, rvector, match_count, input.ColumnCount()); + + // If we have a predicate, apply it to the result + if (predicate) { + auto &sel = state.pred_matches; + match_count = state.pred_executor.SelectExpression(chunk, sel); + chunk.Slice(sel, match_count); + lvector.SliceInPlace(sel, match_count); + rvector.SliceInPlace(sel, match_count); + } + + state.left_outer.SetMatches(lvector, match_count); + gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); } // check if we exhausted the RHS, if we did we need to move to the next right chunk in the next iteration diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp index c457e9826..11ac71492 100644 --- a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp @@ -13,6 +13,7 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/expression_iterator.hpp" #include "duckdb/main/settings.hpp" namespace duckdb { @@ -43,10 +44,9 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera const auto &probe_types = op.children[0]->types; join_op.types.insert(join_op.types.end(), probe_types.begin(), probe_types.end()); - // TODO: We can't handle predicates right now because we would have to remap column references. - if (op.predicate) { - return nullptr; - } + // Project pk + LogicalType pk_type = LogicalType::BIGINT; + join_op.types.emplace_back(pk_type); // Fill in the projection maps to simplify the code below // Since NLJ doesn't support projection, but ASOF does, @@ -65,9 +65,25 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera } } - // Project pk - LogicalType pk_type = LogicalType::BIGINT; - join_op.types.emplace_back(pk_type); + // Remap predicate column references. + if (op.predicate) { + vector swap_projection_map; + const auto rhs_width = op.children[1]->types.size(); + for (const auto &l : join_op.right_projection_map) { + swap_projection_map.emplace_back(l + rhs_width); + } + for (const auto &r : join_op.left_projection_map) { + swap_projection_map.emplace_back(r); + } + join_op.predicate = op.predicate->Copy(); + ExpressionIterator::EnumerateExpression(join_op.predicate, [&](Expression &child) { + if (child.GetExpressionClass() == ExpressionClass::BOUND_REF) { + auto &col_idx = child.Cast().index; + const auto new_idx = swap_projection_map[col_idx]; + col_idx = new_idx; + } + }); + } auto binder = Binder::CreateBinder(context); FunctionBinder function_binder(*binder); @@ -208,7 +224,7 @@ PhysicalPlanGenerator::PlanAsOfLoopJoin(LogicalComparisonJoin &op, PhysicalOpera auto window_types = probe.GetTypes(); window_types.emplace_back(pk_type); - idx_t probe_cardinality = op.children[0]->EstimateCardinality(context); + const auto probe_cardinality = op.EstimateCardinality(context); auto &window = Make(window_types, std::move(window_select), probe_cardinality); window.children.emplace_back(probe); @@ -275,10 +291,12 @@ PhysicalOperator &PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) } D_ASSERT(asof_idx < op.conditions.size()); - bool force_asof_join = DBConfig::GetSetting(context); - if (!force_asof_join) { - idx_t asof_join_threshold = DBConfig::GetSetting(context); - if (op.children[0]->has_estimated_cardinality && lhs_cardinality < asof_join_threshold) { + // If there is a non-comparison predicate, we have to use NLJ. + const bool has_predicate = op.predicate.get(); + const bool force_asof_join = DBConfig::GetSetting(context); + if (!force_asof_join || has_predicate) { + const idx_t asof_join_threshold = DBConfig::GetSetting(context); + if (has_predicate || (op.children[0]->has_estimated_cardinality && lhs_cardinality < asof_join_threshold)) { auto result = PlanAsOfLoopJoin(op, left, right); if (result) { return *result; diff --git a/src/duckdb/src/function/cast/variant/from_variant.cpp b/src/duckdb/src/function/cast/variant/from_variant.cpp index aa129b463..f29db2b85 100644 --- a/src/duckdb/src/function/cast/variant/from_variant.cpp +++ b/src/duckdb/src/function/cast/variant/from_variant.cpp @@ -349,6 +349,14 @@ static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, V SelectionVector child_values_sel; child_values_sel.Initialize(count); + SelectionVector row_sel(0, count); + if (row.IsValid()) { + auto row_index = row.GetIndex(); + for (idx_t i = 0; i < count; i++) { + row_sel[i] = static_cast(row_index); + } + } + for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { auto &child_name = child_types[child_idx].first; @@ -357,14 +365,21 @@ static bool ConvertVariantToStruct(FromVariantConversionData &conversion_data, V VariantPathComponent component; component.key = child_name; component.lookup_mode = VariantChildLookupMode::BY_KEY; - auto collection_result = - VariantUtils::FindChildValues(conversion_data.variant, component, row, child_values_sel, child_data, count); - if (!collection_result.Success()) { - D_ASSERT(collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = collection_result.nested_data_index; - auto row_index = row.IsValid() ? row.GetIndex() : nested_index; + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(conversion_data.variant, component, row_sel, child_values_sel, lookup_validity, + child_data, count); + if (!lookup_validity.AllValid()) { + optional_idx nested_index; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + nested_index = i; + break; + } + } + D_ASSERT(nested_index.IsValid()); + auto row_index = row.IsValid() ? row.GetIndex() : nested_index.GetIndex(); auto object_keys = - VariantUtils::GetObjectKeys(conversion_data.variant, row_index, child_data[nested_index]); + VariantUtils::GetObjectKeys(conversion_data.variant, row_index, child_data[nested_index.GetIndex()]); conversion_data.error = StringUtil::Format("VARIANT(OBJECT(%s)) is missing key '%s'", StringUtil::Join(object_keys, ","), component.key); return false; diff --git a/src/duckdb/src/function/cast/variant/to_variant.cpp b/src/duckdb/src/function/cast/variant/to_variant.cpp index ad1962d37..813d24835 100644 --- a/src/duckdb/src/function/cast/variant/to_variant.cpp +++ b/src/duckdb/src/function/cast/variant/to_variant.cpp @@ -130,6 +130,9 @@ static bool WriteVariantResultData(ToVariantSourceData &source, ToVariantGlobalR } static bool CastToVARIANT(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + if (!count) { + return true; + } DataChunk offsets; offsets.Initialize(Allocator::DefaultAllocator(), {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, diff --git a/src/duckdb/src/function/macro_function.cpp b/src/duckdb/src/function/macro_function.cpp index 2f407c025..66e36181b 100644 --- a/src/duckdb/src/function/macro_function.cpp +++ b/src/duckdb/src/function/macro_function.cpp @@ -47,7 +47,7 @@ MacroBindResult MacroFunction::BindMacroFunction( InsertionOrderPreservingMap> &named_arguments, idx_t depth) { ExpressionBinder expr_binder(binder, binder.context); - + expr_binder.lambda_bindings = binder.lambda_bindings; // Find argument types and separate positional and default arguments vector positional_arg_types; InsertionOrderPreservingMap named_arg_types; diff --git a/src/duckdb/src/function/scalar/create_sort_key.cpp b/src/duckdb/src/function/scalar/create_sort_key.cpp index 2f5463e3f..d9127d359 100644 --- a/src/duckdb/src/function/scalar/create_sort_key.cpp +++ b/src/duckdb/src/function/scalar/create_sort_key.cpp @@ -696,13 +696,15 @@ void PrepareSortData(Vector &result, idx_t size, SortKeyLengthInfo &key_lengths, } } -void FinalizeSortData(Vector &result, idx_t size) { +void FinalizeSortData(Vector &result, idx_t size, const SortKeyLengthInfo &key_lengths, + const unsafe_vector &offsets) { switch (result.GetType().id()) { case LogicalTypeId::BLOB: { auto result_data = FlatVector::GetData(result); // call Finalize on the result for (idx_t r = 0; r < size; r++) { - result_data[r].Finalize(); + result_data[r].SetSizeAndFinalize(NumericCast(offsets[r]), + key_lengths.variable_lengths[r] + key_lengths.constant_length); } break; } @@ -739,7 +741,7 @@ void CreateSortKeyInternal(vector> &sort_key_data, SortKeyConstructInfo info(modifiers[c], offsets, data_pointers.get()); ConstructSortKey(*sort_key_data[c], info); } - FinalizeSortData(result, row_count); + FinalizeSortData(result, row_count, key_lengths, offsets); } } // namespace diff --git a/src/duckdb/src/function/scalar/operator/arithmetic.cpp b/src/duckdb/src/function/scalar/operator/arithmetic.cpp index 82cd9b5b7..1dde43871 100644 --- a/src/duckdb/src/function/scalar/operator/arithmetic.cpp +++ b/src/duckdb/src/function/scalar/operator/arithmetic.cpp @@ -1220,7 +1220,7 @@ hugeint_t InterpolateOperator::Operation(const hugeint_t &lo, const double d, co template <> uhugeint_t InterpolateOperator::Operation(const uhugeint_t &lo, const double d, const uhugeint_t &hi) { - return Hugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); + return Uhugeint::Convert(Operation(Uhugeint::Cast(lo), d, Uhugeint::Cast(hi))); } static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT diff --git a/src/duckdb/src/function/scalar/variant/variant_extract.cpp b/src/duckdb/src/function/scalar/variant/variant_extract.cpp index e0c10fa73..2c6ff14cf 100644 --- a/src/duckdb/src/function/scalar/variant/variant_extract.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_extract.cpp @@ -28,8 +28,11 @@ BindData::BindData(const string &str) : FunctionData() { component.key = str; } BindData::BindData(uint32_t index) : FunctionData() { + if (index == 0) { + throw BinderException("Extracting index 0 from VARIANT(ARRAY) is invalid, indexes are 1-based"); + } component.lookup_mode = VariantChildLookupMode::BY_INDEX; - component.index = index; + component.index = index - 1; } unique_ptr BindData::Copy() const { @@ -142,22 +145,26 @@ static void VariantExtractFunction(DataChunk &input, ExpressionState &state, Vec } //! Look up the value_index of the child we're extracting - auto child_collection_result = - VariantUtils::FindChildValues(variant, component, optional_idx(), new_value_index_sel, nested_data, count); - if (!child_collection_result.Success()) { - if (child_collection_result.type == VariantChildDataCollectionResult::Type::INDEX_ZERO) { - throw InvalidInputException("Extracting index 0 from VARIANT(ARRAY) is invalid, indexes are 1-based"); + ValidityMask lookup_validity(count); + VariantUtils::FindChildValues(variant, component, nullptr, new_value_index_sel, lookup_validity, nested_data, + count); + if (!lookup_validity.AllValid()) { + optional_idx index; + for (idx_t i = 0; i < count; i++) { + if (!lookup_validity.RowIsValid(i)) { + index = i; + break; + } } + D_ASSERT(index.IsValid()); switch (component.lookup_mode) { case VariantChildLookupMode::BY_INDEX: { - D_ASSERT(child_collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = child_collection_result.nested_data_index; + auto nested_index = index.GetIndex(); throw InvalidInputException("VARIANT(ARRAY(%d)) is missing index %d", nested_data[nested_index].child_count, component.index); } case VariantChildLookupMode::BY_KEY: { - D_ASSERT(child_collection_result.type == VariantChildDataCollectionResult::Type::COMPONENT_NOT_FOUND); - auto nested_index = child_collection_result.nested_data_index; + auto nested_index = index.GetIndex(); auto row_index = nested_index; auto object_keys = VariantUtils::GetObjectKeys(variant, row_index, nested_data[nested_index]); throw InvalidInputException("VARIANT(OBJECT(%s)) is missing key '%s'", StringUtil::Join(object_keys, ","), diff --git a/src/duckdb/src/function/scalar/variant/variant_utils.cpp b/src/duckdb/src/function/scalar/variant/variant_utils.cpp index 05e96a905..b9450188f 100644 --- a/src/duckdb/src/function/scalar/variant/variant_utils.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_utils.cpp @@ -7,6 +7,18 @@ namespace duckdb { +PhysicalType VariantDecimalData::GetPhysicalType() const { + if (width > DecimalWidth::max) { + return PhysicalType::INT128; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT64; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT32; + } else { + return PhysicalType::INT16; + } +} + bool VariantUtils::IsNestedType(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { auto type_id = variant.GetTypeId(row, value_index); return type_id == VariantLogicalType::ARRAY || type_id == VariantLogicalType::OBJECT; @@ -19,11 +31,10 @@ VariantDecimalData VariantUtils::DecodeDecimalData(const UnifiedVariantVectorDat auto data = const_data_ptr_cast(variant.GetData(row).GetData()); auto ptr = data + byte_offset; - VariantDecimalData result; - result.width = VarintDecode(ptr); - result.scale = VarintDecode(ptr); - result.value_ptr = ptr; - return result; + auto width = VarintDecode(ptr); + auto scale = VarintDecode(ptr); + auto value_ptr = ptr; + return VariantDecimalData(width, scale, value_ptr); } string_t VariantUtils::DecodeStringData(const UnifiedVariantVectorData &variant, idx_t row, uint32_t value_index) { @@ -63,13 +74,13 @@ vector VariantUtils::GetObjectKeys(const UnifiedVariantVectorData &varia return object_keys; } -VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, - const VariantPathComponent &component, optional_idx row, - SelectionVector &res, VariantNestedData *nested_data, - idx_t count) { +//! FIXME: this shouldn't return a "result", it should populate a validity mask instead. +void VariantUtils::FindChildValues(const UnifiedVariantVectorData &variant, const VariantPathComponent &component, + optional_ptr sel, SelectionVector &res, + ValidityMask &res_validity, VariantNestedData *nested_data, idx_t count) { for (idx_t i = 0; i < count; i++) { - auto row_index = row.IsValid() ? row.GetIndex() : i; + auto row_index = sel ? sel->get_index(i) : i; auto &nested_data_entry = nested_data[i]; if (nested_data_entry.is_null) { @@ -77,13 +88,10 @@ VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVari } if (component.lookup_mode == VariantChildLookupMode::BY_INDEX) { auto child_idx = component.index; - if (child_idx == 0) { - return VariantChildDataCollectionResult::IndexZero(); - } - child_idx--; if (child_idx >= nested_data_entry.child_count) { //! The list is too small to contain this index - return VariantChildDataCollectionResult::NotFound(i); + res_validity.SetInvalid(i); + continue; } auto value_id = variant.GetValuesIndex(row_index, nested_data_entry.children_idx + child_idx); res[i] = static_cast(value_id); @@ -103,10 +111,9 @@ VariantChildDataCollectionResult VariantUtils::FindChildValues(const UnifiedVari } } if (!found_child) { - return VariantChildDataCollectionResult::NotFound(i); + res_validity.SetInvalid(i); } } - return VariantChildDataCollectionResult(); } vector VariantUtils::ValueIsNull(const UnifiedVariantVectorData &variant, const SelectionVector &sel, diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 9eb76671a..4d811b36a 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev828" +#define DUCKDB_PATCH_VERSION "0-dev966" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev828" +#define DUCKDB_VERSION "v1.5.0-dev966" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "353406bd7f" +#define DUCKDB_SOURCE_ID "9d77bcf518" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h index ccf5ad5ac..ca07b602c 100644 --- a/src/duckdb/src/include/duckdb.h +++ b/src/duckdb/src/include/duckdb.h @@ -4675,6 +4675,14 @@ Check if the column at 'index' index of the table has a DEFAULT expression. */ DUCKDB_C_API duckdb_state duckdb_column_has_default(duckdb_table_description table_description, idx_t index, bool *out); +/*! +Return the number of columns of the described table. + +* @param table_description The table_description to query. +* @return The column count. +*/ +DUCKDB_C_API idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description); + /*! Obtain the column name at 'index'. The out result must be destroyed with `duckdb_free`. @@ -4685,6 +4693,17 @@ The out result must be destroyed with `duckdb_free`. */ DUCKDB_C_API char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index); +/*! +Obtain the column type at 'index'. +The return value must be destroyed with `duckdb_destroy_logical_type`. + +* @param table_description The table_description to query. +* @param index The index of the column to query. +* @return The column type. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, + idx_t index); + //===--------------------------------------------------------------------===// // Arrow Interface //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/common/hugeint.hpp b/src/duckdb/src/include/duckdb/common/hugeint.hpp index c9b54bd95..acdc4fb4b 100644 --- a/src/duckdb/src/include/duckdb/common/hugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/hugeint.hpp @@ -76,7 +76,7 @@ struct hugeint_t { // NOLINT: use numeric casing DUCKDB_API explicit operator int16_t() const; DUCKDB_API explicit operator int32_t() const; DUCKDB_API explicit operator int64_t() const; - DUCKDB_API operator uhugeint_t() const; // NOLINT: Allow implicit conversion from `hugeint_t` + DUCKDB_API explicit operator uhugeint_t() const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp index f1a6f6eb3..a847e217b 100644 --- a/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp +++ b/src/duckdb/src/include/duckdb/common/operator/comparison_operators.hpp @@ -210,15 +210,4 @@ inline bool GreaterThan::Operation(const interval_t &left, const interval_t &rig return Interval::GreaterThan(left, right); } -//===--------------------------------------------------------------------===// -// Specialized Hugeint Comparison Operators -//===--------------------------------------------------------------------===// -template <> -inline bool Equals::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::Equals(left, right); -} -template <> -inline bool GreaterThan::Operation(const hugeint_t &left, const hugeint_t &right) { - return Hugeint::GreaterThan(left, right); -} } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp index 557e9cd5b..ee1d11afb 100644 --- a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp +++ b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp @@ -25,22 +25,6 @@ struct SelectionVector; class StringHeap; struct UnifiedVectorFormat; -// The NestedValidity class help to set/get the validity from inside nested vectors -class NestedValidity { - data_ptr_t list_validity_location; - data_ptr_t *struct_validity_locations; - idx_t entry_idx; - idx_t idx_in_entry; - idx_t list_validity_offset; - -public: - explicit NestedValidity(data_ptr_t validitymask_location); - NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index); - void SetInvalid(idx_t idx); - bool IsValid(idx_t idx); - void OffsetListBy(idx_t offset); -}; - struct RowOperationsState { explicit RowOperationsState(ArenaAllocator &allocator) : allocator(allocator) { } @@ -49,7 +33,7 @@ struct RowOperationsState { unique_ptr addresses; // Re-usable vector for row_aggregate.cpp }; -// RowOperations contains a set of operations that operate on data using a RowLayout +// RowOperations contains a set of operations that operate on data using a TupleDataLayout struct RowOperations { //===--------------------------------------------------------------------===// // Aggregation Operators @@ -70,66 +54,6 @@ struct RowOperations { //! finalize - unaligned addresses, updated static void FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, DataChunk &result, idx_t aggr_idx); - - //===--------------------------------------------------------------------===// - // Read/Write Operators - //===--------------------------------------------------------------------===// - //! Scatter group data to the rows. Initialises the ValidityMask. - static void Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count); - //! Gather a single column. - //! If heap_ptr is not null, then the data is assumed to contain swizzled pointers, - //! which will be unswizzled in memory. - static void Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size = 0, - data_ptr_t heap_ptr = nullptr); - - //===--------------------------------------------------------------------===// - // Heap Operators - //===--------------------------------------------------------------------===// - //! Compute the entry sizes of a vector with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset = 0); - //! Compute the entry sizes of vector data with variable size type (used before building heap buffer space). - static void ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset = 0); - //! Scatter vector with variable size type to the heap. - static void HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset = 0); - //! Scatter vector data with variable size type to the heap. - static void HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset = 0); - //! Gather a single column with variable size type from the heap. - static void HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t key_locations[], - optional_ptr parent_validity); - - //===--------------------------------------------------------------------===// - // Sorting Operators - //===--------------------------------------------------------------------===// - //! Scatter vector data to the rows in radix-sortable format. - static void RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t key_locations[], bool desc, bool has_null, bool nulls_first, idx_t prefix_len, - idx_t width, idx_t offset = 0); - - //===--------------------------------------------------------------------===// - // Out-of-Core Operators - //===--------------------------------------------------------------------===// - //! Swizzles blob pointers to offset within heap row - static void SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count); - //! Swizzles the base pointer of each row to offset within heap block - static void SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset = 0); - //! Copies 'count' heap rows that are pointed to by the rows at 'row_ptr' to 'heap_ptr' and swizzles the pointers - static void CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count); - - //! Unswizzles the base offset within heap block the rows to pointers - static void UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); - //! Unswizzles all offsets back to pointers - static void UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/comparators.hpp b/src/duckdb/src/include/duckdb/common/sort/comparators.hpp deleted file mode 100644 index 5f3cd3807..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/comparators.hpp +++ /dev/null @@ -1,65 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/comparators.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -struct SortLayout; -struct SBScanState; - -using ValidityBytes = RowLayout::ValidityBytes; - -struct Comparators { -public: - //! Whether a tie between two blobs can be broken - static bool TieIsBreakable(const idx_t &col_idx, const data_ptr_t &row_ptr, const SortLayout &sort_layout); - //! Compares the tuples that a being read from in the 'left' and 'right blocks during merge sort - //! (only in case we cannot simply 'memcmp' - if there are blob columns) - static int CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort); - //! Compare two blob values - static int CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type); - -private: - //! Compares two blob values that were initially tied by their prefix - static int BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external); - //! Compare two fixed-size values - template - static int TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr); - - //! Compare two values at the pointers (can be recursive if nested type) - static int CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid); - //! Compares two fixed-size values at the given pointers - template - static int TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr); - //! Compares two string values at the given pointers - static int CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid); - //! Compares two struct values at the given pointers (recursive) - static int CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid); - static int CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid, - idx_t array_size); - //! Compare two list values at the pointers (can be recursive if nested type) - static int CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, bool valid); - //! Compares a list of fixed-size values - template - static int TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const ValidityBytes &left_validity, - const ValidityBytes &right_validity, const idx_t &count); - - //! Unwizzles an offset into a pointer - static void UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); - //! Swizzles a pointer into an offset - static void SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type); -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp b/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp deleted file mode 100644 index c935a713a..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/duckdb_pdqsort.hpp +++ /dev/null @@ -1,710 +0,0 @@ -/* -pdqsort.h - Pattern-defeating quicksort. - -Copyright (c) 2021 Orson Peters - -This software is provided 'as-is', without any express or implied warranty. In no event will the -authors be held liable for any damages arising from the use of this software. - -Permission is granted to anyone to use this software for any purpose, including commercial -applications, and to alter it and redistribute it freely, subject to the following restrictions: - -1. The origin of this software must not be misrepresented; you must not claim that you wrote the - original software. If you use this software in a product, an acknowledgment in the product - documentation would be appreciated but is not required. - -2. Altered source versions must be plainly marked as such, and must not be misrepresented as - being the original software. - -3. This notice may not be removed or altered from any source distribution. -*/ - -#pragma once - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/unique_ptr.hpp" - -#include -#include -#include -#include -#include - -namespace duckdb_pdqsort { - -using duckdb::data_ptr_t; -using duckdb::data_t; -using duckdb::FastMemcmp; -using duckdb::FastMemcpy; -using duckdb::idx_t; -using duckdb::make_unsafe_uniq_array_uninitialized; -using duckdb::unique_ptr; -using duckdb::unsafe_unique_array; - -// NOLINTBEGIN - -enum { - // Partitions below this size are sorted using insertion sort. - insertion_sort_threshold = 24, - - // Partitions above this size use Tukey's ninther to select the pivot. - ninther_threshold = 128, - - // When we detect an already sorted partition, attempt an insertion sort that allows this - // amount of element moves before giving up. - partial_insertion_sort_limit = 8, - - // Must be multiple of 8 due to loop unrolling, and < 256 to fit in unsigned char. - block_size = 64, - - // Cacheline size, assumes power of two. - cacheline_size = 64 - -}; - -// Returns floor(log2(n)), assumes n > 0. -template -inline int log2(T n) { - int log = 0; - while (n >>= 1) { - ++log; - } - return log; -} - -struct PDQConstants { - PDQConstants(idx_t entry_size, idx_t comp_offset, idx_t comp_size, data_ptr_t end) - : entry_size(entry_size), comp_offset(comp_offset), comp_size(comp_size), - tmp_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), tmp_buf(tmp_buf_ptr.get()), - iter_swap_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - iter_swap_buf(iter_swap_buf_ptr.get()), - swap_offsets_buf_ptr(make_unsafe_uniq_array_uninitialized(entry_size)), - swap_offsets_buf(swap_offsets_buf_ptr.get()), end(end) { - } - - const duckdb::idx_t entry_size; - const idx_t comp_offset; - const idx_t comp_size; - - unsafe_unique_array tmp_buf_ptr; - const data_ptr_t tmp_buf; - - unsafe_unique_array iter_swap_buf_ptr; - const data_ptr_t iter_swap_buf; - - unsafe_unique_array swap_offsets_buf_ptr; - const data_ptr_t swap_offsets_buf; - - const data_ptr_t end; -}; - -struct PDQIterator { - PDQIterator(data_ptr_t ptr, const idx_t &entry_size) : ptr(ptr), entry_size(entry_size) { - } - - inline PDQIterator(const PDQIterator &other) : ptr(other.ptr), entry_size(other.entry_size) { - } - - inline const data_ptr_t &operator*() const { - return ptr; - } - - inline PDQIterator &operator++() { - ptr += entry_size; - return *this; - } - - inline PDQIterator &operator--() { - ptr -= entry_size; - return *this; - } - - inline PDQIterator operator++(int) { - auto tmp = *this; - ptr += entry_size; - return tmp; - } - - inline PDQIterator operator--(int) { - auto tmp = *this; - ptr -= entry_size; - return tmp; - } - - inline PDQIterator operator+(const idx_t &i) const { - auto result = *this; - result.ptr += i * entry_size; - return result; - } - - inline PDQIterator operator-(const idx_t &i) const { - PDQIterator result = *this; - result.ptr -= i * entry_size; - return result; - } - - inline PDQIterator &operator=(const PDQIterator &other) { - D_ASSERT(entry_size == other.entry_size); - ptr = other.ptr; - return *this; - } - - inline friend idx_t operator-(const PDQIterator &lhs, const PDQIterator &rhs) { - D_ASSERT(duckdb::NumericCast(*lhs - *rhs) % lhs.entry_size == 0); - D_ASSERT(*lhs - *rhs >= 0); - return duckdb::NumericCast(*lhs - *rhs) / lhs.entry_size; - } - - inline friend bool operator<(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs < *rhs; - } - - inline friend bool operator>(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs > *rhs; - } - - inline friend bool operator>=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs >= *rhs; - } - - inline friend bool operator<=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs <= *rhs; - } - - inline friend bool operator==(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs == *rhs; - } - - inline friend bool operator!=(const PDQIterator &lhs, const PDQIterator &rhs) { - return *lhs != *rhs; - } - -private: - data_ptr_t ptr; - const idx_t &entry_size; -}; - -static inline bool comp(const data_ptr_t &l, const data_ptr_t &r, const PDQConstants &constants) { - D_ASSERT(l == constants.tmp_buf || l == constants.swap_offsets_buf || l < constants.end); - D_ASSERT(r == constants.tmp_buf || r == constants.swap_offsets_buf || r < constants.end); - return FastMemcmp(l + constants.comp_offset, r + constants.comp_offset, constants.comp_size) < 0; -} - -static inline const data_ptr_t &GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.tmp_buf, src, constants.entry_size); - return constants.tmp_buf; -} - -static inline const data_ptr_t &SWAP_OFFSETS_GET_TMP(const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(src != constants.tmp_buf && src != constants.swap_offsets_buf && src < constants.end); - FastMemcpy(constants.swap_offsets_buf, src, constants.entry_size); - return constants.swap_offsets_buf; -} - -static inline void MOVE(const data_ptr_t &dest, const data_ptr_t &src, const PDQConstants &constants) { - D_ASSERT(dest == constants.tmp_buf || dest == constants.swap_offsets_buf || dest < constants.end); - D_ASSERT(src == constants.tmp_buf || src == constants.swap_offsets_buf || src < constants.end); - FastMemcpy(dest, src, constants.entry_size); -} - -static inline void iter_swap(const PDQIterator &lhs, const PDQIterator &rhs, const PDQConstants &constants) { - D_ASSERT(*lhs < constants.end); - D_ASSERT(*rhs < constants.end); - FastMemcpy(constants.iter_swap_buf, *lhs, constants.entry_size); - FastMemcpy(*lhs, *rhs, constants.entry_size); - FastMemcpy(*rhs, constants.iter_swap_buf, constants.entry_size); -} - -// Sorts [begin, end) using insertion sort with the given comparison function. -inline void insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Sorts [begin, end) using insertion sort with the given comparison function. Assumes -// *(begin - 1) is an element smaller than or equal to any element in [begin, end). -inline void unguarded_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - } - } -} - -// Attempts to use insertion sort on [begin, end). Will return false if more than -// partial_insertion_sort_limit elements were moved, and abort sorting. Otherwise it will -// successfully sort and return true. -inline bool partial_insertion_sort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return true; - } - - std::size_t limit = 0; - for (PDQIterator cur = begin + 1; cur != end; ++cur) { - PDQIterator sift = cur; - PDQIterator sift_1 = cur - 1; - - // Compare first so we can avoid 2 moves for an element already positioned correctly. - if (comp(*sift, *sift_1, constants)) { - const auto &tmp = GET_TMP(*sift, constants); - - do { - MOVE(*sift--, *sift_1, constants); - } while (sift != begin && comp(tmp, *--sift_1, constants)); - - MOVE(*sift, tmp, constants); - limit += cur - sift; - } - - if (limit > partial_insertion_sort_limit) { - return false; - } - } - - return true; -} - -inline void sort2(const PDQIterator &a, const PDQIterator &b, const PDQConstants &constants) { - if (comp(*b, *a, constants)) { - iter_swap(a, b, constants); - } -} - -// Sorts the elements *a, *b and *c using comparison function comp. -inline void sort3(const PDQIterator &a, const PDQIterator &b, const PDQIterator &c, const PDQConstants &constants) { - sort2(a, b, constants); - sort2(b, c, constants); - sort2(a, b, constants); -} - -template -inline T *align_cacheline(T *p) { -#if defined(UINTPTR_MAX) && __cplusplus >= 201103L - std::uintptr_t ip = reinterpret_cast(p); -#else - std::size_t ip = reinterpret_cast(p); -#endif - ip = (ip + cacheline_size - 1) & -duckdb::UnsafeNumericCast(cacheline_size); - return reinterpret_cast(ip); -} - -inline void swap_offsets(const PDQIterator &first, const PDQIterator &last, unsigned char *offsets_l, - unsigned char *offsets_r, size_t num, bool use_swaps, const PDQConstants &constants) { - if (use_swaps) { - // This case is needed for the descending distribution, where we need - // to have proper swapping for pdqsort to remain O(n). - for (size_t i = 0; i < num; ++i) { - iter_swap(first + offsets_l[i], last - offsets_r[i], constants); - } - } else if (num > 0) { - PDQIterator l = first + offsets_l[0]; - PDQIterator r = last - offsets_r[0]; - const auto &tmp = SWAP_OFFSETS_GET_TMP(*l, constants); - MOVE(*l, *r, constants); - for (size_t i = 1; i < num; ++i) { - l = first + offsets_l[i]; - MOVE(*r, *l, constants); - r = last - offsets_r[i]; - MOVE(*l, *r, constants); - } - MOVE(*r, tmp, constants); - } -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. Uses branchless partitioning. -inline std::pair partition_right_branchless(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - if (!already_partitioned) { - iter_swap(first, last, constants); - ++first; - - // The following branchless partitioning is derived from "BlockQuicksort: How Branch - // Mispredictions don’t affect Quicksort" by Stefan Edelkamp and Armin Weiss, but - // heavily micro-optimized. - unsigned char offsets_l_storage[block_size + cacheline_size]; - unsigned char offsets_r_storage[block_size + cacheline_size]; - unsigned char *offsets_l = align_cacheline(offsets_l_storage); - unsigned char *offsets_r = align_cacheline(offsets_r_storage); - - PDQIterator offsets_l_base = first; - PDQIterator offsets_r_base = last; - size_t num_l, num_r, start_l, start_r; - num_l = num_r = start_l = start_r = 0; - - while (first < last) { - // Fill up offset blocks with elements that are on the wrong side. - // First we determine how much elements are considered for each offset block. - size_t num_unknown = last - first; - size_t left_split = num_l == 0 ? (num_r == 0 ? num_unknown / 2 : num_unknown) : 0; - size_t right_split = num_r == 0 ? (num_unknown - left_split) : 0; - - // Fill the offset blocks. - if (left_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } else { - for (unsigned char i = 0; i < left_split;) { - offsets_l[num_l] = i++; - num_l += !comp(*first, pivot, constants); - ++first; - } - } - - if (right_split >= block_size) { - for (unsigned char i = 0; i < block_size;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } else { - for (unsigned char i = 0; i < right_split;) { - offsets_r[num_r] = ++i; - num_r += comp(*--last, pivot, constants); - } - } - - // Swap elements and update block sizes and first/last boundaries. - size_t num = std::min(num_l, num_r); - swap_offsets(offsets_l_base, offsets_r_base, offsets_l + start_l, offsets_r + start_r, num, num_l == num_r, - constants); - num_l -= num; - num_r -= num; - start_l += num; - start_r += num; - - if (num_l == 0) { - start_l = 0; - offsets_l_base = first; - } - - if (num_r == 0) { - start_r = 0; - offsets_r_base = last; - } - } - - // We have now fully identified [first, last)'s proper position. Swap the last elements. - if (num_l) { - offsets_l += start_l; - while (num_l--) { - iter_swap(offsets_l_base + offsets_l[num_l], --last, constants); - } - first = last; - } - if (num_r) { - offsets_r += start_r; - while (num_r--) { - iter_swap(offsets_r_base - offsets_r[num_r], first, constants), ++first; - } - last = first; - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Partitions [begin, end) around pivot *begin using comparison function comp. Elements equal -// to the pivot are put in the right-hand partition. Returns the position of the pivot after -// partitioning and whether the passed sequence already was correctly partitioned. Assumes the -// pivot is a median of at least 3 elements and that [begin, end) is at least -// insertion_sort_threshold long. -inline std::pair partition_right(const PDQIterator &begin, const PDQIterator &end, - const PDQConstants &constants) { - // Move pivot into local for speed. - const auto &pivot = GET_TMP(*begin, constants); - - PDQIterator first = begin; - PDQIterator last = end; - - // Find the first element greater than or equal than the pivot (the median of 3 guarantees - // this exists). - while (comp(*++first, pivot, constants)) { - } - - // Find the first element strictly smaller than the pivot. We have to guard this search if - // there was no element before *first. - if (first - 1 == begin) { - while (first < last && !comp(*--last, pivot, constants)) { - } - } else { - while (!comp(*--last, pivot, constants)) { - } - } - - // If the first pair of elements that should be swapped to partition are the same element, - // the passed in sequence already was correctly partitioned. - bool already_partitioned = first >= last; - - // Keep swapping pairs of elements that are on the wrong side of the pivot. Previously - // swapped pairs guard the searches, which is why the first iteration is special-cased - // above. - while (first < last) { - iter_swap(first, last, constants); - while (comp(*++first, pivot, constants)) { - } - while (!comp(*--last, pivot, constants)) { - } - } - - // Put the pivot in the right place. - PDQIterator pivot_pos = first - 1; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return std::make_pair(pivot_pos, already_partitioned); -} - -// Similar function to the one above, except elements equal to the pivot are put to the left of -// the pivot and it doesn't check or return if the passed sequence already was partitioned. -// Since this is rarely used (the many equal case), and in that case pdqsort already has O(n) -// performance, no block quicksort is applied here for simplicity. -inline PDQIterator partition_left(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - const auto &pivot = GET_TMP(*begin, constants); - PDQIterator first = begin; - PDQIterator last = end; - - while (comp(pivot, *--last, constants)) { - } - - if (last + 1 == end) { - while (first < last && !comp(pivot, *++first, constants)) { - } - } else { - while (!comp(pivot, *++first, constants)) { - } - } - - while (first < last) { - iter_swap(first, last, constants); - while (comp(pivot, *--last, constants)) { - } - while (!comp(pivot, *++first, constants)) { - } - } - - PDQIterator pivot_pos = last; - MOVE(*begin, *pivot_pos, constants); - MOVE(*pivot_pos, pivot, constants); - - return pivot_pos; -} - -template -inline void pdqsort_loop(PDQIterator begin, const PDQIterator &end, const PDQConstants &constants, int bad_allowed, - bool leftmost = true) { - // Use a while loop for tail recursion elimination. - while (true) { - idx_t size = end - begin; - - // Insertion sort is faster for small arrays. - if (size < insertion_sort_threshold) { - if (leftmost) { - insertion_sort(begin, end, constants); - } else { - unguarded_insertion_sort(begin, end, constants); - } - return; - } - - // Choose pivot as median of 3 or pseudomedian of 9. - idx_t s2 = size / 2; - if (size > ninther_threshold) { - sort3(begin, begin + s2, end - 1, constants); - sort3(begin + 1, begin + (s2 - 1), end - 2, constants); - sort3(begin + 2, begin + (s2 + 1), end - 3, constants); - sort3(begin + (s2 - 1), begin + s2, begin + (s2 + 1), constants); - iter_swap(begin, begin + s2, constants); - } else { - sort3(begin + s2, begin, end - 1, constants); - } - - // If *(begin - 1) is the end of the right partition of a previous partition operation - // there is no element in [begin, end) that is smaller than *(begin - 1). Then if our - // pivot compares equal to *(begin - 1) we change strategy, putting equal elements in - // the left partition, greater elements in the right partition. We do not have to - // recurse on the left partition, since it's sorted (all equal). - if (!leftmost && !comp(*(begin - 1), *begin, constants)) { - begin = partition_left(begin, end, constants) + 1; - continue; - } - - // Partition and get results. - std::pair part_result = - Branchless ? partition_right_branchless(begin, end, constants) : partition_right(begin, end, constants); - PDQIterator pivot_pos = part_result.first; - bool already_partitioned = part_result.second; - - // Check for a highly unbalanced partition. - idx_t l_size = pivot_pos - begin; - idx_t r_size = end - (pivot_pos + 1); - bool highly_unbalanced = l_size < size / 8 || r_size < size / 8; - - // If we got a highly unbalanced partition we shuffle elements to break many patterns. - if (highly_unbalanced) { - // If we had too many bad partitions, switch to heapsort to guarantee O(n log n). - // if (--bad_allowed == 0) { - // std::make_heap(begin, end, comp); - // std::sort_heap(begin, end, comp); - // return; - // } - - if (l_size >= insertion_sort_threshold) { - iter_swap(begin, begin + l_size / 4, constants); - iter_swap(pivot_pos - 1, pivot_pos - l_size / 4, constants); - - if (l_size > ninther_threshold) { - iter_swap(begin + 1, begin + (l_size / 4 + 1), constants); - iter_swap(begin + 2, begin + (l_size / 4 + 2), constants); - iter_swap(pivot_pos - 2, pivot_pos - (l_size / 4 + 1), constants); - iter_swap(pivot_pos - 3, pivot_pos - (l_size / 4 + 2), constants); - } - } - - if (r_size >= insertion_sort_threshold) { - iter_swap(pivot_pos + 1, pivot_pos + (1 + r_size / 4), constants); - iter_swap(end - 1, end - r_size / 4, constants); - - if (r_size > ninther_threshold) { - iter_swap(pivot_pos + 2, pivot_pos + (2 + r_size / 4), constants); - iter_swap(pivot_pos + 3, pivot_pos + (3 + r_size / 4), constants); - iter_swap(end - 2, end - (1 + r_size / 4), constants); - iter_swap(end - 3, end - (2 + r_size / 4), constants); - } - } - } else { - // If we were decently balanced and we tried to sort an already partitioned - // sequence try to use insertion sort. - if (already_partitioned && partial_insertion_sort(begin, pivot_pos, constants) && - partial_insertion_sort(pivot_pos + 1, end, constants)) { - return; - } - } - - // Sort the left partition first using recursion and do tail recursion elimination for - // the right-hand partition. - pdqsort_loop(begin, pivot_pos, constants, bad_allowed, leftmost); - begin = pivot_pos + 1; - leftmost = false; - } -} - -inline void pdqsort(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} - -inline void pdqsort_branchless(const PDQIterator &begin, const PDQIterator &end, const PDQConstants &constants) { - if (begin == end) { - return; - } - pdqsort_loop(begin, end, constants, log2(end - begin)); -} -// NOLINTEND - -} // namespace duckdb_pdqsort diff --git a/src/duckdb/src/include/duckdb/common/sort/sort.hpp b/src/duckdb/src/include/duckdb/common/sort/sort.hpp deleted file mode 100644 index 188ea2127..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/sort.hpp +++ /dev/null @@ -1,290 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sort.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class RowLayout; -struct LocalSortState; - -struct SortConstants { - static constexpr idx_t VALUES_PER_RADIX = 256; - static constexpr idx_t MSD_RADIX_LOCATIONS = VALUES_PER_RADIX + 1; - static constexpr idx_t INSERTION_SORT_THRESHOLD = 24; - static constexpr idx_t MSD_RADIX_SORT_SIZE_THRESHOLD = 4; -}; - -struct SortLayout { -public: - SortLayout() { - } - explicit SortLayout(const vector &orders); - SortLayout GetPrefixComparisonLayout(idx_t num_prefix_cols) const; - -public: - idx_t column_count; - vector order_types; - vector order_by_null_types; - vector logical_types; - - bool all_constant; - vector constant_size; - vector column_sizes; - vector prefix_lengths; - vector stats; - vector has_null; - - idx_t comparison_size; - idx_t entry_size; - - RowLayout blob_layout; - unordered_map sorting_to_blob_col; -}; - -struct GlobalSortState { -public: - GlobalSortState(ClientContext &context, const vector &orders, RowLayout &payload_layout); - - //! Add local state sorted data to this global state - void AddLocalState(LocalSortState &local_sort_state); - //! Prepares the GlobalSortState for the merge sort phase (after completing radix sort phase) - void PrepareMergePhase(); - //! Initializes the global sort state for another round of merging - void InitializeMergeRound(); - //! Completes the cascaded merge sort round. - //! Pass true if you wish to use the radix data for further comparisons. - void CompleteMergeRound(bool keep_radix_data = false); - //! Print the sorted data to the console. - void Print(); - -public: - //! The client context - ClientContext &context; - //! The lock for updating the order global state - mutex lock; - //! The buffer manager - BufferManager &buffer_manager; - - //! Sorting and payload layouts - const SortLayout sort_layout; - const RowLayout payload_layout; - - //! Sorted data - vector> sorted_blocks; - vector>> sorted_blocks_temp; - unique_ptr odd_one_out; - - //! Pinned heap data (if sorting in memory) - vector> heap_blocks; - vector pinned_blocks; - - //! Capacity (number of rows) used to initialize blocks - idx_t block_capacity; - //! Whether we are doing an external sort - bool external; - - //! Progress in merge path stage - idx_t pair_idx; - idx_t num_pairs; - idx_t l_start; - idx_t r_start; -}; - -struct LocalSortState { -public: - LocalSortState(); - - //! Initialize the layouts and RowDataCollections - void Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p); - //! Sink one DataChunk into the local sort state - void SinkChunk(DataChunk &sort, DataChunk &payload); - //! Size of accumulated data in bytes - idx_t SizeInBytes() const; - //! Sort the data accumulated so far - void Sort(GlobalSortState &global_sort_state, bool reorder_heap); - //! Concatenate the blocks held by a RowDataCollection into a single block - static unique_ptr ConcatenateBlocks(RowDataCollection &row_data); - -private: - //! Sorts the data in the newly created SortedBlock - void SortInMemory(); - //! Re-order the local state after sorting - void ReOrder(GlobalSortState &gstate, bool reorder_heap); - //! Re-order a SortedData object after sorting - void ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap); - -public: - //! Whether this local state has been initialized - bool initialized; - //! The buffer manager - BufferManager *buffer_manager; - //! The sorting and payload layouts - const SortLayout *sort_layout; - const RowLayout *payload_layout; - //! Radix/memcmp sortable data - unique_ptr radix_sorting_data; - //! Variable sized sorting data and accompanying heap - unique_ptr blob_sorting_data; - unique_ptr blob_sorting_heap; - //! Payload data and accompanying heap - unique_ptr payload_data; - unique_ptr payload_heap; - //! Sorted data - vector> sorted_blocks; - -private: - //! Selection vector and addresses for scattering the data to rows - const SelectionVector &sel_ptr = *FlatVector::IncrementalSelectionVector(); - Vector addresses = Vector(LogicalType::POINTER); -}; - -struct MergeSorter { -public: - MergeSorter(GlobalSortState &state, BufferManager &buffer_manager); - - //! Finds and merges partitions until the current cascaded merge round is finished - void PerformInMergeRound(); - -private: - //! The global sorting state - GlobalSortState &state; - //! The sorting and payload layouts - BufferManager &buffer_manager; - const SortLayout &sort_layout; - - //! The left and right reader - unique_ptr left; - unique_ptr right; - - //! Input and output blocks - unique_ptr left_input; - unique_ptr right_input; - SortedBlock *result; - -private: - //! Computes the left and right block that will be merged next (Merge Path partition) - void GetNextPartition(); - //! Finds the boundary of the next partition using binary search - void GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx); - //! Compare values within SortedBlocks using a global index - int CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx); - - //! Finds the next partition and merges it - void MergePartition(); - - //! Computes how the next 'count' tuples should be merged by setting the 'left_smaller' array - void ComputeMerge(const idx_t &count, bool left_smaller[]); - - //! Merges the radix sorting blocks according to the 'left_smaller' array - void MergeRadix(const idx_t &count, const bool left_smaller[]); - //! Merges SortedData according to the 'left_smaller' array - void MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices); - //! Merges constant size rows according to the 'left_smaller' array - void MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, idx_t &r_entry_idx, - const idx_t &r_count, RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, - const bool left_smaller[], idx_t &copied, const idx_t &count); - //! Flushes constant size rows into the result - void FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count); - //! Flushes blob rows and accompanying heap - void FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, BufferHandle &target_heap_handle, - data_ptr_t &target_heap_ptr, idx_t &copied, const idx_t &count); -}; - -struct SBIterator { - static int ComparisonValue(ExpressionType comparison); - - SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p = 0); - - inline idx_t GetIndex() const { - return entry_idx; - } - - inline void SetIndex(idx_t entry_idx_p) { - const auto new_block_idx = entry_idx_p / block_capacity; - if (new_block_idx != scan.block_idx) { - scan.SetIndices(new_block_idx, 0); - if (new_block_idx < block_count) { - scan.PinRadix(scan.block_idx); - block_ptr = scan.RadixPtr(); - if (!all_constant) { - scan.PinData(*scan.sb->blob_sorting_data); - } - } - } - - scan.entry_idx = entry_idx_p % block_capacity; - entry_ptr = block_ptr + scan.entry_idx * entry_size; - entry_idx = entry_idx_p; - } - - inline SBIterator &operator++() { - if (++scan.entry_idx < block_capacity) { - entry_ptr += entry_size; - ++entry_idx; - } else { - SetIndex(entry_idx + 1); - } - - return *this; - } - - inline SBIterator &operator--() { - if (scan.entry_idx) { - --scan.entry_idx; - --entry_idx; - entry_ptr -= entry_size; - } else { - SetIndex(entry_idx - 1); - } - - return *this; - } - - inline bool Compare(const SBIterator &other, const SortLayout &prefix) const { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(entry_ptr, other.entry_ptr, prefix.comparison_size); - } else { - comp_res = Comparators::CompareTuple(scan, other.scan, entry_ptr, other.entry_ptr, prefix, external); - } - - return comp_res <= cmp; - } - - inline bool Compare(const SBIterator &other) const { - return Compare(other, sort_layout); - } - - // Fixed comparison parameters - const SortLayout &sort_layout; - const idx_t block_count; - const idx_t block_capacity; - const size_t entry_size; - const bool all_constant; - const bool external; - const int cmp; - - // Iteration state - SBScanState scan; - idx_t entry_idx; - data_ptr_t block_ptr; - data_ptr_t entry_ptr; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp b/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp deleted file mode 100644 index b6941bda2..000000000 --- a/src/duckdb/src/include/duckdb/common/sort/sorted_block.hpp +++ /dev/null @@ -1,165 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/sort/sorted_block.hpp -// -// -//===----------------------------------------------------------------------===// -#pragma once - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" - -namespace duckdb { - -class BufferManager; -struct RowDataBlock; -struct SortLayout; -struct GlobalSortState; - -enum class SortedDataType { BLOB, PAYLOAD }; - -//! Object that holds sorted rows, and an accompanying heap if there are blobs -struct SortedData { -public: - SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, GlobalSortState &state); - //! Number of rows that this object holds - idx_t Count(); - //! Initialize new block to write to - void CreateBlock(); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index); - //! Unswizzles all - void Unswizzle(); - -public: - const SortedDataType type; - //! Layout of this data - const RowLayout layout; - //! Data and heap blocks - vector> data_blocks; - vector> heap_blocks; - //! Whether the pointers in this sorted data are swizzled - bool swizzled; - -private: - //! The buffer manager - BufferManager &buffer_manager; - //! The global state - GlobalSortState &state; -}; - -//! Block that holds sorted rows: radix, blob and payload data -struct SortedBlock { -public: - SortedBlock(BufferManager &buffer_manager, GlobalSortState &gstate); - //! Number of rows that this object holds - idx_t Count() const; - //! Initialize this block to write data to - void InitializeWrite(); - //! Init new block to write to - void CreateBlock(); - //! Fill this sorted block by appending the blocks held by a vector of sorted blocks - void AppendSortedBlocks(vector> &sorted_blocks); - //! Locate the block and entry index of a row in this block, - //! given an index between 0 and the total number of rows in this block - void GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index); - //! Create a slice that holds the rows between the start and end indices - unique_ptr CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx); - - //! Size (in bytes) of the heap of this block - idx_t HeapSize() const; - //! Total size (in bytes) of this block - idx_t SizeInBytes() const; - -public: - //! Radix/memcmp sortable data - vector> radix_sorting_data; - //! Variable sized sorting data - unique_ptr blob_sorting_data; - //! Payload data - unique_ptr payload_data; - -private: - //! Buffer manager, global state, and sorting layout constants - BufferManager &buffer_manager; - GlobalSortState &state; - const SortLayout &sort_layout; - const RowLayout &payload_layout; -}; - -//! State used to scan a SortedBlock e.g. during merge sort -struct SBScanState { -public: - SBScanState(BufferManager &buffer_manager, GlobalSortState &state); - - void PinRadix(idx_t block_idx_to); - void PinData(SortedData &sd); - - data_ptr_t RadixPtr() const; - data_ptr_t DataPtr(SortedData &sd) const; - data_ptr_t HeapPtr(SortedData &sd) const; - data_ptr_t BaseHeapPtr(SortedData &sd) const; - - idx_t Remaining() const; - - void SetIndices(idx_t block_idx_to, idx_t entry_idx_to); - -public: - BufferManager &buffer_manager; - const SortLayout &sort_layout; - GlobalSortState &state; - - SortedBlock *sb; - - idx_t block_idx; - idx_t entry_idx; - - BufferHandle radix_handle; - - BufferHandle blob_sorting_data_handle; - BufferHandle blob_sorting_heap_handle; - - BufferHandle payload_data_handle; - BufferHandle payload_heap_handle; -}; - -//! Used to scan the data into DataChunks after sorting -struct PayloadScanner { -public: - PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush = true); - explicit PayloadScanner(GlobalSortState &global_sort_state, bool flush = true); - - //! Scan a single block - PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush = false); - - //! The type layout of the payload - inline const vector &GetPayloadTypes() const { - return scanner->GetTypes(); - } - - //! The number of rows scanned so far - inline idx_t Scanned() const { - return scanner->Scanned(); - } - - //! The number of remaining rows - inline idx_t Remaining() const { - return scanner->Remaining(); - } - - //! Scans the next data chunk from the sorted data - void Scan(DataChunk &chunk); - -private: - //! The sorted data being scanned - unique_ptr rows; - unique_ptr heap; - //! The actual scanner - unique_ptr scanner; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/string_map_set.hpp b/src/duckdb/src/include/duckdb/common/string_map_set.hpp index 00600c421..40bd51171 100644 --- a/src/duckdb/src/include/duckdb/common/string_map_set.hpp +++ b/src/duckdb/src/include/duckdb/common/string_map_set.hpp @@ -28,9 +28,26 @@ struct StringEquality { } }; +struct StringCIHash { + std::size_t operator()(const string_t &k) const { + return StringUtil::CIHash(k.GetData(), k.GetSize()); + } +}; + +struct StringCIEquality { + bool operator()(const string_t &a, const string_t &b) const { + return StringUtil::CIEquals(a.GetData(), a.GetSize(), b.GetData(), b.GetSize()); + } +}; + template using string_map_t = unordered_map; using string_set_t = unordered_set; +template +using case_insensitive_string_map_t = unordered_map; + +using case_insensitive_string_set_t = unordered_set; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/string_util.hpp b/src/duckdb/src/include/duckdb/common/string_util.hpp index 1448c559a..85c24a1d6 100644 --- a/src/duckdb/src/include/duckdb/common/string_util.hpp +++ b/src/duckdb/src/include/duckdb/common/string_util.hpp @@ -217,6 +217,7 @@ class StringUtil { //! Case insensitive hash DUCKDB_API static uint64_t CIHash(const string &str); + DUCKDB_API static uint64_t CIHash(const char *str, idx_t size); //! Case insensitive equals DUCKDB_API static bool CIEquals(const string &l1, const string &l2); diff --git a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp index 3720bf844..9fa5d447b 100644 --- a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp @@ -129,38 +129,38 @@ class Hugeint { static int Sign(hugeint_t n); static hugeint_t Abs(hugeint_t n); // comparison operators - static bool Equals(hugeint_t lhs, hugeint_t rhs) { + static bool Equals(const hugeint_t &lhs, const hugeint_t &rhs) { bool lower_equals = lhs.lower == rhs.lower; bool upper_equals = lhs.upper == rhs.upper; return lower_equals && upper_equals; } - static bool NotEquals(hugeint_t lhs, hugeint_t rhs) { + static bool NotEquals(const hugeint_t &lhs, const hugeint_t &rhs) { return !Equals(lhs, rhs); } - static bool GreaterThan(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger = lhs.lower > rhs.lower; return upper_bigger || (upper_equal && lower_bigger); } - static bool GreaterThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool GreaterThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_bigger = lhs.upper > rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_bigger_equals = lhs.lower >= rhs.lower; return upper_bigger || (upper_equal && lower_bigger_equals); } - static bool LessThan(hugeint_t lhs, hugeint_t rhs) { + static bool LessThan(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller = lhs.lower < rhs.lower; return upper_smaller || (upper_equal && lower_smaller); } - static bool LessThanEquals(hugeint_t lhs, hugeint_t rhs) { + static bool LessThanEquals(const hugeint_t &lhs, const hugeint_t &rhs) { bool upper_smaller = lhs.upper < rhs.upper; bool upper_equal = lhs.upper == rhs.upper; bool lower_smaller_equals = lhs.lower <= rhs.lower; diff --git a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp index ceb5637ac..5575e5a08 100644 --- a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp @@ -108,6 +108,7 @@ struct SelectionVector { return selection_data; } buffer_ptr Slice(const SelectionVector &sel, idx_t count) const; + idx_t SliceInPlace(const SelectionVector &sel, idx_t count); string ToString(idx_t count = 0) const; void Print(idx_t count = 0) const; diff --git a/src/duckdb/src/include/duckdb/common/types/variant.hpp b/src/duckdb/src/include/duckdb/common/types/variant.hpp index 0d5917892..bef2f2353 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant.hpp @@ -29,9 +29,18 @@ struct VariantNestedData { }; struct VariantDecimalData { +public: + VariantDecimalData(uint32_t width, uint32_t scale, const_data_ptr_t value_ptr) + : width(width), scale(scale), value_ptr(value_ptr) { + } + +public: + PhysicalType GetPhysicalType() const; + +public: uint32_t width; uint32_t scale; - const_data_ptr_t value_ptr; + const_data_ptr_t value_ptr = nullptr; }; struct VariantVectorData { diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp index f46496c9a..24382e8d9 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_asof_join.hpp @@ -37,9 +37,6 @@ class PhysicalAsOfJoin : public PhysicalComparisonJoin { // Projection mappings vector right_projection_map; - // Predicate (join conditions that don't reference both sides) - unique_ptr predicate; - protected: // CachingOperator Interface OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp index 25ed9ed06..2cdff1374 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_nested_loop_join.hpp @@ -18,13 +18,16 @@ class PhysicalNestedLoopJoin : public PhysicalComparisonJoin { static constexpr const PhysicalOperatorType TYPE = PhysicalOperatorType::NESTED_LOOP_JOIN; public: - PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, + PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, PhysicalOperator &right, vector cond, JoinType join_type, idx_t estimated_cardinality, unique_ptr pushdown_info); - PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalOperator &op, PhysicalOperator &left, + PhysicalNestedLoopJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperator &left, PhysicalOperator &right, vector cond, JoinType join_type, idx_t estimated_cardinality); + // Predicate (join conditions that don't reference both sides) + unique_ptr predicate; + public: // Operator Interface unique_ptr GetOperatorState(ExecutionContext &context) const override; diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp index d8605eb1f..3e90c4365 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp @@ -70,10 +70,10 @@ struct VariantUtils { uint32_t value_index); DUCKDB_API static vector GetObjectKeys(const UnifiedVariantVectorData &variant, idx_t row, const VariantNestedData &nested_data); - DUCKDB_API static VariantChildDataCollectionResult FindChildValues(const UnifiedVariantVectorData &variant, - const VariantPathComponent &component, - optional_idx row, SelectionVector &res, - VariantNestedData *nested_data, idx_t count); + DUCKDB_API static void FindChildValues(const UnifiedVariantVectorData &variant, + const VariantPathComponent &component, + optional_ptr sel, SelectionVector &res, + ValidityMask &res_validity, VariantNestedData *nested_data, idx_t count); DUCKDB_API static VariantNestedDataCollectionResult CollectNestedData(const UnifiedVariantVectorData &variant, VariantLogicalType expected_type, const SelectionVector &sel, idx_t count, optional_idx row, idx_t offset, diff --git a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp index 2ce10061a..899a331cb 100644 --- a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp @@ -554,6 +554,11 @@ typedef struct { // New string functions that are added char *(*duckdb_value_to_string)(duckdb_value value); + // New functions around the table description + + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); // New functions around table function binding void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -1044,6 +1049,8 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_scalar_function_bind_get_argument = duckdb_scalar_function_bind_get_argument; result.duckdb_scalar_function_set_bind_data_copy = duckdb_scalar_function_set_bind_data_copy; result.duckdb_value_to_string = duckdb_value_to_string; + result.duckdb_table_description_get_column_count = duckdb_table_description_get_column_count; + result.duckdb_table_description_get_column_type = duckdb_table_description_get_column_type; result.duckdb_table_function_get_client_context = duckdb_table_function_get_client_context; result.duckdb_create_map_value = duckdb_create_map_value; result.duckdb_create_union_value = duckdb_create_union_value; diff --git a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp index 90d028035..79290367d 100644 --- a/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/database_file_path_manager.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" +#include "duckdb/common/enums/access_mode.hpp" namespace duckdb { struct AttachInfo; @@ -20,11 +21,14 @@ struct AttachOptions; enum class InsertDatabasePathResult { SUCCESS, ALREADY_EXISTS }; struct DatabasePathInfo { - explicit DatabasePathInfo(string name_p) : name(std::move(name_p)), is_attached(true) { + explicit DatabasePathInfo(string name_p, AccessMode access_mode) + : name(std::move(name_p)), access_mode(access_mode), is_attached(true) { } string name; + AccessMode access_mode; bool is_attached; + idx_t reference_count = 1; }; //! The DatabaseFilePathManager is used to ensure we only ever open a single database file once diff --git a/src/duckdb/src/include/duckdb/main/relation.hpp b/src/duckdb/src/include/duckdb/main/relation.hpp index 9d9e67686..bc383ffe0 100644 --- a/src/duckdb/src/include/duckdb/main/relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation.hpp @@ -78,7 +78,8 @@ class Relation : public enable_shared_from_this { public: DUCKDB_API virtual const vector &Columns() = 0; - DUCKDB_API virtual unique_ptr GetQueryNode(); + DUCKDB_API virtual unique_ptr GetQueryNode() = 0; + DUCKDB_API virtual string GetQuery(); DUCKDB_API virtual BoundStatement Bind(Binder &binder); DUCKDB_API virtual string GetAlias(); diff --git a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp index 7d5462941..8df59b8d2 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_table_relation.hpp @@ -26,6 +26,8 @@ class CreateTableRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp index cb826a86c..aa09b0def 100644 --- a/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/create_view_relation.hpp @@ -26,6 +26,8 @@ class CreateViewRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp index c07445ba4..0c25c6576 100644 --- a/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/delete_relation.hpp @@ -26,6 +26,8 @@ class DeleteRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp index 888583b2b..96be08d8f 100644 --- a/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/explain_relation.hpp @@ -24,6 +24,8 @@ class ExplainRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp index 3695cde7b..fccb0ae92 100644 --- a/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/insert_relation.hpp @@ -23,6 +23,8 @@ class InsertRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp index bdb035652..b1be001b9 100644 --- a/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/query_relation.hpp @@ -28,6 +28,7 @@ class QueryRelation : public Relation { public: static unique_ptr ParseStatement(ClientContext &context, const string &query, const string &error); unique_ptr GetQueryNode() override; + string GetQuery() override; unique_ptr GetTableRef() override; BoundStatement Bind(Binder &binder) override; diff --git a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp index 58ad203b2..91eac246e 100644 --- a/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/update_relation.hpp @@ -29,6 +29,8 @@ class UpdateRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp index 99d2ebe8e..cf0853ff3 100644 --- a/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/write_csv_relation.hpp @@ -23,6 +23,8 @@ class WriteCSVRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp index d32089212..138eee7c7 100644 --- a/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp +++ b/src/duckdb/src/include/duckdb/main/relation/write_parquet_relation.hpp @@ -24,6 +24,8 @@ class WriteParquetRelation : public Relation { public: BoundStatement Bind(Binder &binder) override; + unique_ptr GetQueryNode() override; + string GetQuery() override; const vector &Columns() override; string ToString(idx_t depth) override; bool IsReadOnly() override { diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 1c96771ee..6817f8e86 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -657,7 +657,7 @@ struct ExperimentalMetadataReuseSetting { static constexpr const char *Name = "experimental_metadata_reuse"; static constexpr const char *Description = "EXPERIMENTAL: Re-use row group and table metadata when checkpointing."; static constexpr const char *InputType = "BOOLEAN"; - static constexpr const char *DefaultValue = "false"; + static constexpr const char *DefaultValue = "true"; static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; diff --git a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp index 5e920ca2d..29c2f0ac4 100644 --- a/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/filter_pushdown.hpp @@ -104,8 +104,9 @@ class FilterPushdown { void ExtractFilterBindings(const Expression &expr, vector &bindings); //! Generate filters from the current set of filters stored in the FilterCombiner void GenerateFilters(); - //! if there are filters in this FilterPushdown node, push them into the combiner - void PushFilters(); + //! if there are filters in this FilterPushdown node, push them into the combiner. Returns + //! FilterResult::UNSATISFIABLE if the subtree should be stripped, or FilterResult::SUCCESS otherwise + FilterResult PushFilters(); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/query_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node.hpp index 956bd63f7..5c091b259 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node.hpp @@ -60,8 +60,6 @@ class QueryNode { //! CTEs (used by SelectNode and SetOperationNode) CommonTableExpressionMap cte_map; - virtual const vector> &GetSelectList() const = 0; - public: //! Convert the query node to a string virtual string ToString() const = 0; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp index bc997a6c7..4ad41748b 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp @@ -32,10 +32,6 @@ class CTENode : public QueryNode { CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - const vector> &GetSelectList() const override { - return query->GetSelectList(); - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp index 6d73fda4a..1f5f16ead 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/recursive_cte_node.hpp @@ -33,10 +33,6 @@ class RecursiveCTENode : public QueryNode { //! targets for key variants vector> key_targets; - const vector> &GetSelectList() const override { - return left->GetSelectList(); - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp index 62aa9c0b2..dfc474d14 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/select_node.hpp @@ -43,10 +43,6 @@ class SelectNode : public QueryNode { //! The SAMPLE clause unique_ptr sample; - const vector> &GetSelectList() const override { - return select_list; - } - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp index 960f6c2d6..3070e2245 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/set_operation_node.hpp @@ -29,8 +29,6 @@ class SetOperationNode : public QueryNode { //! The children of the set operation vector> children; - const vector> &GetSelectList() const override; - public: //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp index 9e813335c..26db46a58 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/statement_node.hpp @@ -24,7 +24,6 @@ class StatementNode : public QueryNode { SQLStatement &stmt; public: - const vector> &GetSelectList() const override; //! Convert the query node to a string string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 5603ab36b..bd93bc8b8 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -419,8 +419,6 @@ class Binder : public enable_shared_from_this { BoundStatement BindNode(StatementNode &node); unique_ptr VisitQueryNode(BoundQueryNode &node, unique_ptr root); - unique_ptr CreatePlan(BoundRecursiveCTENode &node); - unique_ptr CreatePlan(BoundCTENode &node); unique_ptr CreatePlan(BoundSelectNode &statement); unique_ptr CreatePlan(BoundSetOperationNode &node); unique_ptr CreatePlan(BoundQueryNode &node); diff --git a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp index cd5a78b6a..76c461e78 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_query_node.hpp @@ -17,13 +17,8 @@ namespace duckdb { //! Bound equivalent of QueryNode class BoundQueryNode { public: - explicit BoundQueryNode(QueryNodeType type) : type(type) { - } - virtual ~BoundQueryNode() { - } + virtual ~BoundQueryNode() = default; - //! The type of the query node, either SetOperation or Select - QueryNodeType type; //! The result modifiers that should be applied to this query node vector> modifiers; @@ -34,23 +29,6 @@ class BoundQueryNode { public: virtual idx_t GetRootIndex() = 0; - -public: - template - TARGET &Cast() { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (type != TARGET::TYPE) { - throw InternalException("Failed to cast bound query node to type - query node type mismatch"); - } - return reinterpret_cast(*this); - } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp index ac8aef099..862ef5a11 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp @@ -16,8 +16,6 @@ namespace duckdb { class BoundQueryNode; class BoundSelectNode; class BoundSetOperationNode; -class BoundRecursiveCTENode; -class BoundCTENode; //===--------------------------------------------------------------------===// // Expressions diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp deleted file mode 100644 index 67c076ab6..000000000 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_cte_node.hpp +++ /dev/null @@ -1,46 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -class BoundCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; - -public: - BoundCTENode() : BoundQueryNode(QueryNodeType::CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - //! The cte node - BoundStatement query; - //! The child node - BoundStatement child; - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the query side of the CTE - shared_ptr query_binder; - //! The binder used by the child side of the CTE - shared_ptr child_binder; - - CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; - -public: - idx_t GetRootIndex() override { - return child.plan->GetRootIndex(); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp deleted file mode 100644 index 6a1819464..000000000 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_recursive_cte_node.hpp +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/query_node/bound_recursive_cte_node.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_query_node.hpp" - -namespace duckdb { - -//! Bound equivalent of SetOperationNode -class BoundRecursiveCTENode : public BoundQueryNode { -public: - static constexpr const QueryNodeType TYPE = QueryNodeType::RECURSIVE_CTE_NODE; - -public: - BoundRecursiveCTENode() : BoundQueryNode(QueryNodeType::RECURSIVE_CTE_NODE) { - } - - //! Keep track of the CTE name this node represents - string ctename; - - bool union_all; - //! The left side of the set operation - BoundStatement left; - //! The right side of the set operation - BoundStatement right; - //! Target columns for the recursive key variant - vector> key_targets; - - //! Index used by the set operation - idx_t setop_index; - //! The binder used by the left side of the set operation - shared_ptr left_binder; - //! The binder used by the right side of the set operation - shared_ptr right_binder; - -public: - idx_t GetRootIndex() override { - return setop_index; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp index d94195698..3fdc186e9 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_select_node.hpp @@ -35,12 +35,6 @@ struct BoundUnnestNode { //! Bound equivalent of SelectNode class BoundSelectNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SELECT_NODE; - -public: - BoundSelectNode() : BoundQueryNode(QueryNodeType::SELECT_NODE) { - } - //! Bind information SelectBindState bind_state; //! The projection list diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp index 391ca26e6..0939da695 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp @@ -18,13 +18,6 @@ struct BoundSetOpChild; //! Bound equivalent of SetOperationNode class BoundSetOperationNode : public BoundQueryNode { public: - static constexpr const QueryNodeType TYPE = QueryNodeType::SET_OPERATION_NODE; - -public: - BoundSetOperationNode() : BoundQueryNode(QueryNodeType::SET_OPERATION_NODE) { - } - ~BoundSetOperationNode() override; - //! The type of set operation SetOperationType setop_type = SetOperationType::NONE; //! whether the ALL modifier was used or not diff --git a/src/duckdb/src/include/duckdb/planner/query_node/list.hpp b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp index 5c7dbda94..dcac81248 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/list.hpp @@ -1,4 +1,2 @@ -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" diff --git a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp index 99e949418..93cec496d 100644 --- a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp +++ b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp @@ -61,7 +61,8 @@ struct CachingFileHandle { //! Tries to read from the cache, filling "overlapping_ranges" with ranges that overlap with the request. //! Returns an invalid BufferHandle if it fails BufferHandle TryReadFromCache(data_ptr_t &buffer, idx_t nr_bytes, idx_t location, - vector> &overlapping_ranges); + vector> &overlapping_ranges, + optional_idx &start_location_of_next_range); //! Try to read from the specified range, return an invalid BufferHandle if it fails BufferHandle TryReadFromFileRange(const unique_ptr &guard, CachedFileRange &file_range, data_ptr_t &buffer, idx_t nr_bytes, idx_t location); diff --git a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp index f4c943a79..c246d68b6 100644 --- a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp @@ -48,10 +48,10 @@ class ArrayColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp index b688f8ed7..ab8a5970e 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp @@ -156,10 +156,10 @@ class ColumnData { virtual void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx); - virtual void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count); - virtual void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth); + virtual void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count); + virtual void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth); virtual unique_ptr GetUpdateStatistics(); virtual void CommitDropColumn(); @@ -220,8 +220,8 @@ class ColumnData { void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result, idx_t scan_count, bool allow_updates, bool scan_committed); void FetchUpdateRow(TransactionData transaction, row_t row_id, Vector &result, idx_t result_idx); - void UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector); + void UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count, Vector &base_vector); idx_t FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector &base_vector); idx_t GetVectorCount(idx_t vector_index) const; diff --git a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp index 98d8c662b..621ece451 100644 --- a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp @@ -46,10 +46,10 @@ class ListColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp index d003d1378..62b54eeed 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp @@ -171,12 +171,12 @@ class RowGroup : public SegmentBase { void InitializeAppend(RowGroupAppendState &append_state); void Append(RowGroupAppendState &append_state, DataChunk &chunk, idx_t append_count); - void Update(TransactionData transaction, DataChunk &updates, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids); + void Update(TransactionData transaction, DataTable &data_table, DataChunk &updates, row_t *ids, idx_t offset, + idx_t count, const vector &column_ids); //! Update a single column; corresponds to DataTable::UpdateColumn //! This method should only be called from the WAL - void UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path); + void UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path); void MergeStatistics(idx_t column_idx, const BaseStatistics &other); void MergeIntoStatistics(idx_t column_idx, BaseStatistics &other); diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp index e5c745829..80aa52668 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp @@ -36,6 +36,7 @@ struct CollectionCheckpointState; struct PersistentCollectionData; class CheckpointTask; class TableIOManager; +class DataTable; class RowGroupCollection { public: @@ -101,9 +102,10 @@ class RowGroupCollection { void RemoveFromIndexes(const QueryContext &context, TableIndexList &indexes, Vector &row_identifiers, idx_t count); idx_t Delete(TransactionData transaction, DataTable &table, row_t *ids, idx_t count); - void Update(TransactionData transaction, row_t *ids, const vector &column_ids, DataChunk &updates); - void UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates); + void Update(TransactionData transaction, DataTable &table, row_t *ids, const vector &column_ids, + DataChunk &updates); + void UpdateColumn(TransactionData transaction, DataTable &table, Vector &row_ids, + const vector &column_path, DataChunk &updates); void Checkpoint(TableDataWriter &writer, TableStatistics &global_stats); diff --git a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp index 3bc6572ae..f839c6b24 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_id_column_data.hpp @@ -48,10 +48,10 @@ class RowIdColumnData : public ColumnData { void AppendData(BaseStatistics &stats, ColumnAppendState &state, UnifiedVectorFormat &vdata, idx_t count) override; void RevertAppend(row_t start_row) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; void CommitDropColumn() override; diff --git a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp index 8d233139a..ec06eb30a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp @@ -47,10 +47,10 @@ class StandardColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp index 91c7f1e19..798a21326 100644 --- a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp @@ -46,10 +46,10 @@ class StructColumnData : public ColumnData { idx_t Fetch(ColumnScanState &state, row_t row_id, Vector &result) override; void FetchRow(TransactionData transaction, ColumnFetchState &state, row_t row_id, Vector &result, idx_t result_idx) override; - void Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) override; - void UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) override; + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) override; + void UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) override; unique_ptr GetUpdateStatistics() override; void CommitDropColumn() override; diff --git a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp index 75cf25ecf..3f5b9d211 100644 --- a/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/update_segment.hpp @@ -38,8 +38,8 @@ class UpdateSegment { void FetchUpdates(TransactionData transaction, idx_t vector_index, Vector &result); void FetchCommitted(idx_t vector_index, Vector &result); void FetchCommittedRange(idx_t start_row, idx_t count, Vector &result); - void Update(TransactionData transaction, idx_t column_index, Vector &update, row_t *ids, idx_t count, - Vector &base_data); + void Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update, row_t *ids, + idx_t count, Vector &base_data); void FetchRow(TransactionData transaction, idx_t row_id, Vector &result, idx_t result_idx); void RollbackUpdate(UpdateInfo &info); diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp index 12c4d180c..79e0567b3 100644 --- a/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction.hpp @@ -76,7 +76,7 @@ class DuckTransaction : public Transaction { idx_t base_row); void PushSequenceUsage(SequenceCatalogEntry &entry, const SequenceData &data); void PushAppend(DataTable &table, idx_t row_start, idx_t row_count); - UndoBufferReference CreateUpdateInfo(idx_t type_size, idx_t entries); + UndoBufferReference CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries); bool IsDuckTransaction() const override { return true; @@ -90,6 +90,7 @@ class DuckTransaction : public Transaction { //! Get a shared lock on a table shared_ptr SharedLockTable(DataTableInfo &info); + //! Hold an owning reference of the table, needed to safely reference it inside the transaction commit/undo logic void ModifyTable(DataTable &tbl); private: diff --git a/src/duckdb/src/include/duckdb/transaction/update_info.hpp b/src/duckdb/src/include/duckdb/transaction/update_info.hpp index 7cccd923e..5eb139261 100644 --- a/src/duckdb/src/include/duckdb/transaction/update_info.hpp +++ b/src/duckdb/src/include/duckdb/transaction/update_info.hpp @@ -17,6 +17,7 @@ namespace duckdb { class UpdateSegment; struct DataTableInfo; +class DataTable; //! UpdateInfo is a class that represents a set of updates applied to a single vector. //! The UpdateInfo struct contains metadata associated with the update. @@ -26,6 +27,8 @@ struct DataTableInfo; struct UpdateInfo { //! The update segment that this update info affects UpdateSegment *segment; + //! The table this was update was made on + DataTable *table; //! The column index of which column we are updating idx_t column_index; //! The version number @@ -87,7 +90,7 @@ struct UpdateInfo { //! Returns the total allocation size for an UpdateInfo entry, together with space for the tuple data static idx_t GetAllocSize(idx_t type_size); //! Initialize an UpdateInfo struct that has been allocated using GetAllocSize (i.e. has extra space after it) - static void Initialize(UpdateInfo &info, transaction_t transaction_id); + static void Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp b/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp index aad1a672c..4c68da487 100644 --- a/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp +++ b/src/duckdb/src/include/duckdb/transaction/wal_write_state.hpp @@ -31,7 +31,7 @@ class WALWriteState { void CommitEntry(UndoFlags type, data_ptr_t data); private: - void SwitchTable(DataTableInfo *table, UndoFlags new_op); + void SwitchTable(DataTableInfo &table, UndoFlags new_op); void WriteCatalogEntry(CatalogEntry &entry, data_ptr_t extra_data); void WriteDelete(DeleteInfo &info); diff --git a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp index a60abf187..77fed9815 100644 --- a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp +++ b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp @@ -85,6 +85,8 @@ class StatementVerifier { private: const vector> empty_select_list = {}; + + const vector> &GetSelectList(QueryNode &node); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb_extension.h b/src/duckdb/src/include/duckdb_extension.h index 7c5136059..f014be548 100644 --- a/src/duckdb/src/include/duckdb_extension.h +++ b/src/duckdb/src/include/duckdb_extension.h @@ -643,6 +643,13 @@ typedef struct { char *(*duckdb_value_to_string)(duckdb_value value); #endif +// New functions around the table description +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + idx_t (*duckdb_table_description_get_column_count)(duckdb_table_description table_description); + duckdb_logical_type (*duckdb_table_description_get_column_type)(duckdb_table_description table_description, + idx_t index); +#endif + // New functions around table function binding #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); @@ -1164,6 +1171,10 @@ typedef struct { // Version unstable_new_string_functions #define duckdb_value_to_string duckdb_ext_api.duckdb_value_to_string +// Version unstable_new_table_description_functions +#define duckdb_table_description_get_column_count duckdb_ext_api.duckdb_table_description_get_column_count +#define duckdb_table_description_get_column_type duckdb_ext_api.duckdb_table_description_get_column_type + // Version unstable_new_table_function_functions #define duckdb_table_function_get_client_context duckdb_ext_api.duckdb_table_function_get_client_context diff --git a/src/duckdb/src/main/capi/table_description-c.cpp b/src/duckdb/src/main/capi/table_description-c.cpp index 26624bbfc..cfcd01c43 100644 --- a/src/duckdb/src/main/capi/table_description-c.cpp +++ b/src/duckdb/src/main/capi/table_description-c.cpp @@ -1,5 +1,5 @@ -#include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/main/capi/capi_internal.hpp" using duckdb::Connection; using duckdb::ErrorData; @@ -68,14 +68,14 @@ const char *duckdb_table_description_error(duckdb_table_description table) { return wrapper->error.c_str(); } -duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, idx_t index) { +duckdb_state GetTableDescription(TableDescriptionWrapper *wrapper, duckdb::optional_idx index) { if (!wrapper) { return DuckDBError; } auto &table = wrapper->description; - if (index >= table->columns.size()) { - wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", index, - table->columns.size()); + if (index.IsValid() && index.GetIndex() >= table->columns.size()) { + wrapper->error = duckdb::StringUtil::Format("Column index %d is out of range, table only has %d columns", + index.GetIndex(), table->columns.size()); return DuckDBError; } return DuckDBSuccess; @@ -97,6 +97,16 @@ duckdb_state duckdb_column_has_default(duckdb_table_description table_descriptio return DuckDBSuccess; } +idx_t duckdb_table_description_get_column_count(duckdb_table_description table_description) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, duckdb::optional_idx()) == DuckDBError) { + return 0; + } + + auto &table = wrapper->description; + return table->columns.size(); +} + char *duckdb_table_description_get_column_name(duckdb_table_description table_description, idx_t index) { auto wrapper = reinterpret_cast(table_description); if (GetTableDescription(wrapper, index) == DuckDBError) { @@ -113,3 +123,16 @@ char *duckdb_table_description_get_column_name(duckdb_table_description table_de return result; } + +duckdb_logical_type duckdb_table_description_get_column_type(duckdb_table_description table_description, idx_t index) { + auto wrapper = reinterpret_cast(table_description); + if (GetTableDescription(wrapper, index) == DuckDBError) { + return nullptr; + } + + auto &table = wrapper->description; + auto &column = table->columns[index]; + + auto logical_type = new duckdb::LogicalType(column.Type()); + return reinterpret_cast(logical_type); +} diff --git a/src/duckdb/src/main/database_file_path_manager.cpp b/src/duckdb/src/main/database_file_path_manager.cpp index f1825780e..0ea9b9ad4 100644 --- a/src/duckdb/src/main/database_file_path_manager.cpp +++ b/src/duckdb/src/main/database_file_path_manager.cpp @@ -18,20 +18,34 @@ InsertDatabasePathResult DatabaseFilePathManager::InsertDatabasePath(const strin } lock_guard path_lock(db_paths_lock); - auto entry = db_paths.emplace(path, DatabasePathInfo(name)); + auto entry = db_paths.emplace(path, DatabasePathInfo(name, options.access_mode)); if (!entry.second) { auto &existing = entry.first->second; + bool already_exists = false; if (on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT && existing.name == name) { - if (existing.is_attached) { + already_exists = true; + } + if (options.access_mode == AccessMode::READ_ONLY && existing.access_mode == AccessMode::READ_ONLY) { + if (already_exists && existing.is_attached) { return InsertDatabasePathResult::ALREADY_EXISTS; } - throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is in " - "the process of being detached", - name, path); + // all attaches are in read-only mode - there is no conflict, just increase the reference count + existing.reference_count++; + } else { + if (already_exists) { + if (existing.is_attached) { + return InsertDatabasePathResult::ALREADY_EXISTS; + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is in " + "the process of being detached", + name, path); + } + throw BinderException( + "Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " + "attached by database \"%s\"", + name, path, existing.name); } - throw BinderException("Unique file handle conflict: Cannot attach \"%s\" - the database file \"%s\" is already " - "attached by database \"%s\"", - name, path, existing.name); } options.stored_database_path = make_uniq(*this, path, name); return InsertDatabasePathResult::SUCCESS; @@ -42,7 +56,14 @@ void DatabaseFilePathManager::EraseDatabasePath(const string &path) { return; } lock_guard path_lock(db_paths_lock); - db_paths.erase(path); + auto entry = db_paths.find(path); + if (entry != db_paths.end()) { + if (entry->second.reference_count <= 1) { + db_paths.erase(entry); + } else { + entry->second.reference_count--; + } + } } void DatabaseFilePathManager::DetachDatabase(const string &path) { diff --git a/src/duckdb/src/main/relation.cpp b/src/duckdb/src/main/relation.cpp index 9a28349e7..b9e4d50ff 100644 --- a/src/duckdb/src/main/relation.cpp +++ b/src/duckdb/src/main/relation.cpp @@ -394,8 +394,8 @@ string Relation::ToString() { } // LCOV_EXCL_START -unique_ptr Relation::GetQueryNode() { - throw InternalException("Cannot create a query node from this node type"); +string Relation::GetQuery() { + return GetQueryNode()->ToString(); } void Relation::Head(idx_t limit) { diff --git a/src/duckdb/src/main/relation/create_table_relation.cpp b/src/duckdb/src/main/relation/create_table_relation.cpp index 2492f244b..39aa65e36 100644 --- a/src/duckdb/src/main/relation/create_table_relation.cpp +++ b/src/duckdb/src/main/relation/create_table_relation.cpp @@ -29,6 +29,14 @@ BoundStatement CreateTableRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateTableRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a create table relation"); +} + +string CreateTableRelation::GetQuery() { + return string(); +} + const vector &CreateTableRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/create_view_relation.cpp b/src/duckdb/src/main/relation/create_view_relation.cpp index c00deef38..6f77f013f 100644 --- a/src/duckdb/src/main/relation/create_view_relation.cpp +++ b/src/duckdb/src/main/relation/create_view_relation.cpp @@ -35,6 +35,14 @@ BoundStatement CreateViewRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr CreateViewRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string CreateViewRelation::GetQuery() { + return string(); +} + const vector &CreateViewRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/delete_relation.cpp b/src/duckdb/src/main/relation/delete_relation.cpp index 64b3f231e..2ec60f664 100644 --- a/src/duckdb/src/main/relation/delete_relation.cpp +++ b/src/duckdb/src/main/relation/delete_relation.cpp @@ -26,6 +26,14 @@ BoundStatement DeleteRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr DeleteRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a delete relation"); +} + +string DeleteRelation::GetQuery() { + return string(); +} + const vector &DeleteRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/explain_relation.cpp b/src/duckdb/src/main/relation/explain_relation.cpp index f91e1d29f..9f2976c9d 100644 --- a/src/duckdb/src/main/relation/explain_relation.cpp +++ b/src/duckdb/src/main/relation/explain_relation.cpp @@ -20,6 +20,14 @@ BoundStatement ExplainRelation::Bind(Binder &binder) { return binder.Bind(explain.Cast()); } +unique_ptr ExplainRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an explain relation"); +} + +string ExplainRelation::GetQuery() { + return string(); +} + const vector &ExplainRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/insert_relation.cpp b/src/duckdb/src/main/relation/insert_relation.cpp index 9728570a0..84ef16ec6 100644 --- a/src/duckdb/src/main/relation/insert_relation.cpp +++ b/src/duckdb/src/main/relation/insert_relation.cpp @@ -24,6 +24,14 @@ BoundStatement InsertRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr InsertRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an insert relation"); +} + +string InsertRelation::GetQuery() { + return string(); +} + const vector &InsertRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/query_relation.cpp b/src/duckdb/src/main/relation/query_relation.cpp index e0cf2e280..ee7ac72d9 100644 --- a/src/duckdb/src/main/relation/query_relation.cpp +++ b/src/duckdb/src/main/relation/query_relation.cpp @@ -49,6 +49,10 @@ unique_ptr QueryRelation::GetQueryNode() { return std::move(select->node); } +string QueryRelation::GetQuery() { + return query; +} + unique_ptr QueryRelation::GetTableRef() { auto subquery_ref = make_uniq(GetSelectStatement(), GetAlias()); return std::move(subquery_ref); diff --git a/src/duckdb/src/main/relation/update_relation.cpp b/src/duckdb/src/main/relation/update_relation.cpp index 9176cf2f2..81d85ca89 100644 --- a/src/duckdb/src/main/relation/update_relation.cpp +++ b/src/duckdb/src/main/relation/update_relation.cpp @@ -35,6 +35,14 @@ BoundStatement UpdateRelation::Bind(Binder &binder) { return binder.Bind(stmt.Cast()); } +unique_ptr UpdateRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from an update relation"); +} + +string UpdateRelation::GetQuery() { + return string(); +} + const vector &UpdateRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/write_csv_relation.cpp b/src/duckdb/src/main/relation/write_csv_relation.cpp index 4795c7a51..f77d6f1ee 100644 --- a/src/duckdb/src/main/relation/write_csv_relation.cpp +++ b/src/duckdb/src/main/relation/write_csv_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteCSVRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteCSVRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write CSV relation"); +} + +string WriteCSVRelation::GetQuery() { + return string(); +} + const vector &WriteCSVRelation::Columns() { return columns; } diff --git a/src/duckdb/src/main/relation/write_parquet_relation.cpp b/src/duckdb/src/main/relation/write_parquet_relation.cpp index d6e403618..b1dfdb29f 100644 --- a/src/duckdb/src/main/relation/write_parquet_relation.cpp +++ b/src/duckdb/src/main/relation/write_parquet_relation.cpp @@ -25,6 +25,14 @@ BoundStatement WriteParquetRelation::Bind(Binder &binder) { return binder.Bind(copy.Cast()); } +unique_ptr WriteParquetRelation::GetQueryNode() { + throw InternalException("Cannot create a query node from a write parquet relation"); +} + +string WriteParquetRelation::GetQuery() { + return string(); +} + const vector &WriteParquetRelation::Columns() { return columns; } diff --git a/src/duckdb/src/optimizer/filter_combiner.cpp b/src/duckdb/src/optimizer/filter_combiner.cpp index 8e4a295b4..ddbe82ab0 100644 --- a/src/duckdb/src/optimizer/filter_combiner.cpp +++ b/src/duckdb/src/optimizer/filter_combiner.cpp @@ -1,5 +1,6 @@ #include "duckdb/optimizer/filter_combiner.hpp" +#include "duckdb/common/enums/expression_type.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/expression.hpp" @@ -907,6 +908,12 @@ FilterResult FilterCombiner::AddTransitiveFilters(BoundComparisonExpression &com idx_t left_equivalence_set = GetEquivalenceSet(left_node); idx_t right_equivalence_set = GetEquivalenceSet(right_node); if (left_equivalence_set == right_equivalence_set) { + if (comparison.GetExpressionType() == ExpressionType::COMPARE_GREATERTHAN || + comparison.GetExpressionType() == ExpressionType::COMPARE_LESSTHAN) { + // non equal comparison has equal equivalence set, then it is unsatisfiable + // e.g., j > i AND i < j is unsatisfiable + return FilterResult::UNSATISFIABLE; + } // this equality filter already exists, prune it return FilterResult::SUCCESS; } diff --git a/src/duckdb/src/optimizer/filter_pushdown.cpp b/src/duckdb/src/optimizer/filter_pushdown.cpp index 4fa17f7d0..7c13386d9 100644 --- a/src/duckdb/src/optimizer/filter_pushdown.cpp +++ b/src/duckdb/src/optimizer/filter_pushdown.cpp @@ -208,17 +208,23 @@ unique_ptr FilterPushdown::PushdownJoin(unique_ptrfilter)); D_ASSERT(result != FilterResult::UNSUPPORTED); - (void)result; + if (result == FilterResult::UNSATISFIABLE) { + // one of the filters is unsatisfiable - abort filter pushdown + return FilterResult::UNSATISFIABLE; + } } filters.clear(); + return FilterResult::SUCCESS; } FilterResult FilterPushdown::AddFilter(unique_ptr expr) { - PushFilters(); + if (PushFilters() == FilterResult::UNSATISFIABLE) { + return FilterResult::UNSATISFIABLE; + } // split up the filters by AND predicate vector> expressions; expressions.push_back(std::move(expr)); diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp index 90dbbb823..ac4b6532a 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_get.cpp @@ -4,6 +4,7 @@ #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_empty_result.hpp" namespace duckdb { unique_ptr FilterPushdown::PushdownGet(unique_ptr op) { @@ -48,7 +49,9 @@ unique_ptr FilterPushdown::PushdownGet(unique_ptr(std::move(op)); + } //! We generate the table filters that will be executed during the table scan vector pushdown_results; diff --git a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp index 24786867b..3a0697e99 100644 --- a/src/duckdb/src/optimizer/rule/regex_optimizations.cpp +++ b/src/duckdb/src/optimizer/rule/regex_optimizations.cpp @@ -184,6 +184,13 @@ unique_ptr RegexOptimizationRule::Apply(LogicalOperator &op, vector< if (!escaped_like_string.exists) { return nullptr; } + + // if regexp had options, remove them so the new Contains Expression can be matched for other optimizers. + if (root.children.size() == 3) { + root.children.pop_back(); + D_ASSERT(root.children.size() == 2); + } + auto parameter = make_uniq(Value(std::move(escaped_like_string.like_string))); auto contains = make_uniq(root.return_type, GetStringContains(), std::move(root.children), nullptr); diff --git a/src/duckdb/src/parser/query_node/set_operation_node.cpp b/src/duckdb/src/parser/query_node/set_operation_node.cpp index a8b624f21..cdc188820 100644 --- a/src/duckdb/src/parser/query_node/set_operation_node.cpp +++ b/src/duckdb/src/parser/query_node/set_operation_node.cpp @@ -8,10 +8,6 @@ namespace duckdb { SetOperationNode::SetOperationNode() : QueryNode(QueryNodeType::SET_OPERATION_NODE) { } -const vector> &SetOperationNode::GetSelectList() const { - return children[0]->GetSelectList(); -} - string SetOperationNode::ToString() const { string result; result = cte_map.ToString(); diff --git a/src/duckdb/src/parser/query_node/statement_node.cpp b/src/duckdb/src/parser/query_node/statement_node.cpp index e27b2e6c0..66e7b8e5a 100644 --- a/src/duckdb/src/parser/query_node/statement_node.cpp +++ b/src/duckdb/src/parser/query_node/statement_node.cpp @@ -5,9 +5,6 @@ namespace duckdb { StatementNode::StatementNode(SQLStatement &stmt_p) : QueryNode(QueryNodeType::STATEMENT_NODE), stmt(stmt_p) { } -const vector> &StatementNode::GetSelectList() const { - throw InternalException("StatementNode has no select list"); -} //! Convert the query node to a string string StatementNode::ToString() const { return stmt.ToString(); diff --git a/src/duckdb/src/parser/statement/relation_statement.cpp b/src/duckdb/src/parser/statement/relation_statement.cpp index 9b3801495..023d3cac9 100644 --- a/src/duckdb/src/parser/statement/relation_statement.cpp +++ b/src/duckdb/src/parser/statement/relation_statement.cpp @@ -5,10 +5,7 @@ namespace duckdb { RelationStatement::RelationStatement(shared_ptr relation_p) : SQLStatement(StatementType::RELATION_STATEMENT), relation(std::move(relation_p)) { - if (relation->type == RelationType::QUERY_RELATION) { - auto &query_relation = relation->Cast(); - query = query_relation.query; - } + query = relation->GetQuery(); } unique_ptr RelationStatement::Copy() const { diff --git a/src/duckdb/src/parser/transform/expression/transform_subquery.cpp b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp index bc8a9762d..986e46e25 100644 --- a/src/duckdb/src/parser/transform/expression/transform_subquery.cpp +++ b/src/duckdb/src/parser/transform/expression/transform_subquery.cpp @@ -24,7 +24,6 @@ unique_ptr Transformer::TransformSubquery(duckdb_libpgquery::P subquery_expr->subquery = TransformSelectStmt(*root.subselect); SetQueryLocation(*subquery_expr, root.location); D_ASSERT(subquery_expr->subquery); - D_ASSERT(!subquery_expr->subquery->node->GetSelectList().empty()); switch (root.subLinkType) { case duckdb_libpgquery::PG_EXISTS_SUBLINK: { diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 07d225378..d99521743 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -343,7 +343,6 @@ optional_ptr Binder::GetMatchingBinding(const string &catalog_name, con const string &table_name, const string &column_name, ErrorData &error) { optional_ptr binding; - D_ASSERT(!lambda_bindings); if (macro_binding && table_name == macro_binding->GetAlias()) { binding = optional_ptr(macro_binding.get()); } else { diff --git a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp index cce06d712..5d2bce798 100644 --- a/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_macro_expression.cpp @@ -98,6 +98,7 @@ void ExpressionBinder::UnfoldMacroExpression(FunctionExpression &function, Scala // validate the arguments and separate positional and default arguments vector> positional_arguments; InsertionOrderPreservingMap> named_arguments; + binder.lambda_bindings = lambda_bindings; auto bind_result = MacroFunction::BindMacroFunction(binder, macro_func.macros, macro_func.name, function, positional_arguments, named_arguments, depth); if (!bind_result.error.empty()) { diff --git a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp index 8e15f3b28..7f03f0e32 100644 --- a/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_subquery_expression.cpp @@ -23,10 +23,6 @@ class BoundSubqueryNode : public QueryNode { BoundStatement bound_node; unique_ptr subquery; - const vector> &GetSelectList() const override { - throw InternalException("Cannot get select list of bound subquery node"); - } - string ToString() const override { throw InternalException("Cannot ToString bound subquery node"); } diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp index 92ca383c7..120fdbceb 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -1,10 +1,6 @@ -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression_map.hpp" -#include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" namespace duckdb { @@ -17,25 +13,25 @@ BoundStatement Binder::BindNode(CTENode &statement) { } BoundStatement Binder::BindCTE(CTENode &statement) { - BoundCTENode result; + BoundStatement result; // first recursively visit the materialized CTE operations // the left side is visited first and is added to the BindContext of the right side D_ASSERT(statement.query); - result.ctename = statement.ctename; - result.materialized = statement.materialized; - result.setop_index = GenerateTableIndex(); + auto ctename = statement.ctename; + auto materialized = statement.materialized; + auto setop_index = GenerateTableIndex(); - AddCTE(result.ctename); + AddCTE(ctename); - result.query_binder = Binder::CreateBinder(context, this); - result.query = result.query_binder->BindNode(*statement.query); + auto query_binder = Binder::CreateBinder(context, this); + auto query = query_binder->BindNode(*statement.query); // the result types of the CTE are the types of the LHS - result.types = result.query.types; + result.types = query.types; // names are picked from the LHS, unless aliases are explicitly specified - result.names = result.query.names; + result.names = query.names; for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { result.names[i] = statement.aliases[i]; } @@ -55,43 +51,43 @@ BoundStatement Binder::BindCTE(CTENode &statement) { } // This allows the right side to reference the CTE - bind_context.AddGenericBinding(result.setop_index, statement.ctename, names, result.types); + bind_context.AddGenericBinding(setop_index, statement.ctename, names, result.types); - result.child_binder = Binder::CreateBinder(context, this); + auto child_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - result.child_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, names, result.types); + child_binder->bind_context.AddCTEBinding(setop_index, statement.ctename, names, result.types); + BoundStatement child; if (statement.child) { - // Move all modifiers to the child node. - for (auto &modifier : statement.modifiers) { - statement.child->modifiers.push_back(std::move(modifier)); - } - - statement.modifiers.clear(); - - result.child = result.child_binder->BindNode(*statement.child); - for (auto &c : result.query_binder->correlated_columns) { - result.child_binder->AddCorrelatedColumn(c); + child = child_binder->BindNode(*statement.child); + for (auto &c : query_binder->correlated_columns) { + child_binder->AddCorrelatedColumn(c); } // the result types of the CTE are the types of the LHS - result.types = result.child.types; - result.names = result.child.names; + result.types = child.types; + result.names = child.names; - MoveCorrelatedExpressions(*result.child_binder); + MoveCorrelatedExpressions(*child_binder); } - MoveCorrelatedExpressions(*result.query_binder); + MoveCorrelatedExpressions(*query_binder); + + auto cte_query = std::move(query.plan); + auto cte_child = std::move(child.plan); + + auto root = make_uniq(ctename, setop_index, result.types.size(), std::move(cte_query), + std::move(cte_child), materialized); - BoundStatement result_statement; - result_statement.types = result.types; - result_statement.names = result.names; - result_statement.plan = CreatePlan(result); - return result_statement; + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || child_binder->has_unplanned_dependent_joins || + query_binder->has_unplanned_dependent_joins; + result.plan = std::move(root); + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp index 8cb62ab73..0795c13cd 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_recursive_cte_node.cpp @@ -3,13 +3,12 @@ #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/query_node/recursive_cte_node.hpp" #include "duckdb/planner/binder.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" +#include "duckdb/planner/operator/logical_recursive_cte.hpp" namespace duckdb { BoundStatement Binder::BindNode(RecursiveCTENode &statement) { - BoundRecursiveCTENode result; // first recursively visit the recursive CTE operations // the left side is visited first and is added to the BindContext of the right side @@ -19,49 +18,51 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { throw BinderException("UNION ALL cannot be used with USING KEY in recursive CTE."); } - result.ctename = statement.ctename; - result.union_all = statement.union_all; - result.setop_index = GenerateTableIndex(); + auto ctename = statement.ctename; + auto union_all = statement.union_all; + auto setop_index = GenerateTableIndex(); - result.left_binder = Binder::CreateBinder(context, this); - result.left = result.left_binder->BindNode(*statement.left); + auto left_binder = Binder::CreateBinder(context, this); + auto left = left_binder->BindNode(*statement.left); + BoundStatement result; // the result types of the CTE are the types of the LHS - result.types = result.left.types; + result.types = left.types; // names are picked from the LHS, unless aliases are explicitly specified - result.names = result.left.names; + result.names = left.names; for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { result.names[i] = statement.aliases[i]; } // This allows the right side to reference the CTE recursively - bind_context.AddGenericBinding(result.setop_index, statement.ctename, result.names, result.types); + bind_context.AddGenericBinding(setop_index, statement.ctename, result.names, result.types); - result.right_binder = Binder::CreateBinder(context, this); + auto right_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context - result.right_binder->bind_context.AddCTEBinding(result.setop_index, statement.ctename, result.names, result.types, - !statement.key_targets.empty()); + right_binder->bind_context.AddCTEBinding(setop_index, statement.ctename, result.names, result.types, + !statement.key_targets.empty()); - result.right = result.right_binder->BindNode(*statement.right); - for (auto &c : result.left_binder->correlated_columns) { - result.right_binder->AddCorrelatedColumn(c); + auto right = right_binder->BindNode(*statement.right); + for (auto &c : left_binder->correlated_columns) { + right_binder->AddCorrelatedColumn(c); } // move the correlated expressions from the child binders to this binder - MoveCorrelatedExpressions(*result.left_binder); - MoveCorrelatedExpressions(*result.right_binder); + MoveCorrelatedExpressions(*left_binder); + MoveCorrelatedExpressions(*right_binder); + vector> key_targets; // bind specified keys to the referenced column auto expression_binder = ExpressionBinder(*this, context); - for (unique_ptr &expr : statement.key_targets) { + for (auto &expr : statement.key_targets) { auto bound_expr = expression_binder.Bind(expr); D_ASSERT(bound_expr->type == ExpressionType::BOUND_COLUMN_REF); - result.key_targets.push_back(std::move(bound_expr)); + key_targets.push_back(std::move(bound_expr)); } // now both sides have been bound we can resolve types - if (result.left.types.size() != result.right.types.size()) { + if (left.types.size() != right.types.size()) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -70,11 +71,42 @@ BoundStatement Binder::BindNode(RecursiveCTENode &statement) { throw NotImplementedException("FIXME: bind modifiers in recursive CTE"); } - BoundStatement result_statement; - result_statement.types = result.types; - result_statement.names = result.names; - result_statement.plan = CreatePlan(result); - return result_statement; + // Generate the logical plan for the left and right sides of the set operation + left_binder->is_outside_flattened = is_outside_flattened; + right_binder->is_outside_flattened = is_outside_flattened; + + auto left_node = std::move(left.plan); + auto right_node = std::move(right.plan); + + // check if there are any unplanned subqueries left in either child + has_unplanned_dependent_joins = has_unplanned_dependent_joins || left_binder->has_unplanned_dependent_joins || + right_binder->has_unplanned_dependent_joins; + + // for both the left and right sides, cast them to the same types + left_node = CastLogicalOperatorToTypes(left.types, result.types, std::move(left_node)); + right_node = CastLogicalOperatorToTypes(right.types, result.types, std::move(right_node)); + + auto recurring_binding = right_binder->GetCTEBinding("recurring." + ctename); + bool ref_recurring = recurring_binding && recurring_binding->Cast().reference_count > 0; + if (key_targets.empty() && ref_recurring) { + throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); + } + + // Check if there is a reference to the recursive or recurring table, if not create a set operator. + auto cte_binding = right_binder->GetCTEBinding(ctename); + bool ref_cte = cte_binding && cte_binding->Cast().reference_count > 0; + if (!ref_cte && !ref_recurring) { + auto root = + make_uniq(setop_index, result.types.size(), std::move(left_node), + std::move(right_node), LogicalOperatorType::LOGICAL_UNION, union_all); + result.plan = std::move(root); + } else { + auto root = make_uniq(ctename, setop_index, result.types.size(), union_all, + std::move(key_targets), std::move(left_node), std::move(right_node)); + root->ref_recurring = ref_recurring; + result.plan = std::move(root); + } + return result; } } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp index d70f6d2cc..6c301d1de 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp @@ -14,9 +14,6 @@ namespace duckdb { -BoundSetOperationNode::~BoundSetOperationNode() { -} - struct SetOpAliasGatherer { public: explicit SetOpAliasGatherer(SelectBindState &bind_state_p) : bind_state(bind_state_p) { diff --git a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp deleted file mode 100644 index dc4cc8770..000000000 --- a/src/duckdb/src/planner/binder/query_node/plan_cte_node.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "duckdb/common/string_util.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/operator/logical_materialized_cte.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundCTENode &node) { - // Generate the logical plan for the cte_query and child. - auto cte_query = std::move(node.query.plan); - auto cte_child = std::move(node.child.plan); - - auto root = make_uniq(node.ctename, node.setop_index, node.types.size(), - std::move(cte_query), std::move(cte_child), node.materialized); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.child_binder->has_unplanned_dependent_joins || - node.query_binder->has_unplanned_dependent_joins; - - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp deleted file mode 100644 index f51a03c50..000000000 --- a/src/duckdb/src/planner/binder/query_node/plan_recursive_cte_node.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_recursive_cte.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -unique_ptr Binder::CreatePlan(BoundRecursiveCTENode &node) { - // Generate the logical plan for the left and right sides of the set operation - node.left_binder->is_outside_flattened = is_outside_flattened; - node.right_binder->is_outside_flattened = is_outside_flattened; - - auto left_node = std::move(node.left.plan); - auto right_node = std::move(node.right.plan); - - // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || node.left_binder->has_unplanned_dependent_joins || - node.right_binder->has_unplanned_dependent_joins; - - // for both the left and right sides, cast them to the same types - left_node = CastLogicalOperatorToTypes(node.left.types, node.types, std::move(left_node)); - right_node = CastLogicalOperatorToTypes(node.right.types, node.types, std::move(right_node)); - - auto recurring_binding = node.right_binder->GetCTEBinding("recurring." + node.ctename); - bool ref_recurring = recurring_binding && recurring_binding->Cast().reference_count > 0; - if (node.key_targets.empty() && ref_recurring) { - throw InvalidInputException("RECURRING can only be used with USING KEY in recursive CTE."); - } - - // Check if there is a reference to the recursive or recurring table, if not create a set operator. - auto cte_binding = node.right_binder->GetCTEBinding(node.ctename); - bool ref_cte = cte_binding && cte_binding->Cast().reference_count > 0; - if (!ref_cte && !ref_recurring) { - auto root = - make_uniq(node.setop_index, node.types.size(), std::move(left_node), - std::move(right_node), LogicalOperatorType::LOGICAL_UNION, node.union_all); - return VisitQueryNode(node, std::move(root)); - } - - auto root = - make_uniq(node.ctename, node.setop_index, node.types.size(), node.union_all, - std::move(node.key_targets), std::move(left_node), std::move(right_node)); - root->ref_recurring = ref_recurring; - return VisitQueryNode(node, std::move(root)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp index 3955cf897..b5fc04677 100644 --- a/src/duckdb/src/planner/binder/statement/bind_pragma.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_pragma.cpp @@ -2,6 +2,7 @@ #include "duckdb/parser/statement/pragma_statement.hpp" #include "duckdb/planner/operator/logical_pragma.hpp" #include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" @@ -28,16 +29,32 @@ unique_ptr Binder::BindPragma(PragmaInfo &info, QueryErrorConte } // bind the pragma function - auto &entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name); + auto entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::RETURN_NULL); + if (!entry) { + // try to find whether a table extry might exist + auto table_entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, + info.name, OnEntryNotFound::RETURN_NULL); + if (table_entry) { + // there is a table entry with the same name, now throw more explicit error message + throw CatalogException("Pragma Function with name %s does not exist, but a table function with the same " + "name exists, try `CALL %s(...)`", + info.name, info.name); + } + // rebind to throw exception + entry = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, info.name, + OnEntryNotFound::THROW_EXCEPTION); + } + FunctionBinder function_binder(*this); ErrorData error; - auto bound_idx = function_binder.BindFunction(entry.name, entry.functions, params, error); + auto bound_idx = function_binder.BindFunction(entry->name, entry->functions, params, error); if (!bound_idx.IsValid()) { D_ASSERT(error.HasError()); error.AddQueryLocation(error_context); error.Throw(); } - auto bound_function = entry.functions.GetFunctionByOffset(bound_idx.GetIndex()); + auto bound_function = entry->functions.GetFunctionByOffset(bound_idx.GetIndex()); // bind and check named params BindNamedParameters(bound_function.named_parameters, named_parameters, error_context, bound_function.name); return make_uniq(std::move(bound_function), std::move(params), std::move(named_parameters)); diff --git a/src/duckdb/src/planner/expression_iterator.cpp b/src/duckdb/src/planner/expression_iterator.cpp index 9f67f915c..3d1407900 100644 --- a/src/duckdb/src/planner/expression_iterator.cpp +++ b/src/duckdb/src/planner/expression_iterator.cpp @@ -4,8 +4,6 @@ #include "duckdb/planner/expression/list.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" -#include "duckdb/planner/query_node/bound_recursive_cte_node.hpp" -#include "duckdb/planner/query_node/bound_cte_node.hpp" #include "duckdb/planner/tableref/list.hpp" #include "duckdb/common/enum_util.hpp" diff --git a/src/duckdb/src/storage/caching_file_system.cpp b/src/duckdb/src/storage/caching_file_system.cpp index 3de905228..347107673 100644 --- a/src/duckdb/src/storage/caching_file_system.cpp +++ b/src/duckdb/src/storage/caching_file_system.cpp @@ -79,6 +79,21 @@ FileHandle &CachingFileHandle::GetFileHandle() { return *file_handle; } +static bool ShouldExpandToFillGap(const idx_t current_length, const idx_t added_length) { + const idx_t MAX_BOUND_TO_BE_ADDED_LENGTH = 1048576; + + if (added_length > MAX_BOUND_TO_BE_ADDED_LENGTH) { + // Absolute value of what would be needed to added is too high + return false; + } + if (added_length > current_length) { + // Relative value of what would be needed to added is too high + return false; + } + + return true; +} + BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, const idx_t location) { BufferHandle result; if (!external_file_cache.IsEnabled()) { @@ -90,30 +105,42 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, c // Try to read from the cache, filling overlapping_ranges in the process vector> overlapping_ranges; - result = TryReadFromCache(buffer, nr_bytes, location, overlapping_ranges); + optional_idx start_location_of_next_range; + result = TryReadFromCache(buffer, nr_bytes, location, overlapping_ranges, start_location_of_next_range); if (result.IsValid()) { return result; // Success } + idx_t new_nr_bytes = nr_bytes; + if (start_location_of_next_range.IsValid()) { + const idx_t nr_bytes_to_be_added = start_location_of_next_range.GetIndex() - location - nr_bytes; + if (ShouldExpandToFillGap(nr_bytes, nr_bytes_to_be_added)) { + // Grow the range from location to start_location_of_next_range, so that to fill gaps in the cached ranges + new_nr_bytes = nr_bytes + nr_bytes_to_be_added; + } + } + // Finally, if we weren't able to find the file range in the cache, we have to create a new file range - result = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, nr_bytes); - auto new_file_range = make_shared_ptr(result.GetBlockHandle(), nr_bytes, location, version_tag); + result = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, new_nr_bytes); + auto new_file_range = + make_shared_ptr(result.GetBlockHandle(), new_nr_bytes, location, version_tag); buffer = result.Ptr(); // Interleave reading and copying from cached buffers if (OnDiskFile()) { // On-disk file: prefer interleaving reading and copying from cached buffers - ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, true); + ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, true); } else { - // Remote file: prefer interleaving reading and copying from cached buffers only if reduces number of real reads - if (ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, false) <= 1) { - ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, true); + // Remote file: prefer interleaving reading and copying from cached buffers only if reduces number of real + // reads + if (ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, false) <= 1) { + ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, new_nr_bytes, location, true); } else { - GetFileHandle().Read(context, buffer, nr_bytes, location); + GetFileHandle().Read(context, buffer, new_nr_bytes, location); } } - return TryInsertFileRange(result, buffer, nr_bytes, location, new_file_range); + return TryInsertFileRange(result, buffer, new_nr_bytes, location, new_file_range); } BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, idx_t &nr_bytes) { @@ -131,7 +158,12 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, idx_t &nr_bytes) { // Try to read from the cache first vector> overlapping_ranges; - result = TryReadFromCache(buffer, nr_bytes, position, overlapping_ranges); + { + optional_idx start_location_of_next_range; + result = TryReadFromCache(buffer, nr_bytes, position, overlapping_ranges, start_location_of_next_range); + // start_location_of_next_range is in this case discarded + } + if (result.IsValid()) { position += nr_bytes; return result; // Success @@ -214,7 +246,8 @@ const string &CachingFileHandle::GetVersionTag(const unique_ptr } BufferHandle CachingFileHandle::TryReadFromCache(data_ptr_t &buffer, idx_t nr_bytes, idx_t location, - vector> &overlapping_ranges) { + vector> &overlapping_ranges, + optional_idx &start_location_of_next_range) { BufferHandle result; // Get read lock for cached ranges @@ -246,7 +279,8 @@ BufferHandle CachingFileHandle::TryReadFromCache(data_ptr_t &buffer, idx_t nr_by } while (it != ranges.end()) { if (it->second->location >= this_end) { - // We're past the requested location + // We're past the requested location, we are going to bail out, save start_location_of_next_range + start_location_of_next_range = it->second->location; break; } // Check if the cached range overlaps the requested one diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp index 75f8dd694..4302ebb62 100644 --- a/src/duckdb/src/storage/data_table.cpp +++ b/src/duckdb/src/storage/data_table.cpp @@ -1544,7 +1544,7 @@ void DataTable::Update(TableUpdateState &state, ClientContext &context, Vector & row_ids_slice.Slice(row_ids, sel_global_update, n_global_update); row_ids_slice.Flatten(n_global_update); - row_groups->Update(transaction, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); + row_groups->Update(transaction, *this, FlatVector::GetData(row_ids_slice), column_ids, updates_slice); } } @@ -1568,7 +1568,7 @@ void DataTable::UpdateColumn(TableCatalogEntry &table, ClientContext &context, V updates.Flatten(); row_ids.Flatten(updates.size()); - row_groups->UpdateColumn(transaction, row_ids, column_path, updates); + row_groups->UpdateColumn(transaction, *this, row_ids, column_path, updates); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/storage/local_storage.cpp b/src/duckdb/src/storage/local_storage.cpp index 4c58c2d5d..39612b09b 100644 --- a/src/duckdb/src/storage/local_storage.cpp +++ b/src/duckdb/src/storage/local_storage.cpp @@ -580,7 +580,7 @@ void LocalStorage::Update(DataTable &table, Vector &row_ids, const vector(row_ids); - storage->GetCollection().Update(TransactionData(0, 0), ids, column_ids, updates); + storage->GetCollection().Update(TransactionData(0, 0), table, ids, column_ids, updates); } void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage, optional_ptr commit_state) { diff --git a/src/duckdb/src/storage/table/array_column_data.cpp b/src/duckdb/src/storage/table/array_column_data.cpp index d92562a94..849e1dec8 100644 --- a/src/duckdb/src/storage/table/array_column_data.cpp +++ b/src/duckdb/src/storage/table/array_column_data.cpp @@ -224,13 +224,14 @@ idx_t ArrayColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resul throw NotImplementedException("Array Fetch"); } -void ArrayColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ArrayColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw NotImplementedException("Array Update is not supported."); } -void ArrayColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ArrayColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw NotImplementedException("Array Update Column is not supported"); } diff --git a/src/duckdb/src/storage/table/column_data.cpp b/src/duckdb/src/storage/table/column_data.cpp index c38fff709..2b48d90cf 100644 --- a/src/duckdb/src/storage/table/column_data.cpp +++ b/src/duckdb/src/storage/table/column_data.cpp @@ -293,13 +293,13 @@ void ColumnData::FetchUpdateRow(TransactionData transaction, row_t row_id, Vecto updates->FetchRow(transaction, NumericCast(row_id), result, result_idx); } -void ColumnData::UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count, Vector &base_vector) { +void ColumnData::UpdateInternal(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count, Vector &base_vector) { lock_guard update_guard(update_lock); if (!updates) { updates = make_uniq(*this); } - updates->Update(transaction, column_index, update_vector, row_ids, update_count, base_vector); + updates->Update(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } idx_t ColumnData::ScanVector(TransactionData transaction, idx_t vector_index, ColumnScanState &state, Vector &result, @@ -578,20 +578,20 @@ idx_t ColumnData::FetchUpdateData(ColumnScanState &state, row_t *row_ids, Vector return fetch_count; } -void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_vector, + row_t *row_ids, idx_t update_count) { Vector base_vector(type); ColumnScanState state; FetchUpdateData(state, row_ids, base_vector); - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } -void ColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, Vector &update_vector, - row_t *row_ids, idx_t update_count, idx_t depth) { +void ColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, const vector &column_path, + Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { // this method should only be called at the end of the path in the base column case D_ASSERT(depth >= column_path.size()); - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count); } void ColumnData::AppendTransientSegment(SegmentLock &l, idx_t start_row) { diff --git a/src/duckdb/src/storage/table/list_column_data.cpp b/src/duckdb/src/storage/table/list_column_data.cpp index 986b32dc7..5672482c7 100644 --- a/src/duckdb/src/storage/table/list_column_data.cpp +++ b/src/duckdb/src/storage/table/list_column_data.cpp @@ -263,13 +263,14 @@ idx_t ListColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result throw NotImplementedException("List Fetch"); } -void ListColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void ListColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw NotImplementedException("List Update is not supported."); } -void ListColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void ListColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw NotImplementedException("List Update Column is not supported"); } diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index eb02ba7e9..4c6bb794a 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -854,8 +854,8 @@ void RowGroup::CleanupAppend(transaction_t lowest_transaction, idx_t start, idx_ vinfo.CleanupAppend(lowest_transaction, start, count); } -void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_t *ids, idx_t offset, idx_t count, - const vector &column_ids) { +void RowGroup::Update(TransactionData transaction, DataTable &data_table, DataChunk &update_chunk, row_t *ids, + idx_t offset, idx_t count, const vector &column_ids) { #ifdef DEBUG for (size_t i = offset; i < offset + count; i++) { D_ASSERT(ids[i] >= row_t(this->start) && ids[i] < row_t(this->start + this->count)); @@ -868,16 +868,16 @@ void RowGroup::Update(TransactionData transaction, DataChunk &update_chunk, row_ if (offset > 0) { Vector sliced_vector(update_chunk.data[i], offset, offset + count); sliced_vector.Flatten(count); - col_data.Update(transaction, column.index, sliced_vector, ids + offset, count); + col_data.Update(transaction, data_table, column.index, sliced_vector, ids + offset, count); } else { - col_data.Update(transaction, column.index, update_chunk.data[i], ids, count); + col_data.Update(transaction, data_table, column.index, update_chunk.data[i], ids, count); } MergeStatistics(column.index, *col_data.GetUpdateStatistics()); } } -void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vector &row_ids, idx_t offset, idx_t count, - const vector &column_path) { +void RowGroup::UpdateColumn(TransactionData transaction, DataTable &data_table, DataChunk &updates, Vector &row_ids, + idx_t offset, idx_t count, const vector &column_path) { D_ASSERT(updates.ColumnCount() == 1); auto ids = FlatVector::GetData(row_ids); @@ -887,9 +887,9 @@ void RowGroup::UpdateColumn(TransactionData transaction, DataChunk &updates, Vec if (offset > 0) { Vector sliced_vector(updates.data[0], offset, offset + count); sliced_vector.Flatten(count); - col_data.UpdateColumn(transaction, column_path, sliced_vector, ids + offset, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, sliced_vector, ids + offset, count, 1); } else { - col_data.UpdateColumn(transaction, column_path, updates.data[0], ids, count, 1); + col_data.UpdateColumn(transaction, data_table, column_path, updates.data[0], ids, count, 1); } MergeStatistics(primary_column_idx, *col_data.GetUpdateStatistics()); } @@ -1093,8 +1093,10 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite row_group_pointer.has_metadata_blocks = has_metadata_blocks; row_group_pointer.extra_metadata_blocks = extra_metadata_blocks; row_group_pointer.deletes_pointers = deletes_pointers; - metadata_manager->ClearModifiedBlocks(write_data.existing_pointers); - metadata_manager->ClearModifiedBlocks(deletes_pointers); + if (metadata_manager) { + metadata_manager->ClearModifiedBlocks(write_data.existing_pointers); + metadata_manager->ClearModifiedBlocks(deletes_pointers); + } return row_group_pointer; } D_ASSERT(write_data.states.size() == columns.size()); diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp index 42c453ea0..7c5300bdd 100644 --- a/src/duckdb/src/storage/table/row_group_collection.cpp +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -650,14 +650,14 @@ optional_ptr RowGroupCollection::NextUpdateRowGroup(row_t *ids, idx_t return row_group; } -void RowGroupCollection::Update(TransactionData transaction, row_t *ids, const vector &column_ids, - DataChunk &updates) { +void RowGroupCollection::Update(TransactionData transaction, DataTable &data_table, row_t *ids, + const vector &column_ids, DataChunk &updates) { D_ASSERT(updates.size() >= 1); idx_t pos = 0; do { idx_t start = pos; auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->Update(transaction, updates, ids, start, pos - start, column_ids); + row_group->Update(transaction, data_table, updates, ids, start, pos - start, column_ids); auto l = stats.GetLock(); for (idx_t i = 0; i < column_ids.size(); i++) { @@ -770,15 +770,15 @@ void RowGroupCollection::RemoveFromIndexes(const QueryContext &context, TableInd } } -void RowGroupCollection::UpdateColumn(TransactionData transaction, Vector &row_ids, const vector &column_path, - DataChunk &updates) { +void RowGroupCollection::UpdateColumn(TransactionData transaction, DataTable &data_table, Vector &row_ids, + const vector &column_path, DataChunk &updates) { D_ASSERT(updates.size() >= 1); auto ids = FlatVector::GetData(row_ids); idx_t pos = 0; do { idx_t start = pos; auto row_group = NextUpdateRowGroup(ids, pos, updates.size()); - row_group->UpdateColumn(transaction, updates, row_ids, start, pos - start, column_path); + row_group->UpdateColumn(transaction, data_table, updates, row_ids, start, pos - start, column_path); auto lock = stats.GetLock(); auto primary_column_idx = column_path[0]; diff --git a/src/duckdb/src/storage/table/row_id_column_data.cpp b/src/duckdb/src/storage/table/row_id_column_data.cpp index d869913bf..4bc3c4148 100644 --- a/src/duckdb/src/storage/table/row_id_column_data.cpp +++ b/src/duckdb/src/storage/table/row_id_column_data.cpp @@ -138,13 +138,14 @@ void RowIdColumnData::RevertAppend(row_t start_row) { throw InternalException("RowIdColumnData cannot be appended to"); } -void RowIdColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void RowIdColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { throw InternalException("RowIdColumnData cannot be updated"); } -void RowIdColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void RowIdColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { throw InternalException("RowIdColumnData cannot be updated"); } diff --git a/src/duckdb/src/storage/table/standard_column_data.cpp b/src/duckdb/src/storage/table/standard_column_data.cpp index fde7d2463..ad8814ab4 100644 --- a/src/duckdb/src/storage/table/standard_column_data.cpp +++ b/src/duckdb/src/storage/table/standard_column_data.cpp @@ -152,8 +152,8 @@ idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &re return scan_count; } -void StandardColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { +void StandardColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { ColumnScanState standard_state, validity_state; Vector base_vector(type); auto standard_fetch = FetchUpdateData(standard_state, row_ids, base_vector); @@ -162,18 +162,19 @@ void StandardColumnData::Update(TransactionData transaction, idx_t column_index, throw InternalException("Unaligned fetch in validity and main column data for update"); } - UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); - validity.UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); + validity.UpdateInternal(transaction, data_table, column_index, update_vector, row_ids, update_count, base_vector); } -void StandardColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StandardColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { if (depth >= column_path.size()) { // update this column - ColumnData::Update(transaction, column_path[0], update_vector, row_ids, update_count); + ColumnData::Update(transaction, data_table, column_path[0], update_vector, row_ids, update_count); } else { // update the child column (i.e. the validity column) - validity.UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, depth + 1); + validity.UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, update_count, depth + 1); } } diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp index 65f322e79..b1de02b2d 100644 --- a/src/duckdb/src/storage/table/struct_column_data.cpp +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -207,17 +207,18 @@ idx_t StructColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &resu return scan_count; } -void StructColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, - idx_t update_count) { - validity.Update(transaction, column_index, update_vector, row_ids, update_count); +void StructColumnData::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, + Vector &update_vector, row_t *row_ids, idx_t update_count) { + validity.Update(transaction, data_table, column_index, update_vector, row_ids, update_count); auto &child_entries = StructVector::GetEntries(update_vector); for (idx_t i = 0; i < child_entries.size(); i++) { - sub_columns[i]->Update(transaction, column_index, *child_entries[i], row_ids, update_count); + sub_columns[i]->Update(transaction, data_table, column_index, *child_entries[i], row_ids, update_count); } } -void StructColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, - Vector &update_vector, row_t *row_ids, idx_t update_count, idx_t depth) { +void StructColumnData::UpdateColumn(TransactionData transaction, DataTable &data_table, + const vector &column_path, Vector &update_vector, row_t *row_ids, + idx_t update_count, idx_t depth) { // we can never DIRECTLY update a struct column if (depth >= column_path.size()) { throw InternalException("Attempting to directly update a struct column - this should not be possible"); @@ -225,13 +226,13 @@ void StructColumnData::UpdateColumn(TransactionData transaction, const vector sub_columns.size()) { throw InternalException("Update column_path out of range"); } - sub_columns[update_column - 1]->UpdateColumn(transaction, column_path, update_vector, row_ids, update_count, - depth + 1); + sub_columns[update_column - 1]->UpdateColumn(transaction, data_table, column_path, update_vector, row_ids, + update_count, depth + 1); } } diff --git a/src/duckdb/src/storage/table/update_segment.cpp b/src/duckdb/src/storage/table/update_segment.cpp index 8056907bc..c47851ead 100644 --- a/src/duckdb/src/storage/table/update_segment.cpp +++ b/src/duckdb/src/storage/table/update_segment.cpp @@ -7,6 +7,7 @@ #include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/transaction/update_info.hpp" #include "duckdb/transaction/undo_buffer.hpp" +#include "duckdb/storage/data_table.hpp" #include @@ -104,9 +105,10 @@ idx_t UpdateInfo::GetAllocSize(idx_t type_size) { return AlignValue(sizeof(UpdateInfo) + (sizeof(sel_t) + type_size) * STANDARD_VECTOR_SIZE); } -void UpdateInfo::Initialize(UpdateInfo &info, transaction_t transaction_id) { +void UpdateInfo::Initialize(UpdateInfo &info, DataTable &data_table, transaction_t transaction_id) { info.max = STANDARD_VECTOR_SIZE; info.version_number = transaction_id; + info.table = &data_table; info.segment = nullptr; info.prev.entry = nullptr; info.next.entry = nullptr; @@ -1236,11 +1238,11 @@ static idx_t SortSelectionVector(SelectionVector &sel, idx_t count, row_t *ids) return pos; } -UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, idx_t type_size, idx_t count, +UpdateInfo *CreateEmptyUpdateInfo(TransactionData transaction, DataTable &data_table, idx_t type_size, idx_t count, unsafe_unique_array &data) { data = make_unsafe_uniq_array_uninitialized(UpdateInfo::GetAllocSize(type_size)); auto update_info = reinterpret_cast(data.get()); - UpdateInfo::Initialize(*update_info, transaction.transaction_id); + UpdateInfo::Initialize(*update_info, data_table, transaction.transaction_id); return update_info; } @@ -1258,8 +1260,8 @@ void UpdateSegment::InitializeUpdateInfo(idx_t vector_idx) { } } -void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vector &update_p, row_t *ids, idx_t count, - Vector &base_data) { +void UpdateSegment::Update(TransactionData transaction, DataTable &data_table, idx_t column_index, Vector &update_p, + row_t *ids, idx_t count, Vector &base_data) { // obtain an exclusive lock auto write_lock = lock.GetExclusiveLock(); @@ -1322,10 +1324,10 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect // no updates made yet by this transaction: initially the update info to empty if (transaction.transaction) { auto &dtransaction = transaction.transaction->Cast(); - node_ref = dtransaction.CreateUpdateInfo(type_size, count); + node_ref = dtransaction.CreateUpdateInfo(type_size, data_table, count); node = &UpdateInfo::Get(node_ref); } else { - node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + node = CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data); } node->segment = this; node->vector_index = vector_index; @@ -1360,7 +1362,7 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto handle = root->allocator.Allocate(alloc_size); auto &update_info = UpdateInfo::Get(handle); - UpdateInfo::Initialize(update_info, TRANSACTION_ID_START - 1); + UpdateInfo::Initialize(update_info, data_table, TRANSACTION_ID_START - 1); update_info.column_index = column_index; InitializeUpdateInfo(update_info, ids, sel, count, vector_index, vector_offset); @@ -1370,10 +1372,10 @@ void UpdateSegment::Update(TransactionData transaction, idx_t column_index, Vect UndoBufferReference node_ref; optional_ptr transaction_node; if (transaction.transaction) { - node_ref = transaction.transaction->CreateUpdateInfo(type_size, count); + node_ref = transaction.transaction->CreateUpdateInfo(type_size, data_table, count); transaction_node = &UpdateInfo::Get(node_ref); } else { - transaction_node = CreateEmptyUpdateInfo(transaction, type_size, count, update_info_data); + transaction_node = CreateEmptyUpdateInfo(transaction, data_table, type_size, count, update_info_data); } InitializeUpdateInfo(*transaction_node, ids, sel, count, vector_index, vector_offset); diff --git a/src/duckdb/src/transaction/commit_state.cpp b/src/duckdb/src/transaction/commit_state.cpp index 0f5d75bd2..6eba8ab10 100644 --- a/src/duckdb/src/transaction/commit_state.cpp +++ b/src/duckdb/src/transaction/commit_state.cpp @@ -165,6 +165,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::INSERT_TUPLE: { // append: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } // mark the tuples as committed info->table->CommitAppend(commit_id, info->start_row, info->count); break; @@ -172,6 +178,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::DELETE_TUPLE: { // deletion: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } // mark the tuples as committed info->version_info->CommitDelete(info->vector_idx, commit_id, *info); break; @@ -179,6 +191,12 @@ void CommitState::CommitEntry(UndoFlags type, data_ptr_t data) { case UndoFlags::UPDATE_TUPLE: { // update: auto info = reinterpret_cast(data); + if (!info->table->IsMainTable()) { + auto table_name = info->table->GetTableName(); + auto table_modification = info->table->TableModification(); + throw TransactionException("Attempting to modify table %s but another transaction has %s this table", + table_name, table_modification); + } info->version_number = commit_id; break; } diff --git a/src/duckdb/src/transaction/duck_transaction.cpp b/src/duckdb/src/transaction/duck_transaction.cpp index dc6afccb7..cf8acccde 100644 --- a/src/duckdb/src/transaction/duck_transaction.cpp +++ b/src/duckdb/src/transaction/duck_transaction.cpp @@ -126,11 +126,11 @@ void DuckTransaction::PushAppend(DataTable &table, idx_t start_row, idx_t row_co append_info->count = row_count; } -UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, idx_t entries) { +UndoBufferReference DuckTransaction::CreateUpdateInfo(idx_t type_size, DataTable &data_table, idx_t entries) { idx_t alloc_size = UpdateInfo::GetAllocSize(type_size); auto undo_entry = undo_buffer.CreateEntry(UndoFlags::UPDATE_TUPLE, alloc_size); auto &update_info = UpdateInfo::Get(undo_entry); - UpdateInfo::Initialize(update_info, transaction_id); + UpdateInfo::Initialize(update_info, data_table, transaction_id); return undo_entry; } @@ -246,14 +246,6 @@ ErrorData DuckTransaction::Commit(AttachedDatabase &db, transaction_t new_commit // no need to flush anything if we made no changes return ErrorData(); } - for (auto &entry : modified_tables) { - auto &tbl = entry.first.get(); - if (!tbl.IsMainTable()) { - return ErrorData( - TransactionException("Attempting to modify table %s but another transaction has %s this table", - tbl.GetTableName(), tbl.TableModification())); - } - } D_ASSERT(db.IsSystem() || db.IsTemporary() || !IsReadOnly()); UndoBuffer::IteratorState iterator_state; diff --git a/src/duckdb/src/transaction/wal_write_state.cpp b/src/duckdb/src/transaction/wal_write_state.cpp index 5fe17e050..0036ad0c6 100644 --- a/src/duckdb/src/transaction/wal_write_state.cpp +++ b/src/duckdb/src/transaction/wal_write_state.cpp @@ -27,10 +27,10 @@ WALWriteState::WALWriteState(DuckTransaction &transaction_p, WriteAheadLog &log, : transaction(transaction_p), log(log), commit_state(commit_state), current_table_info(nullptr) { } -void WALWriteState::SwitchTable(DataTableInfo *table_info, UndoFlags new_op) { - if (current_table_info != table_info) { +void WALWriteState::SwitchTable(DataTableInfo &table_info, UndoFlags new_op) { + if (current_table_info != &table_info) { // write the current table to the log - log.WriteSetTable(table_info->GetSchemaName(), table_info->GetTableName()); + log.WriteSetTable(table_info.GetSchemaName(), table_info.GetTableName()); current_table_info = table_info; } } @@ -171,7 +171,7 @@ void WALWriteState::WriteCatalogEntry(CatalogEntry &entry, data_ptr_t dataptr) { void WALWriteState::WriteDelete(DeleteInfo &info) { // switch to the current table, if necessary - SwitchTable(info.table->GetDataTableInfo().get(), UndoFlags::DELETE_TUPLE); + SwitchTable(*info.table->GetDataTableInfo(), UndoFlags::DELETE_TUPLE); if (!delete_chunk) { delete_chunk = make_uniq(); @@ -198,7 +198,7 @@ void WALWriteState::WriteUpdate(UpdateInfo &info) { auto &column_data = info.segment->column_data; auto &table_info = column_data.GetTableInfo(); - SwitchTable(&table_info, UndoFlags::UPDATE_TUPLE); + SwitchTable(table_info, UndoFlags::UPDATE_TUPLE); // initialize the update chunk vector update_types; diff --git a/src/duckdb/src/verification/statement_verifier.cpp b/src/duckdb/src/verification/statement_verifier.cpp index 81f4c4aba..31d21f045 100644 --- a/src/duckdb/src/verification/statement_verifier.cpp +++ b/src/duckdb/src/verification/statement_verifier.cpp @@ -1,5 +1,9 @@ #include "duckdb/verification/statement_verifier.hpp" +#include "duckdb/parser/query_node/select_node.hpp" +#include "duckdb/parser/query_node/set_operation_node.hpp" +#include "duckdb/parser/query_node/cte_node.hpp" + #include "duckdb/common/error_data.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/parser/parser.hpp" @@ -15,13 +19,26 @@ namespace duckdb { +const vector> &StatementVerifier::GetSelectList(QueryNode &node) { + switch (node.type) { + case QueryNodeType::SELECT_NODE: + return node.Cast().select_list; + case QueryNodeType::SET_OPERATION_NODE: + return GetSelectList(*node.Cast().children[0]); + case QueryNodeType::CTE_NODE: + return GetSelectList(*node.Cast().query); + default: + return empty_select_list; + } +} + StatementVerifier::StatementVerifier(VerificationType type, string name, unique_ptr statement_p, optional_ptr> parameters_p) : type(type), name(std::move(name)), statement(std::move(statement_p)), select_statement(statement->type == StatementType::SELECT_STATEMENT ? &statement->Cast() : nullptr), parameters(parameters_p), - select_list(select_statement ? select_statement->node->GetSelectList() : empty_select_list) { + select_list(select_statement ? GetSelectList(*select_statement->node) : empty_select_list) { } StatementVerifier::StatementVerifier(unique_ptr statement_p, diff --git a/src/duckdb/ub_extension_parquet_writer.cpp b/src/duckdb/ub_extension_parquet_writer.cpp index 5efcd2c3c..cca73c331 100644 --- a/src/duckdb/ub_extension_parquet_writer.cpp +++ b/src/duckdb/ub_extension_parquet_writer.cpp @@ -10,7 +10,5 @@ #include "extension/parquet/writer/primitive_column_writer.cpp" -#include "extension/parquet/writer/variant_column_writer.cpp" - #include "extension/parquet/writer/struct_column_writer.cpp" diff --git a/src/duckdb/ub_src_common_row_operations.cpp b/src/duckdb/ub_src_common_row_operations.cpp index f1ac77f8e..f8f47aee8 100644 --- a/src/duckdb/ub_src_common_row_operations.cpp +++ b/src/duckdb/ub_src_common_row_operations.cpp @@ -1,16 +1,4 @@ #include "src/common/row_operations/row_aggregate.cpp" -#include "src/common/row_operations/row_scatter.cpp" - -#include "src/common/row_operations/row_gather.cpp" - #include "src/common/row_operations/row_matcher.cpp" -#include "src/common/row_operations/row_external.cpp" - -#include "src/common/row_operations/row_radix_scatter.cpp" - -#include "src/common/row_operations/row_heap_scatter.cpp" - -#include "src/common/row_operations/row_heap_gather.cpp" - diff --git a/src/duckdb/ub_src_common_sort.cpp b/src/duckdb/ub_src_common_sort.cpp index bcddfcb4e..7aeebf9a5 100644 --- a/src/duckdb/ub_src_common_sort.cpp +++ b/src/duckdb/ub_src_common_sort.cpp @@ -1,10 +1,8 @@ -#include "src/common/sort/comparators.cpp" +#include "src/common/sort/hashed_sort.cpp" -#include "src/common/sort/merge_sorter.cpp" +#include "src/common/sort/sort.cpp" -#include "src/common/sort/radix_sort.cpp" +#include "src/common/sort/sorted_run.cpp" -#include "src/common/sort/sort_state.cpp" - -#include "src/common/sort/sorted_block.cpp" +#include "src/common/sort/sorted_run_merger.cpp" diff --git a/src/duckdb/ub_src_common_sorting.cpp b/src/duckdb/ub_src_common_sorting.cpp deleted file mode 100644 index b444cb55b..000000000 --- a/src/duckdb/ub_src_common_sorting.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include "src/common/sorting/hashed_sort.cpp" - -#include "src/common/sorting/sort.cpp" - -#include "src/common/sorting/sorted_run.cpp" - -#include "src/common/sorting/sorted_run_merger.cpp" - diff --git a/src/duckdb/ub_src_common_types_row.cpp b/src/duckdb/ub_src_common_types_row.cpp index 3d4ff32c2..4fe55c37c 100644 --- a/src/duckdb/ub_src_common_types_row.cpp +++ b/src/duckdb/ub_src_common_types_row.cpp @@ -2,12 +2,6 @@ #include "src/common/types/row/partitioned_tuple_data.cpp" -#include "src/common/types/row/row_data_collection.cpp" - -#include "src/common/types/row/row_data_collection_scanner.cpp" - -#include "src/common/types/row/row_layout.cpp" - #include "src/common/types/row/tuple_data_allocator.cpp" #include "src/common/types/row/tuple_data_collection.cpp" diff --git a/src/duckdb/ub_src_planner_binder_query_node.cpp b/src/duckdb/ub_src_planner_binder_query_node.cpp index 3ec7b7ecb..acecbaf63 100644 --- a/src/duckdb/ub_src_planner_binder_query_node.cpp +++ b/src/duckdb/ub_src_planner_binder_query_node.cpp @@ -12,10 +12,6 @@ #include "src/planner/binder/query_node/plan_query_node.cpp" -#include "src/planner/binder/query_node/plan_recursive_cte_node.cpp" - -#include "src/planner/binder/query_node/plan_cte_node.cpp" - #include "src/planner/binder/query_node/plan_select_node.cpp" #include "src/planner/binder/query_node/plan_setop.cpp" From 8c8ae146816251062218feb4650ba5f861f7635b Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Sun, 12 Oct 2025 05:12:50 +0000 Subject: [PATCH 6/6] Update vendored DuckDB sources to 2762f1aa72 --- .../function/table/version/pragma_version.cpp | 6 +- .../include/duckdb/main/client_context.hpp | 4 - .../duckdb/parser/query_node/cte_node.hpp | 9 +- .../src/include/duckdb/parser/transformer.hpp | 3 +- .../src/include/duckdb/planner/binder.hpp | 12 +- .../duckdb/planner/bound_statement.hpp | 13 ++ .../query_node/bound_set_operation_node.hpp | 19 +- .../src/main/relation/query_relation.cpp | 22 --- .../src/parser/parsed_expression_iterator.cpp | 6 - src/duckdb/src/parser/query_node/cte_node.cpp | 29 +-- .../transform/helpers/transform_cte.cpp | 4 +- .../statement/transform_pivot_stmt.cpp | 2 +- .../transform/statement/transform_select.cpp | 7 +- src/duckdb/src/parser/transformer.cpp | 25 --- src/duckdb/src/planner/binder.cpp | 43 +---- .../binder/query_node/bind_cte_node.cpp | 110 ++++++++---- .../binder/query_node/bind_select_node.cpp | 23 +-- .../binder/query_node/bind_setop_node.cpp | 165 ++++++++---------- .../planner/binder/query_node/plan_setop.cpp | 34 +--- .../binder/tableref/bind_basetableref.cpp | 23 --- src/duckdb/src/storage/data_table.cpp | 2 +- .../serialization/serialize_query_node.cpp | 3 + .../storage/table/row_group_collection.cpp | 2 +- src/duckdb/src/transaction/undo_buffer.cpp | 2 +- .../src/verification/statement_verifier.cpp | 2 - 25 files changed, 204 insertions(+), 366 deletions(-) diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 4d811b36a..36a1707e8 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev966" +#define DUCKDB_PATCH_VERSION "0-dev988" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 5 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.5.0-dev966" +#define DUCKDB_VERSION "v1.5.0-dev988" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "9d77bcf518" +#define DUCKDB_SOURCE_ID "2762f1aa72" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/include/duckdb/main/client_context.hpp b/src/duckdb/src/include/duckdb/main/client_context.hpp index d0ad964ef..21f56ccf9 100644 --- a/src/duckdb/src/include/duckdb/main/client_context.hpp +++ b/src/duckdb/src/include/duckdb/main/client_context.hpp @@ -339,9 +339,6 @@ class QueryContext { } QueryContext(ClientContext &context) : context(&context) { // NOLINT: allow implicit construction } - QueryContext(weak_ptr context) // NOLINT: allow implicit construction - : owning_context(context.lock()), context(owning_context.get()) { - } public: bool Valid() const { @@ -352,7 +349,6 @@ class QueryContext { } private: - shared_ptr owning_context; optional_ptr context; }; diff --git a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp index 4ad41748b..fd2589fd2 100644 --- a/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp +++ b/src/duckdb/src/include/duckdb/parser/query_node/cte_node.hpp @@ -14,6 +14,7 @@ namespace duckdb { +//! DEPRECATED - CTENode is only preserved for backwards compatibility when serializing older databases class CTENode : public QueryNode { public: static constexpr const QueryNodeType TYPE = QueryNodeType::CTE_NODE; @@ -23,26 +24,18 @@ class CTENode : public QueryNode { } string ctename; - //! The query of the CTE unique_ptr query; - //! Child unique_ptr child; - //! Aliases of the CTE node vector aliases; CTEMaterialize materialized = CTEMaterialize::CTE_MATERIALIZE_DEFAULT; public: - //! Convert the query node to a string string ToString() const override; bool Equals(const QueryNode *other) const override; - //! Create a copy of this SelectNode unique_ptr Copy() const override; - //! Serializes a QueryNode to a stand-alone binary blob - //! Deserializes a blob back into a QueryNode - void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &source); }; diff --git a/src/duckdb/src/include/duckdb/parser/transformer.hpp b/src/duckdb/src/include/duckdb/parser/transformer.hpp index 59e4f0419..1945ebc5a 100644 --- a/src/duckdb/src/include/duckdb/parser/transformer.hpp +++ b/src/duckdb/src/include/duckdb/parser/transformer.hpp @@ -80,7 +80,7 @@ class Transformer { //! The set of pivot entries to create vector> pivot_entries; //! Sets of stored CTEs, if any - vector stored_cte_map; + vector> stored_cte_map; //! Whether or not we are currently binding a window definition bool in_window_definition = false; @@ -304,7 +304,6 @@ class Transformer { string TransformAlias(duckdb_libpgquery::PGAlias *root, vector &column_name_alias); vector TransformStringList(duckdb_libpgquery::PGList *list); void TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map); - static unique_ptr TransformMaterializedCTE(unique_ptr root); unique_ptr TransformRecursiveCTE(duckdb_libpgquery::PGCommonTableExpr &node, CommonTableExpressionInfo &info); diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index bd93bc8b8..914d86b6b 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -69,6 +69,7 @@ struct UnpivotEntry; struct CopyInfo; struct CopyOption; struct BoundSetOpChild; +struct BoundCTEData; template class IndexVector; @@ -409,12 +410,11 @@ class Binder : public enable_shared_from_this { unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); - BoundStatement BindCTE(CTENode &statement); + BoundStatement BindCTE(const string &ctename, CommonTableExpressionInfo &info); BoundStatement BindNode(SelectNode &node); BoundStatement BindNode(SetOperationNode &node); BoundStatement BindNode(RecursiveCTENode &node); - BoundStatement BindNode(CTENode &node); BoundStatement BindNode(QueryNode &node); BoundStatement BindNode(StatementNode &node); @@ -423,8 +423,7 @@ class Binder : public enable_shared_from_this { unique_ptr CreatePlan(BoundSetOperationNode &node); unique_ptr CreatePlan(BoundQueryNode &node); - BoundSetOpChild BindSetOpChild(QueryNode &child); - unique_ptr BindSetOpNode(SetOperationNode &statement); + void BuildUnionByNameInfo(BoundSetOperationNode &result); BoundStatement BindJoin(Binder &parent, TableRef &ref); BoundStatement Bind(BaseTableRef &ref); @@ -515,8 +514,6 @@ class Binder : public enable_shared_from_this { LogicalType BindLogicalTypeInternal(const LogicalType &type, optional_ptr catalog, const string &schema); BoundStatement BindSelectNode(SelectNode &statement, BoundStatement from_table); - unique_ptr BindSelectNodeInternal(SelectNode &statement); - unique_ptr BindSelectNodeInternal(SelectNode &statement, BoundStatement from_table); unique_ptr BindCopyDatabaseSchema(Catalog &source_catalog, const string &target_database_name); unique_ptr BindCopyDatabaseData(Catalog &source_catalog, const string &target_database_name); @@ -544,6 +541,9 @@ class Binder : public enable_shared_from_this { static void CheckInsertColumnCountMismatch(idx_t expected_columns, idx_t result_columns, bool columns_provided, const string &tname); + BoundCTEData PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement); + BoundStatement FinishCTE(BoundCTEData &bound_cte, BoundStatement child_data); + private: Binder(ClientContext &context, shared_ptr parent, BinderType binder_type); }; diff --git a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp index bb1f7bfec..f89f0bf68 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_statement.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_statement.hpp @@ -10,16 +10,29 @@ #include "duckdb/common/string.hpp" #include "duckdb/common/vector.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" +#include "duckdb/common/shared_ptr.hpp" namespace duckdb { class LogicalOperator; struct LogicalType; +struct BoundStatement; +class ParsedExpression; +class Binder; + +struct ExtraBoundInfo { + SetOperationType setop_type = SetOperationType::NONE; + vector> child_binders; + vector bound_children; + vector> original_expressions; +}; struct BoundStatement { unique_ptr plan; vector types; vector names; + ExtraBoundInfo extra_info; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp index 0939da695..675007b50 100644 --- a/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp +++ b/src/duckdb/src/include/duckdb/planner/query_node/bound_set_operation_node.hpp @@ -13,7 +13,6 @@ #include "duckdb/planner/bound_query_node.hpp" namespace duckdb { -struct BoundSetOpChild; //! Bound equivalent of SetOperationNode class BoundSetOperationNode : public BoundQueryNode { @@ -23,7 +22,9 @@ class BoundSetOperationNode : public BoundQueryNode { //! whether the ALL modifier was used or not bool setop_all = false; //! The bound children - vector bound_children; + vector bound_children; + //! Child binders + vector> child_binders; //! Index used by the set operation idx_t setop_index; @@ -34,18 +35,4 @@ class BoundSetOperationNode : public BoundQueryNode { } }; -struct BoundSetOpChild { - unique_ptr bound_node; - BoundStatement node; - shared_ptr binder; - //! Original select list (if this was a SELECT statement) - vector> select_list; - //! Exprs used by the UNION BY NAME operations to add a new projection - vector> reorder_expressions; - - const vector &GetNames(); - const vector &GetTypes(); - idx_t GetRootIndex(); -}; - } // namespace duckdb diff --git a/src/duckdb/src/main/relation/query_relation.cpp b/src/duckdb/src/main/relation/query_relation.cpp index ee7ac72d9..6ebebbc70 100644 --- a/src/duckdb/src/main/relation/query_relation.cpp +++ b/src/duckdb/src/main/relation/query_relation.cpp @@ -67,7 +67,6 @@ BoundStatement QueryRelation::Bind(Binder &binder) { if (first_bind) { auto &query_node = *select_stmt->node; auto &cte_map = query_node.cte_map; - vector> materialized_ctes; for (auto &kv : replacements) { auto &name = kv.first; auto &tableref = kv.second; @@ -88,28 +87,7 @@ BoundStatement QueryRelation::Bind(Binder &binder) { cte_info->query = std::move(select); cte_map.map[name] = std::move(cte_info); - - // We can not rely on CTE inlining anymore, so we need to add a materialized CTE node - // to the query node to ensure that the CTE exists - auto &cte_entry = cte_map.map[name]; - auto mat_cte = make_uniq(); - mat_cte->ctename = name; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt->node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->cte_map = root->cte_map.Copy(); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); } - select_stmt->node = std::move(root); } replacements.clear(); binder.SetBindingMode(saved_binding_mode); diff --git a/src/duckdb/src/parser/parsed_expression_iterator.cpp b/src/duckdb/src/parser/parsed_expression_iterator.cpp index 7ca38a10e..47e39dc76 100644 --- a/src/duckdb/src/parser/parsed_expression_iterator.cpp +++ b/src/duckdb/src/parser/parsed_expression_iterator.cpp @@ -271,12 +271,6 @@ void ParsedExpressionIterator::EnumerateQueryNodeChildren( EnumerateQueryNodeChildren(*rcte_node.right, expr_callback, ref_callback); break; } - case QueryNodeType::CTE_NODE: { - auto &cte_node = node.Cast(); - EnumerateQueryNodeChildren(*cte_node.query, expr_callback, ref_callback); - EnumerateQueryNodeChildren(*cte_node.child, expr_callback, ref_callback); - break; - } case QueryNodeType::SELECT_NODE: { auto &sel_node = node.Cast(); for (idx_t i = 0; i < sel_node.select_list.size(); i++) { diff --git a/src/duckdb/src/parser/query_node/cte_node.cpp b/src/duckdb/src/parser/query_node/cte_node.cpp index 1e1f0e199..250f31d1b 100644 --- a/src/duckdb/src/parser/query_node/cte_node.cpp +++ b/src/duckdb/src/parser/query_node/cte_node.cpp @@ -5,38 +5,15 @@ namespace duckdb { string CTENode::ToString() const { - string result; - result += child->ToString(); - return result; + throw InternalException("CTENode is a legacy type"); } bool CTENode::Equals(const QueryNode *other_p) const { - if (!QueryNode::Equals(other_p)) { - return false; - } - if (this == other_p) { - return true; - } - auto &other = other_p->Cast(); - - if (!query->Equals(other.query.get())) { - return false; - } - if (!child->Equals(other.child.get())) { - return false; - } - return true; + throw InternalException("CTENode is a legacy type"); } unique_ptr CTENode::Copy() const { - auto result = make_uniq(); - result->ctename = ctename; - result->query = query->Copy(); - result->child = child->Copy(); - result->aliases = aliases; - result->materialized = materialized; - this->CopyProperties(*result); - return std::move(result); + throw InternalException("CTENode is a legacy type"); } } // namespace duckdb diff --git a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp index 2de5d8334..f53c6dbb8 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp @@ -25,7 +25,7 @@ CommonTableExpressionInfo::~CommonTableExpressionInfo() { void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { for (auto &cte_entry : stored_cte_map) { - for (auto &entry : cte_entry->map) { + for (auto &entry : cte_entry.get().map) { auto found_entry = cte_map.map.find(entry.first); if (found_entry != cte_map.map.end()) { // entry already present - use top-most entry @@ -40,7 +40,7 @@ void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { } void Transformer::TransformCTE(duckdb_libpgquery::PGWithClause &de_with_clause, CommonTableExpressionMap &cte_map) { - stored_cte_map.push_back(&cte_map); + stored_cte_map.push_back(cte_map); // TODO: might need to update in case of future lawsuit D_ASSERT(de_with_clause.ctes); diff --git a/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp index 07dfb420d..4572a3a36 100644 --- a/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_pivot_stmt.cpp @@ -95,7 +95,7 @@ unique_ptr Transformer::GenerateCreateEnumStmt(unique_ptr(); - select->node = TransformMaterializedCTE(std::move(subselect)); + select->node = std::move(subselect); info->query = std::move(select); info->type = LogicalType::INVALID; diff --git a/src/duckdb/src/parser/transform/statement/transform_select.cpp b/src/duckdb/src/parser/transform/statement/transform_select.cpp index 2e5135ef6..16cd1a490 100644 --- a/src/duckdb/src/parser/transform/statement/transform_select.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_select.cpp @@ -26,13 +26,10 @@ unique_ptr Transformer::TransformSelectNodeInternal(duckdb_libpgquery throw ParserException("SELECT locking clause is not supported!"); } } - unique_ptr stmt = nullptr; if (select.pivot) { - stmt = TransformPivotStatement(select); - } else { - stmt = TransformSelectInternal(select); + return TransformPivotStatement(select); } - return TransformMaterializedCTE(std::move(stmt)); + return TransformSelectInternal(select); } unique_ptr Transformer::TransformSelectStmt(duckdb_libpgquery::PGSelectStmt &select, bool is_select) { diff --git a/src/duckdb/src/parser/transformer.cpp b/src/duckdb/src/parser/transformer.cpp index 4ab39fca7..32ddaa87a 100644 --- a/src/duckdb/src/parser/transformer.cpp +++ b/src/duckdb/src/parser/transformer.cpp @@ -232,31 +232,6 @@ unique_ptr Transformer::TransformStatementInternal(duckdb_libpgque } } -unique_ptr Transformer::TransformMaterializedCTE(unique_ptr root) { - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - - for (auto &cte : root->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = TransformMaterializedCTE(cte_entry->query->node->Copy()); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - - return root; -} - void Transformer::SetQueryLocation(ParsedExpression &expr, int query_location) { if (query_location < 0) { return; diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index d99521743..3e0c6ee02 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -68,28 +68,9 @@ BoundStatement Binder::BindWithCTE(T &statement) { return Bind(statement); } - // Extract materialized CTEs from cte_map - vector> materialized_ctes; - for (auto &cte : cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = std::move(cte_entry->query->node); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - unique_ptr cte_root = make_uniq(statement); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(cte_root); - cte_root = std::move(node_result); - materialized_ctes.pop_back(); - } - - return Bind(*cte_root); + auto stmt_node = make_uniq(statement); + stmt_node->cte_map = cte_map.Copy(); + return Bind(*stmt_node); } BoundStatement Binder::Bind(SQLStatement &statement) { @@ -152,24 +133,6 @@ BoundStatement Binder::Bind(SQLStatement &statement) { } // LCOV_EXCL_STOP } -BoundStatement Binder::BindNode(QueryNode &node) { - // now we bind the node - switch (node.type) { - case QueryNodeType::SELECT_NODE: - return BindNode(node.Cast()); - case QueryNodeType::RECURSIVE_CTE_NODE: - return BindNode(node.Cast()); - case QueryNodeType::CTE_NODE: - return BindNode(node.Cast()); - case QueryNodeType::SET_OPERATION_NODE: - return BindNode(node.Cast()); - case QueryNodeType::STATEMENT_NODE: - return BindNode(node.Cast()); - default: - throw InternalException("Unsupported query node type"); - } -} - BoundStatement Binder::Bind(QueryNode &node) { return BindNode(node); } diff --git a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp index 120fdbceb..7ae3c513c 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_cte_node.cpp @@ -1,37 +1,74 @@ #include "duckdb/parser/query_node/cte_node.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/parser/query_node/list.hpp" +#include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { -BoundStatement Binder::BindNode(CTENode &statement) { - // first recursively visit the materialized CTE operations - // the left side is visited first and is added to the BindContext of the right side - D_ASSERT(statement.query); - - return BindCTE(statement); +struct BoundCTEData { + string ctename; + CTEMaterialize materialized; + idx_t setop_index; + shared_ptr query_binder; + shared_ptr child_binder; + BoundStatement query; + vector names; + vector types; +}; + +BoundStatement Binder::BindNode(QueryNode &node) { + reference current_binder(*this); + vector bound_ctes; + for (auto &cte : node.cte_map.map) { + bound_ctes.push_back(current_binder.get().PrepareCTE(cte.first, *cte.second)); + current_binder = *bound_ctes.back().child_binder; + } + BoundStatement result; + // now we bind the node + switch (node.type) { + case QueryNodeType::SELECT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::RECURSIVE_CTE_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::SET_OPERATION_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + case QueryNodeType::STATEMENT_NODE: + result = current_binder.get().BindNode(node.Cast()); + break; + default: + throw InternalException("Unsupported query node type"); + } + for (idx_t i = bound_ctes.size(); i > 0; i--) { + auto &finish_binder = i == 1 ? *this : *bound_ctes[i - 2].child_binder; + result = finish_binder.FinishCTE(bound_ctes[i - 1], std::move(result)); + } + return result; } -BoundStatement Binder::BindCTE(CTENode &statement) { - BoundStatement result; +BoundCTEData Binder::PrepareCTE(const string &ctename, CommonTableExpressionInfo &statement) { + BoundCTEData result; // first recursively visit the materialized CTE operations // the left side is visited first and is added to the BindContext of the right side D_ASSERT(statement.query); - auto ctename = statement.ctename; - auto materialized = statement.materialized; - auto setop_index = GenerateTableIndex(); + result.ctename = ctename; + result.materialized = statement.materialized; + result.setop_index = GenerateTableIndex(); AddCTE(ctename); - auto query_binder = Binder::CreateBinder(context, this); - auto query = query_binder->BindNode(*statement.query); + result.query_binder = Binder::CreateBinder(context, this); + result.query = result.query_binder->BindNode(*statement.query->node); // the result types of the CTE are the types of the LHS - result.types = query.types; + result.types = result.query.types; // names are picked from the LHS, unless aliases are explicitly specified - result.names = query.names; + result.names = result.query.names; for (idx_t i = 0; i < statement.aliases.size() && i < result.names.size(); i++) { result.names[i] = statement.aliases[i]; } @@ -50,42 +87,39 @@ BoundStatement Binder::BindCTE(CTENode &statement) { ci_names.insert(name); } - // This allows the right side to reference the CTE - bind_context.AddGenericBinding(setop_index, statement.ctename, names, result.types); - - auto child_binder = Binder::CreateBinder(context, this); + result.child_binder = Binder::CreateBinder(context, this); // Add bindings of left side to temporary CTE bindings context // If there is already a binding for the CTE, we need to remove it first // as we are binding a CTE currently, we take precendence over the existing binding. // This implements the CTE shadowing behavior. - child_binder->bind_context.AddCTEBinding(setop_index, statement.ctename, names, result.types); - - BoundStatement child; - if (statement.child) { - child = child_binder->BindNode(*statement.child); - for (auto &c : query_binder->correlated_columns) { - child_binder->AddCorrelatedColumn(c); - } - - // the result types of the CTE are the types of the LHS - result.types = child.types; - result.names = child.names; + result.child_binder->bind_context.AddCTEBinding(result.setop_index, ctename, names, result.types); + return result; +} - MoveCorrelatedExpressions(*child_binder); +BoundStatement Binder::FinishCTE(BoundCTEData &bound_cte, BoundStatement child) { + for (auto &c : bound_cte.query_binder->correlated_columns) { + bound_cte.child_binder->AddCorrelatedColumn(c); } - MoveCorrelatedExpressions(*query_binder); + BoundStatement result; + // the result types of the CTE are the types of the LHS + result.types = child.types; + result.names = child.names; + + MoveCorrelatedExpressions(*bound_cte.child_binder); + MoveCorrelatedExpressions(*bound_cte.query_binder); - auto cte_query = std::move(query.plan); + auto cte_query = std::move(bound_cte.query.plan); auto cte_child = std::move(child.plan); - auto root = make_uniq(ctename, setop_index, result.types.size(), std::move(cte_query), - std::move(cte_child), materialized); + auto root = make_uniq(bound_cte.ctename, bound_cte.setop_index, result.types.size(), + std::move(cte_query), std::move(cte_child), bound_cte.materialized); // check if there are any unplanned subqueries left in either child - has_unplanned_dependent_joins = has_unplanned_dependent_joins || child_binder->has_unplanned_dependent_joins || - query_binder->has_unplanned_dependent_joins; + has_unplanned_dependent_joins = has_unplanned_dependent_joins || + bound_cte.child_binder->has_unplanned_dependent_joins || + bound_cte.query_binder->has_unplanned_dependent_joins; result.plan = std::move(root); return result; } diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index a02f878a9..7bfc4b22e 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -372,15 +372,6 @@ BoundStatement Binder::BindNode(SelectNode &statement) { return BindSelectNode(statement, std::move(from_table)); } -unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement) { - D_ASSERT(statement.from_table); - - // first bind the FROM table statement - auto from = std::move(statement.from_table); - auto from_table = Bind(*from); - return BindSelectNodeInternal(statement, std::move(from_table)); -} - void Binder::BindWhereStarExpression(unique_ptr &expr) { // expand any expressions in the upper AND recursively if (expr->GetExpressionType() == ExpressionType::CONJUNCTION_AND) { @@ -412,7 +403,7 @@ void Binder::BindWhereStarExpression(unique_ptr &expr) { } } -unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement, BoundStatement from_table) { +BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { D_ASSERT(from_table.plan); D_ASSERT(!statement.from_table); auto result_ptr = make_uniq(); @@ -688,16 +679,12 @@ unique_ptr Binder::BindSelectNodeInternal(SelectNode &statement // now that the SELECT list is bound, we set the types of DISTINCT/ORDER BY expressions BindModifiers(result, result.projection_index, result.names, internal_sql_types, bind_state); - return result_ptr; -} - -BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from_table) { - auto result = BindSelectNodeInternal(statement, std::move(from_table)); BoundStatement result_statement; - result_statement.types = result->types; - result_statement.names = result->names; - result_statement.plan = CreatePlan(*result); + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.original_expressions = std::move(result.bind_state.original_expressions); return result_statement; } diff --git a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp index 6c301d1de..91a501b2f 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_setop_node.cpp @@ -10,6 +10,7 @@ #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/query_node/bound_set_operation_node.hpp" #include "duckdb/planner/expression_binder/select_bind_state.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/common/enum_util.hpp" namespace duckdb { @@ -19,30 +20,22 @@ struct SetOpAliasGatherer { explicit SetOpAliasGatherer(SelectBindState &bind_state_p) : bind_state(bind_state_p) { } - void GatherAliases(BoundSetOpChild &node, const vector &reorder_idx); - void GatherAliases(BoundSetOperationNode &node, const vector &reorder_idx); + void GatherAliases(BoundStatement &stmt, const vector &reorder_idx); + void GatherSetOpAliases(SetOperationType setop_type, const vector &names, + vector &bound_children, const vector &reorder_idx); private: SelectBindState &bind_state; }; -const vector &BoundSetOpChild::GetNames() { - return bound_node ? bound_node->names : node.names; -} -const vector &BoundSetOpChild::GetTypes() { - return bound_node ? bound_node->types : node.types; -} -idx_t BoundSetOpChild::GetRootIndex() { - return bound_node ? bound_node->GetRootIndex() : node.plan->GetRootIndex(); -} -void SetOpAliasGatherer::GatherAliases(BoundSetOpChild &node, const vector &reorder_idx) { - if (node.bound_node) { - GatherAliases(*node.bound_node, reorder_idx); +void SetOpAliasGatherer::GatherAliases(BoundStatement &stmt, const vector &reorder_idx) { + if (stmt.extra_info.setop_type != SetOperationType::NONE) { + GatherSetOpAliases(stmt.extra_info.setop_type, stmt.names, stmt.extra_info.bound_children, reorder_idx); return; } // query node - auto &select_names = node.GetNames(); + auto &select_names = stmt.names; // fill the alias lists with the names D_ASSERT(reorder_idx.size() == select_names.size()); for (idx_t i = 0; i < select_names.size(); i++) { @@ -58,8 +51,9 @@ void SetOpAliasGatherer::GatherAliases(BoundSetOpChild &node, const vector &reorder_idx) { +void SetOpAliasGatherer::GatherSetOpAliases(SetOperationType setop_type, const vector &stmt_names, + vector &bound_children, const vector &reorder_idx) { // create new reorder index - if (setop.setop_type == SetOperationType::UNION_BY_NAME) { + if (setop_type == SetOperationType::UNION_BY_NAME) { + auto &setop_names = stmt_names; // for UNION BY NAME - create a new re-order index case_insensitive_map_t reorder_map; - for (idx_t col_idx = 0; col_idx < setop.names.size(); ++col_idx) { - reorder_map[setop.names[col_idx]] = reorder_idx[col_idx]; + for (idx_t col_idx = 0; col_idx < setop_names.size(); ++col_idx) { + reorder_map[setop_names[col_idx]] = reorder_idx[col_idx]; } // use new reorder index - for (auto &child : setop.bound_children) { + for (auto &child : bound_children) { vector new_reorder_idx; - auto &child_names = child.GetNames(); + auto &child_names = child.names; for (idx_t col_idx = 0; col_idx < child_names.size(); col_idx++) { auto &col_name = child_names[col_idx]; auto entry = reorder_map.find(col_name); @@ -100,22 +96,23 @@ void SetOpAliasGatherer::GatherAliases(BoundSetOperationNode &setop, const vecto GatherAliases(child, new_reorder_idx); } } else { - for (auto &child : setop.bound_children) { + for (auto &child : bound_children) { GatherAliases(child, reorder_idx); } } } -static void GatherAliases(BoundSetOperationNode &node, SelectBindState &bind_state) { +static void GatherAliases(BoundSetOperationNode &root, vector &child_statements, + SelectBindState &bind_state) { SetOpAliasGatherer gatherer(bind_state); vector reorder_idx; - for (idx_t i = 0; i < node.names.size(); i++) { + for (idx_t i = 0; i < root.names.size(); i++) { reorder_idx.push_back(i); } - gatherer.GatherAliases(node, reorder_idx); + gatherer.GatherSetOpAliases(root.setop_type, root.names, child_statements, reorder_idx); } -static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode &result, bool can_contain_nulls) { +void Binder::BuildUnionByNameInfo(BoundSetOperationNode &result) { D_ASSERT(result.setop_type == SetOperationType::UNION_BY_NAME); vector> node_name_maps; case_insensitive_set_t global_name_set; @@ -124,7 +121,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & // We throw a binder exception if two same name in the SELECT list D_ASSERT(result.names.empty()); for (auto &child : result.bound_children) { - auto &child_names = child.GetNames(); + auto &child_names = child.names; case_insensitive_map_t node_name_map; for (idx_t i = 0; i < child_names.size(); ++i) { auto &col_name = child_names[i]; @@ -152,7 +149,7 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & auto &col_name = result.names[i]; LogicalType result_type(LogicalTypeId::INVALID); for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { - auto &child_types = result.bound_children[child_idx].GetTypes(); + auto &child_types = result.bound_children[child_idx].types; auto &child_name_map = node_name_maps[child_idx]; // check if the column exists in this child node auto entry = child_name_map.find(col_name); @@ -188,6 +185,8 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & return; } // If reorder is required, generate the expressions for each node + vector>> reorder_expressions; + reorder_expressions.resize(result.bound_children.size()); for (idx_t i = 0; i < new_size; ++i) { auto &col_name = result.names[i]; for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { @@ -202,52 +201,42 @@ static void BuildUnionByNameInfo(ClientContext &context, BoundSetOperationNode & } else { // the column exists - reference it auto col_idx_in_child = entry->second; - auto &child_col_type = child.GetTypes()[col_idx_in_child]; - expr = make_uniq(child_col_type, - ColumnBinding(child.GetRootIndex(), col_idx_in_child)); + auto &child_col_type = child.types[col_idx_in_child]; + auto root_idx = child.plan->GetRootIndex(); + expr = make_uniq(child_col_type, ColumnBinding(root_idx, col_idx_in_child)); } - child.reorder_expressions.push_back(std::move(expr)); + reorder_expressions[child_idx].push_back(std::move(expr)); } } -} - -BoundSetOpChild Binder::BindSetOpChild(QueryNode &child) { - BoundSetOpChild bound_child; - if (child.type == QueryNodeType::SET_OPERATION_NODE) { - bound_child.bound_node = BindSetOpNode(child.Cast()); - } else { - bound_child.binder = Binder::CreateBinder(context, this); - bound_child.binder->can_contain_nulls = true; - if (child.type == QueryNodeType::SELECT_NODE) { - auto &select_node = child.Cast(); - auto bound_select_node = bound_child.binder->BindSelectNodeInternal(select_node); - for (auto &expr : bound_select_node->bind_state.original_expressions) { - bound_child.select_list.push_back(expr->Copy()); - } - bound_child.node.names = bound_select_node->names; - bound_child.node.types = bound_select_node->types; - bound_child.node.plan = bound_child.binder->CreatePlan(*bound_select_node); - } else { - bound_child.node = bound_child.binder->BindNode(child); + // now push projections for each node + for (idx_t child_idx = 0; child_idx < result.bound_children.size(); ++child_idx) { + auto &child = result.bound_children[child_idx]; + auto &child_reorder_expressions = reorder_expressions[child_idx]; + // if we have re-order expressions push a projection + vector child_types; + for (auto &expr : child_reorder_expressions) { + child_types.push_back(expr->return_type); } + auto child_projection = + make_uniq(GenerateTableIndex(), std::move(child_reorder_expressions)); + child_projection->children.push_back(std::move(child.plan)); + child.plan = std::move(child_projection); + child.types = std::move(child_types); } - return bound_child; } -static void GatherSetOpBinders(BoundSetOpChild &setop_child, vector> &binders) { - if (setop_child.binder) { - binders.push_back(*setop_child.binder); - return; +static void GatherSetOpBinders(vector &children, vector> &binders, + vector> &result) { + for (auto &child_binder : binders) { + result.push_back(*child_binder); } - auto &setop_node = *setop_child.bound_node; - for (auto &child : setop_node.bound_children) { - GatherSetOpBinders(child, binders); + for (auto &child_node : children) { + GatherSetOpBinders(child_node.extra_info.bound_children, child_node.extra_info.child_binders, result); } } -unique_ptr Binder::BindSetOpNode(SetOperationNode &statement) { - auto result_ptr = make_uniq(); - auto &result = *result_ptr; +BoundStatement Binder::BindNode(SetOperationNode &statement) { + BoundSetOperationNode result; result.setop_type = statement.setop_type; result.setop_all = statement.setop_all; @@ -262,27 +251,23 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem throw InternalException("Set Operation type must have exactly 2 children - except for UNION/UNION_BY_NAME"); } for (auto &child : statement.children) { - result.bound_children.push_back(BindSetOpChild(*child)); - } - - vector> binders; - for (auto &child : result.bound_children) { - GatherSetOpBinders(child, binders); - } - // move the correlated expressions from the child binders to this binder - for (auto &child_binder : binders) { - MoveCorrelatedExpressions(child_binder.get()); + auto child_binder = Binder::CreateBinder(context, this); + child_binder->can_contain_nulls = true; + auto child_node = child_binder->BindNode(*child); + MoveCorrelatedExpressions(*child_binder); + result.bound_children.push_back(std::move(child_node)); + result.child_binders.push_back(std::move(child_binder)); } if (result.setop_type == SetOperationType::UNION_BY_NAME) { // UNION BY NAME - merge the columns from all sides - BuildUnionByNameInfo(context, result, can_contain_nulls); + BuildUnionByNameInfo(result); } else { // UNION ALL BY POSITION - the columns of both sides must match exactly - result.names = result.bound_children[0].GetNames(); - auto result_columns = result.bound_children[0].GetTypes().size(); + result.names = result.bound_children[0].names; + auto result_columns = result.bound_children[0].types.size(); for (idx_t i = 1; i < result.bound_children.size(); ++i) { - if (result.bound_children[i].GetTypes().size() != result_columns) { + if (result.bound_children[i].types.size() != result_columns) { throw BinderException("Set operations can only apply to expressions with the " "same number of result columns"); } @@ -290,9 +275,9 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem // figure out the types of the setop result by picking the max of both for (idx_t i = 0; i < result_columns; i++) { - auto result_type = result.bound_children[0].GetTypes()[i]; + auto result_type = result.bound_children[0].types[i]; for (idx_t child_idx = 1; child_idx < result.bound_children.size(); ++child_idx) { - auto &child_types = result.bound_children[child_idx].GetTypes(); + auto &child_types = result.bound_children[child_idx].types; result_type = LogicalType::ForceMaxLogicalType(result_type, child_types[i]); } if (!can_contain_nulls) { @@ -307,7 +292,9 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem SelectBindState bind_state; if (!statement.modifiers.empty()) { // handle the ORDER BY/DISTINCT clauses - GatherAliases(result, bind_state); + vector> binders; + GatherSetOpBinders(result.bound_children, result.child_binders, binders); + GatherAliases(result, result.bound_children, bind_state); // now we perform the actual resolution of the ORDER BY/DISTINCT expressions OrderBinder order_binder(binders, bind_state); @@ -316,16 +303,14 @@ unique_ptr Binder::BindSetOpNode(SetOperationNode &statem // finally bind the types of the ORDER/DISTINCT clause expressions BindModifiers(result, result.setop_index, result.names, result.types, bind_state); - return result_ptr; -} - -BoundStatement Binder::BindNode(SetOperationNode &statement) { - auto result = BindSetOpNode(statement); BoundStatement result_statement; - result_statement.types = result->types; - result_statement.names = result->names; - result_statement.plan = CreatePlan(*result); + result_statement.types = result.types; + result_statement.names = result.names; + result_statement.plan = CreatePlan(result); + result_statement.extra_info.setop_type = statement.setop_type; + result_statement.extra_info.bound_children = std::move(result.bound_children); + result_statement.extra_info.child_binders = std::move(result.child_binders); return result_statement; } diff --git a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp index fec93aa51..a1a7f60b0 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp @@ -113,34 +113,16 @@ unique_ptr Binder::CreatePlan(BoundSetOperationNode &node) { D_ASSERT(node.bound_children.size() >= 2); vector> children; - for (auto &child : node.bound_children) { - unique_ptr child_node; - if (child.bound_node) { - child_node = CreatePlan(*child.bound_node); - } else { - child.binder->is_outside_flattened = is_outside_flattened; + for (idx_t child_idx = 0; child_idx < node.bound_children.size(); child_idx++) { + auto &child = node.bound_children[child_idx]; + auto &child_binder = *node.child_binders[child_idx]; - // construct the logical plan for the child node - child_node = std::move(child.node.plan); - } - if (!child.reorder_expressions.empty()) { - // if we have re-order expressions push a projection - vector child_types; - for (auto &expr : child.reorder_expressions) { - child_types.push_back(expr->return_type); - } - auto child_projection = - make_uniq(GenerateTableIndex(), std::move(child.reorder_expressions)); - child_projection->children.push_back(std::move(child_node)); - child_node = std::move(child_projection); - - child_node = CastLogicalOperatorToTypes(child_types, node.types, std::move(child_node)); - } else { - // otherwise push only casts - child_node = CastLogicalOperatorToTypes(child.GetTypes(), node.types, std::move(child_node)); - } + // construct the logical plan for the child node + auto child_node = std::move(child.plan); + // push casts for the target types + child_node = CastLogicalOperatorToTypes(child.types, node.types, std::move(child_node)); // check if there are any unplanned subqueries left in any child - if (child.binder && child.binder->has_unplanned_dependent_joins) { + if (child_binder.has_unplanned_dependent_joins) { has_unplanned_dependent_joins = true; } children.push_back(std::move(child_node)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp index de775a198..07f3fb023 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_basetableref.cpp @@ -312,29 +312,6 @@ BoundStatement Binder::Bind(BaseTableRef &ref) { // The view may contain CTEs, but maybe only in the cte_map, so we need create CTE nodes for them auto query = view_catalog_entry.GetQuery().Copy(); - auto &select_stmt = query->Cast(); - - vector> materialized_ctes; - for (auto &cte : select_stmt.node->cte_map.map) { - auto &cte_entry = cte.second; - auto mat_cte = make_uniq(); - mat_cte->ctename = cte.first; - mat_cte->query = cte_entry->query->node->Copy(); - mat_cte->aliases = cte_entry->aliases; - mat_cte->materialized = cte_entry->materialized; - materialized_ctes.push_back(std::move(mat_cte)); - } - - auto root = std::move(select_stmt.node); - while (!materialized_ctes.empty()) { - unique_ptr node_result; - node_result = std::move(materialized_ctes.back()); - node_result->child = std::move(root); - root = std::move(node_result); - materialized_ctes.pop_back(); - } - select_stmt.node = std::move(root); - SubqueryRef subquery(unique_ptr_cast(std::move(query))); subquery.alias = ref.alias; diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp index 4302ebb62..7cc6bf3e4 100644 --- a/src/duckdb/src/storage/data_table.cpp +++ b/src/duckdb/src/storage/data_table.cpp @@ -253,7 +253,7 @@ void DataTable::InitializeScanWithOffset(DuckTransaction &transaction, TableScan const vector &column_ids, idx_t start_row, idx_t end_row) { state.checkpoint_lock = transaction.SharedLockTable(*info); state.Initialize(column_ids); - row_groups->InitializeScanWithOffset(transaction.context, state.table_state, column_ids, start_row, end_row); + row_groups->InitializeScanWithOffset(QueryContext(), state.table_state, column_ids, start_row, end_row); } idx_t DataTable::GetRowGroupSize() const { diff --git a/src/duckdb/src/storage/serialization/serialize_query_node.cpp b/src/duckdb/src/storage/serialization/serialize_query_node.cpp index 50ab535d2..25b167558 100644 --- a/src/duckdb/src/storage/serialization/serialize_query_node.cpp +++ b/src/duckdb/src/storage/serialization/serialize_query_node.cpp @@ -38,6 +38,9 @@ unique_ptr QueryNode::Deserialize(Deserializer &deserializer) { } result->modifiers = std::move(modifiers); result->cte_map = std::move(cte_map); + if (type == QueryNodeType::CTE_NODE) { + result = std::move(result->Cast().child); + } return result; } diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp index 7c5300bdd..c0dc479ad 100644 --- a/src/duckdb/src/storage/table/row_group_collection.cpp +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -271,7 +271,7 @@ bool RowGroupCollection::Scan(DuckTransaction &transaction, const vector> &StatementVerifier::GetSelectList(Que return node.Cast().select_list; case QueryNodeType::SET_OPERATION_NODE: return GetSelectList(*node.Cast().children[0]); - case QueryNodeType::CTE_NODE: - return GetSelectList(*node.Cast().query); default: return empty_select_list; }