From 07f3b020330aa913b7a016f5fbdcf418014c8f24 Mon Sep 17 00:00:00 2001 From: "Zhang, Chaojun" Date: Wed, 8 Feb 2023 02:30:44 +0000 Subject: [PATCH] feat: upgrade substrait to 0.23.0 --- core/CMakeLists.txt | 16 - core/common/CMakeLists.txt | 20 - core/common/Exceptions.cpp | 36 -- core/common/tests/CMakeLists.txt | 25 - core/common/tests/ExceptionsTest.cpp | 20 - core/function/CMakeLists.txt | 29 - core/function/Extension.cpp | 291 --------- core/function/Function.cpp | 100 --- core/function/FunctionLookup.cpp | 40 -- core/function/tests/CMakeLists.txt | 25 - core/function/tests/FunctionLookupTest.cpp | 144 ----- core/type/CMakeLists.txt | 24 - core/type/Type.cpp | 529 ---------------- core/type/tests/CMakeLists.txt | 25 - core/type/tests/TypeTest.cpp | 175 ------ include/common/Exceptions.h | 140 ----- include/function/Extension.h | 90 --- include/function/Function.h | 120 ---- include/function/FunctionLookup.h | 98 --- include/function/FunctionMapping.h | 48 -- include/function/FunctionSignature.h | 32 - include/substrait/type/Type.h | 7 +- include/type/Type.h | 693 --------------------- scripts/setup-macos.sh | 89 --- src/CMakeLists.txt | 59 -- src/SubstraitExtension.cpp | 300 --------- src/SubstraitExtension.h | 75 --- src/SubstraitFunction.cpp | 98 --- src/SubstraitFunction.h | 122 ---- src/SubstraitType.cpp | 348 ----------- src/SubstraitType.h | 465 -------------- src/tests/CMakeLists.txt | 28 - src/tests/SubstraitExtensionTest.cpp | 82 --- src/tests/SubstraitTypeTest.cpp | 140 ----- substrait/function/Extension.cpp | 4 +- substrait/type/Type.cpp | 69 +- third_party/googletest | 2 +- 37 files changed, 66 insertions(+), 4542 deletions(-) delete mode 100644 core/CMakeLists.txt delete mode 100644 core/common/CMakeLists.txt delete mode 100644 core/common/Exceptions.cpp delete mode 100644 core/common/tests/CMakeLists.txt delete mode 100644 core/common/tests/ExceptionsTest.cpp delete mode 100644 core/function/CMakeLists.txt delete mode 100644 core/function/Extension.cpp delete mode 100644 core/function/Function.cpp delete mode 100644 core/function/FunctionLookup.cpp delete mode 100644 core/function/tests/CMakeLists.txt delete mode 100644 core/function/tests/FunctionLookupTest.cpp delete mode 100644 core/type/CMakeLists.txt delete mode 100644 core/type/Type.cpp delete mode 100644 core/type/tests/CMakeLists.txt delete mode 100644 core/type/tests/TypeTest.cpp delete mode 100644 include/common/Exceptions.h delete mode 100644 include/function/Extension.h delete mode 100644 include/function/Function.h delete mode 100644 include/function/FunctionLookup.h delete mode 100644 include/function/FunctionMapping.h delete mode 100644 include/function/FunctionSignature.h delete mode 100644 include/type/Type.h delete mode 100755 scripts/setup-macos.sh delete mode 100644 src/CMakeLists.txt delete mode 100644 src/SubstraitExtension.cpp delete mode 100644 src/SubstraitExtension.h delete mode 100644 src/SubstraitFunction.cpp delete mode 100644 src/SubstraitFunction.h delete mode 100644 src/SubstraitType.cpp delete mode 100644 src/SubstraitType.h delete mode 100644 src/tests/CMakeLists.txt delete mode 100644 src/tests/SubstraitExtensionTest.cpp delete mode 100644 src/tests/SubstraitTypeTest.cpp diff --git a/core/CMakeLists.txt b/core/CMakeLists.txt deleted file mode 100644 index ae435aad..00000000 --- a/core/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -add_subdirectory(common) -add_subdirectory(type) -add_subdirectory(function) diff --git a/core/common/CMakeLists.txt b/core/common/CMakeLists.txt deleted file mode 100644 index 97ab6836..00000000 --- a/core/common/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_library( - substrait_common - Exceptions.cpp) - -target_link_libraries( - substrait_common - fmt) - diff --git a/core/common/Exceptions.cpp b/core/common/Exceptions.cpp deleted file mode 100644 index 5eeb608e..00000000 --- a/core/common/Exceptions.cpp +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/Exceptions.h" -#include "fmt/format.h" - -namespace io::substrait::common { - -SubstraitException::SubstraitException( - std::string exceptionCode, - std::string& exceptionMessage, - Type exceptionType, - std::string exceptionName) - : msg_(fmt::format( - "Exception: {}\nError Code: {}\nType: {}\nReason: {}\n" - "Function: {}\nFile: {}\n:Line: {}\n", - exceptionName, - exceptionCode, - exceptionType == Type::kSystem ? "system" : "user", - exceptionMessage, - __FUNCTION__, - __FILE__, - std::to_string(__LINE__))) {} - -} // namespace io::substrait::common diff --git a/core/common/tests/CMakeLists.txt b/core/common/tests/CMakeLists.txt deleted file mode 100644 index dbf17ecb..00000000 --- a/core/common/tests/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_executable( - substrait_common_test - ExceptionsTest.cpp) - -add_test( - substrait_common_test - substrait_common_test) - -target_link_libraries( - substrait_common_test - substrait_common - gtest - gtest_main) diff --git a/core/common/tests/ExceptionsTest.cpp b/core/common/tests/ExceptionsTest.cpp deleted file mode 100644 index 6956ca11..00000000 --- a/core/common/tests/ExceptionsTest.cpp +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/Exceptions.h" -#include - -class SubstraitExceptionTest : public ::testing::Test {}; - -TEST_F(SubstraitExceptionTest, decodeTest) {} diff --git a/core/function/CMakeLists.txt b/core/function/CMakeLists.txt deleted file mode 100644 index dce4c4d4..00000000 --- a/core/function/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set(FUNCTION_SRCS - Function.cpp - Extension.cpp - ../../include/function/FunctionMapping.h - ../../include/function/FunctionSignature.h - FunctionLookup.cpp) - -add_library(substrait_function ${FUNCTION_SRCS}) - -target_link_libraries( - substrait_function - substrait_type - yaml-cpp) - -if (${BUILD_TESTING}) - add_subdirectory(tests) -endif () \ No newline at end of file diff --git a/core/function/Extension.cpp b/core/function/Extension.cpp deleted file mode 100644 index f7a6a374..00000000 --- a/core/function/Extension.cpp +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "function/Extension.h" -#include "yaml-cpp/yaml.h" - -bool decodeFunctionVariant( - const YAML::Node& node, - io::substrait::FunctionVariant& function) { - const auto& returnType = node["return"]; - if (returnType && returnType.IsScalar()) { - /// Return type can be an expression. - const auto& returnExpr = returnType.as(); - std::stringstream ss(returnExpr); - - // TODO: currently we only parse the last sentence of type definition, use - // ANTLR in future. - std::string lastReturnType; - while (std::getline(ss, lastReturnType, '\n')) { - } - function.returnType = io::substrait::Type::decode(lastReturnType); - } - const auto& args = node["args"]; - if (args && args.IsSequence()) { - for (auto& arg : args) { - if (arg["options"]) { // enum argument - auto enumArgument = std::make_shared( - arg.as()); - function.arguments.emplace_back(enumArgument); - } else if (arg["value"]) { // value argument - auto valueArgument = std::make_shared( - arg.as()); - function.arguments.emplace_back(valueArgument); - } else { // type argument - auto typeArgument = std::make_shared( - arg.as()); - function.arguments.emplace_back(typeArgument); - } - } - } - - const auto& variadic = node["variadic"]; - if (variadic) { - auto& min = variadic["min"]; - auto& max = variadic["max"]; - if (min) { - function.variadic = std::make_optional( - {min.as(), - max ? std::make_optional(max.as()) : std::nullopt}); - } else { - function.variadic = std::nullopt; - } - } else { - function.variadic = std::nullopt; - } - - return true; -} - -template <> -struct YAML::convert { - static bool decode(const Node& node, io::substrait::EnumArgument& argument) { - // 'options' is required property - const auto& options = node["options"]; - if (options && options.IsSequence()) { - auto& required = node["required"]; - argument.required = required && required.as(); - return true; - } else { - return false; - } - } -}; - -template <> -struct YAML::convert { - static bool decode(const Node& node, io::substrait::ValueArgument& argument) { - const auto& value = node["value"]; - if (value && value.IsScalar()) { - auto valueType = value.as(); - argument.type = io::substrait::Type::decode(valueType); - return true; - } - return false; - } -}; - -template <> -struct YAML::convert { - static bool decode( - const YAML::Node& node, - io::substrait::TypeArgument& argument) { - // no properties need to populate for type argument, just return true if - // 'type' element exists. - if (node["type"]) { - return true; - } - return false; - } -}; - -template <> -struct YAML::convert { - static bool decode( - const Node& node, - io::substrait::ScalarFunctionVariant& function) { - return decodeFunctionVariant(node, function); - }; -}; - -template <> -struct YAML::convert { - static bool decode( - const Node& node, - io::substrait::AggregateFunctionVariant& function) { - const auto& res = decodeFunctionVariant(node, function); - if (res) { - const auto& intermediate = node["intermediate"]; - if (intermediate) { - function.intermediate = - io::substrait::ParameterizedType::decode(intermediate.as()); - } - } - return res; - } -}; - -template <> -struct YAML::convert { - static bool decode(const Node& node, io::substrait::TypeVariant& typeAnchor) { - const auto& name = node["name"]; - if (name && name.IsScalar()) { - typeAnchor.name = name.as(); - return true; - } - return false; - } -}; - -namespace io::substrait { - -std::shared_ptr Extension::load(const std::string& basePath) { - static const std::vector extensionFiles{ - "functions_aggregate_approx.yaml", - "functions_aggregate_generic.yaml", - "functions_arithmetic.yaml", - "functions_arithmetic_decimal.yaml", - "functions_boolean.yaml", - "functions_comparison.yaml", - "functions_datetime.yaml", - "functions_logarithmic.yaml", - "functions_rounding.yaml", - "functions_string.yaml", - "functions_set.yaml", - }; - return load(basePath, extensionFiles); -} - -std::shared_ptr Extension::load( - const std::string& basePath, - const std::vector& extensionFiles) { - std::vector yamlExtensionFiles; - yamlExtensionFiles.reserve(extensionFiles.size()); - for (auto& extensionFile : extensionFiles) { - auto const pos = basePath.find_last_of('/'); - const auto& extensionUri = basePath.substr(0, pos) + "/" + extensionFile; - yamlExtensionFiles.emplace_back(extensionUri); - } - return load(yamlExtensionFiles); -} - -std::shared_ptr Extension::load( - const std::vector& extensionFiles) { - auto extension = std::make_shared(); - for (const auto& extensionUri : extensionFiles) { - const auto& node = YAML::LoadFile(extensionUri); - - const auto& scalarFunctions = node["scalar_functions"]; - if (scalarFunctions && scalarFunctions.IsSequence()) { - for (auto& scalarFunctionNode : scalarFunctions) { - const auto functionName = scalarFunctionNode["name"].as(); - for (auto& scalaFunctionVariantNode : scalarFunctionNode["impls"]) { - auto scalarFunctionVariant = - scalaFunctionVariantNode.as(); - scalarFunctionVariant.name = functionName; - scalarFunctionVariant.uri = extensionUri; - extension->addScalarFunctionVariant( - std::make_shared(scalarFunctionVariant)); - } - } - } - - const auto& aggregateFunctions = node["aggregate_functions"]; - if (aggregateFunctions && aggregateFunctions.IsSequence()) { - for (auto& aggregateFunctionNode : aggregateFunctions) { - const auto functionName = - aggregateFunctionNode["name"].as(); - for (auto& aggregateFunctionVariantNode : - aggregateFunctionNode["impls"]) { - auto aggregateFunctionVariant = - aggregateFunctionVariantNode.as(); - aggregateFunctionVariant.name = functionName; - aggregateFunctionVariant.uri = extensionUri; - extension->addAggregateFunctionVariant( - std::make_shared( - aggregateFunctionVariant)); - } - } - } - - const auto& types = node["types"]; - if (types && types.IsSequence()) { - for (auto& type : types) { - auto typeAnchor = type.as(); - typeAnchor.uri = extensionUri; - extension->addTypeVariant(std::make_shared(typeAnchor)); - } - } - } - return extension; -} - -void Extension::addWindowFunctionVariant( - const FunctionVariantPtr& functionVariant) { - const auto& functionVariants = - windowFunctionVariantMap_.find(functionVariant->name); - if (functionVariants != windowFunctionVariantMap_.end()) { - auto& variants = functionVariants->second; - variants.emplace_back(functionVariant); - } else { - std::vector variants; - variants.emplace_back(functionVariant); - windowFunctionVariantMap_.insert( - {functionVariant->name, std::move(variants)}); - } -} - -void Extension::addTypeVariant(const TypeVariantPtr& functionVariant) { - typeVariantMap_.insert({functionVariant->name, functionVariant}); -} - -TypeVariantPtr Extension::lookupType(const std::string& typeName) const { - auto typeVariantIter = typeVariantMap_.find(typeName); - if (typeVariantIter != typeVariantMap_.end()) { - return typeVariantIter->second; - } - return nullptr; -} - -void Extension::addScalarFunctionVariant( - const FunctionVariantPtr& functionVariant) { - const auto& functionVariants = - scalarFunctionVariantMap_.find(functionVariant->name); - if (functionVariants != scalarFunctionVariantMap_.end()) { - auto& variants = functionVariants->second; - variants.emplace_back(functionVariant); - } else { - std::vector variants; - variants.emplace_back(functionVariant); - scalarFunctionVariantMap_.insert( - {functionVariant->name, std::move(variants)}); - } -} - -void Extension::addAggregateFunctionVariant( - const FunctionVariantPtr& functionVariant) { - const auto& functionVariants = - aggregateFunctionVariantMap_.find(functionVariant->name); - if (functionVariants != aggregateFunctionVariantMap_.end()) { - auto& variants = functionVariants->second; - variants.emplace_back(functionVariant); - } else { - std::vector variants; - variants.emplace_back(functionVariant); - aggregateFunctionVariantMap_.insert( - {functionVariant->name, std::move(variants)}); - } -} - -} // namespace io::substrait diff --git a/core/function/Function.cpp b/core/function/Function.cpp deleted file mode 100644 index c831c521..00000000 --- a/core/function/Function.cpp +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "function/Function.h" -#include - -namespace io::substrait { - -std::string FunctionVariant::signature( - const std::string& name, - const std::vector& arguments) { - std::stringstream ss; - ss << name; - if (!arguments.empty()) { - ss << ":"; - for (auto it = arguments.begin(); it != arguments.end(); ++it) { - const auto& typeSign = (*it)->toTypeString(); - if (it == arguments.end() - 1) { - ss << typeSign; - } else { - ss << typeSign << "_"; - } - } - } - - return ss.str(); -} - -bool FunctionVariant::tryMatch(const FunctionSignature& signature) { - const auto& actualTypes = signature.arguments; - if (variadic.has_value()) { - // return false if actual types length less than min of variadic - const auto max = variadic->max; - if ((actualTypes.size() < variadic->min) || - (max.has_value() && actualTypes.size() > max.value())) { - return false; - } - - const auto& variadicArgument = arguments[0]; - // actual type must same as the variadicArgument - if (const auto& variadicValueArgument = - std::dynamic_pointer_cast(variadicArgument)) { - for (auto& actualType : actualTypes) { - if (!variadicValueArgument->type->isMatch(actualType)) { - return false; - } - } - } - } else { - std::vector> valueArguments; - for (const auto& argument : arguments) { - if (const auto& variadicValueArgument = - std::dynamic_pointer_cast(argument)) { - valueArguments.emplace_back(variadicValueArgument); - } - } - // return false if size of actual types not equal to size of value - // arguments. - if (valueArguments.size() != actualTypes.size()) { - return false; - } - - for (auto i = 0; i < actualTypes.size(); i++) { - const auto& valueArgument = valueArguments[i]; - if (!valueArgument->type->isMatch(actualTypes[i])) { - return false; - } - } - } - const auto& sigReturnType = signature.returnType; - if (this->returnType && sigReturnType) { - return returnType->isMatch(sigReturnType); - } else { - return true; - } -} - -bool AggregateFunctionVariant::tryMatch(const FunctionSignature& signature) { - bool matched = FunctionVariant::tryMatch(signature); - if (!matched && intermediate) { - const auto& actualTypes = signature.arguments; - if (actualTypes.size() == 1) { - return intermediate->isMatch(actualTypes[0]); - } - } - return matched; -} - -} // namespace io::substrait diff --git a/core/function/FunctionLookup.cpp b/core/function/FunctionLookup.cpp deleted file mode 100644 index f85a0c02..00000000 --- a/core/function/FunctionLookup.cpp +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "function/FunctionLookup.h" - -namespace io::substrait { - -FunctionVariantPtr FunctionLookup::lookupFunction( - const FunctionSignature& signature) const { - const auto& functionMappings = getFunctionMap(); - - const auto& substraitFunctionName = - functionMappings.find(signature.name) != functionMappings.end() - ? functionMappings.at(signature.name) - : signature.name; - - const auto& functionVariants = getFunctionVariants(); - auto functionVariantIter = functionVariants.find(substraitFunctionName); - if (functionVariantIter != functionVariants.end()) { - for (const auto& candidateFunctionVariant : functionVariantIter->second) { - if (candidateFunctionVariant->tryMatch(signature)) { - return candidateFunctionVariant; - } - } - } - return nullptr; -} - -} // namespace io::substrait diff --git a/core/function/tests/CMakeLists.txt b/core/function/tests/CMakeLists.txt deleted file mode 100644 index 209176a0..00000000 --- a/core/function/tests/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_executable( - substrait_function_test - FunctionLookupTest.cpp) - -add_test( - substrait_function_test - substrait_function_test) - -target_link_libraries( - substrait_function_test - substrait_function - gtest - gtest_main) diff --git a/core/function/tests/FunctionLookupTest.cpp b/core/function/tests/FunctionLookupTest.cpp deleted file mode 100644 index 217f7217..00000000 --- a/core/function/tests/FunctionLookupTest.cpp +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "function/FunctionLookup.h" -#include -#include - -using namespace io::substrait; - -class VeloxFunctionMappings : public FunctionMapping { - public: - /// scalar function names in difference between velox and Substrait. - const FunctionMap& scalaMapping() const override { - static const FunctionMap scalarMappings{ - {"plus", "add"}, - {"minus", "subtract"}, - {"mod", "modulus"}, - {"eq", "equal"}, - {"neq", "not_equal"}, - {"substr", "substring"}, - }; - return scalarMappings; - }; -}; - -class FunctionLookupTest : public ::testing::Test { - protected: - std::string getExtensionAbsolutePath() { - const std::string absolute_path = __FILE__; - auto const pos = absolute_path.find_last_of('/'); - return absolute_path.substr(0, pos) + - "/../../../third_party/substrait/extensions/"; - } - - void SetUp() override { - ExtensionPtr extension_ = Extension::load(getExtensionAbsolutePath()); - FunctionMappingPtr mappings_ = - std::make_shared(); - scalarFunctionLookup_ = - std::make_shared(extension_, mappings_); - aggregateFunctionLookup_ = - std::make_shared(extension_, mappings_); - } - - void testScalarFunctionLookup( - const FunctionSignature& inputSignature, - const std::string& outputSignature) { - const auto& functionVariant = - scalarFunctionLookup_->lookupFunction(inputSignature); - - ASSERT_TRUE(functionVariant != nullptr); - ASSERT_EQ(functionVariant->signature(), outputSignature); - } - - void testAggregateFunctionLookup( - const FunctionSignature& inputSignature, - const std::string& outputSignature) { - const auto& functionVariant = - aggregateFunctionLookup_->lookupFunction(inputSignature); - - ASSERT_TRUE(functionVariant != nullptr); - ASSERT_EQ(functionVariant->signature(), outputSignature); - } - - private: - FunctionLookupPtr scalarFunctionLookup_; - FunctionLookupPtr aggregateFunctionLookup_; -}; - -TEST_F(FunctionLookupTest, compare_function) { - testScalarFunctionLookup( - {"lt", {TINYINT(), TINYINT()}, BOOL()}, "lt:any1_any1"); - - testScalarFunctionLookup( - {"lt", {SMALLINT(), SMALLINT()}, BOOL()}, "lt:any1_any1"); - - testScalarFunctionLookup( - {"lt", {INTEGER(), INTEGER()}, BOOL()}, "lt:any1_any1"); - - testScalarFunctionLookup( - {"lt", {BIGINT(), BIGINT()}, BOOL()}, "lt:any1_any1"); - - testScalarFunctionLookup({"lt", {FLOAT(), FLOAT()}, BOOL()}, "lt:any1_any1"); - - testScalarFunctionLookup( - {"lt", {DOUBLE(), DOUBLE()}, BOOL()}, "lt:any1_any1"); - testScalarFunctionLookup( - {"between", {TINYINT(), TINYINT(), TINYINT()}, BOOL()}, - "between:any1_any1_any1"); -} - -TEST_F(FunctionLookupTest, arithmetic_function) { - testScalarFunctionLookup( - {"add", {TINYINT(), TINYINT()}, TINYINT()}, "add:opt_i8_i8"); - - testScalarFunctionLookup( - {"plus", {TINYINT(), TINYINT()}, TINYINT()}, "add:opt_i8_i8"); - testScalarFunctionLookup( - {"divide", - { - FLOAT(), - FLOAT(), - }, - FLOAT()}, - "divide:opt_opt_opt_fp32_fp32"); -} - -TEST_F(FunctionLookupTest, aggregate) { - // for intermediate type - testAggregateFunctionLookup( - {"avg", {ROW({DOUBLE(), BIGINT()})}, FLOAT()}, "avg:opt_fp32"); -} - -TEST_F(FunctionLookupTest, logical) { - testScalarFunctionLookup({"and", {}, BOOL()}, "and:bool"); - testScalarFunctionLookup({"and", {BOOL()},BOOL()}, "and:bool"); - testScalarFunctionLookup({"and", {BOOL(), BOOL()},BOOL()}, "and:bool"); - - testScalarFunctionLookup({"or", {BOOL(), BOOL()}, BOOL()}, "or:bool"); - testScalarFunctionLookup({"not", {BOOL()}, BOOL()}, "not:bool"); - testScalarFunctionLookup({"xor", {BOOL(), BOOL()}, BOOL()}, "xor:bool_bool"); -} - -TEST_F(FunctionLookupTest, string_function) { - testScalarFunctionLookup( - {"like", {STRING(), STRING()}, BOOL()}, "like:opt_str_str"); - testScalarFunctionLookup( - {"like", {VARCHAR(3), VARCHAR(4)}, BOOL()}, - "like:opt_vchar_vchar"); - testScalarFunctionLookup( - {"substr", {STRING(), INTEGER(), INTEGER()}, STRING()}, - "substring:str_i32_i32"); -} diff --git a/core/type/CMakeLists.txt b/core/type/CMakeLists.txt deleted file mode 100644 index 0de6bd5f..00000000 --- a/core/type/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set(TYPE_SRCS - Type.cpp) - -add_library(substrait_type ${TYPE_SRCS}) - -target_link_libraries( - substrait_type - substrait_common) - -if (${BUILD_TESTING}) - add_subdirectory(tests) -endif () \ No newline at end of file diff --git a/core/type/Type.cpp b/core/type/Type.cpp deleted file mode 100644 index fbbf7db5..00000000 --- a/core/type/Type.cpp +++ /dev/null @@ -1,529 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "type/Type.h" -#include -#include -#include -#include "common/Exceptions.h" - -namespace io::substrait { - -namespace { - -size_t findNextComma(const std::string& str, size_t start) { - int cnt = 0; - for (auto i = start; i < str.size(); i++) { - if (str[i] == '<') { - cnt++; - } else if (str[i] == '>') { - cnt--; - } else if (cnt == 0 && str[i] == ',') { - return i; - } - } - - return std::string::npos; -} - -} // namespace - -ParameterizedTypePtr ParameterizedType::decode(const std::string& rawType) { - std::string matchingType = rawType; - std::transform( - matchingType.begin(), - matchingType.end(), - matchingType.begin(), - [](unsigned char c) { return std::tolower(c); }); - - const auto& questionMaskPos = matchingType.find_last_of('?'); - - bool nullable = questionMaskPos != std::string::npos; - - const auto& leftAngleBracketPos = matchingType.find('<'); - if (leftAngleBracketPos == std::string::npos) { - // deal with type and with a question mask like "i32?". - const auto& baseType = nullable - ? matchingType = matchingType.substr(0, questionMaskPos) - : matchingType; - - if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>( - nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>( - nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>( - nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared>(nullable); - } else if (matchingType.rfind("unknown", 0) == 0) { - return std::make_shared(rawType, nullable); - } else { - return std::make_shared(rawType); - } - } else { - const auto& rightAngleBracketPos = rawType.rfind('>'); - const auto& baseTypePos = nullable - ? std::min(leftAngleBracketPos, questionMaskPos) - : leftAngleBracketPos; - - const auto& baseType = matchingType.substr(0, baseTypePos); - - std::vector nestedTypes; - auto prevPos = leftAngleBracketPos + 1; - auto commaPos = findNextComma(rawType, prevPos); - while (commaPos != std::string::npos) { - auto token = rawType.substr(prevPos, commaPos - prevPos); - nestedTypes.emplace_back(decode(token)); - prevPos = commaPos + 1; - commaPos = findNextComma(rawType, prevPos); - } - auto token = rawType.substr(prevPos, rightAngleBracketPos - prevPos); - nestedTypes.emplace_back(decode(token)); - - if (TypeTraits::typeString == baseType) { - return std::make_shared(nestedTypes[0], nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared( - nestedTypes[0], nestedTypes[1], nullable); - } else if (TypeTraits::typeString == baseType) { - return std::make_shared(nestedTypes, nullable); - } else if (TypeTraits::typeString == baseType) { - StringLiteralPtr precision = - std::dynamic_pointer_cast(nestedTypes[0]); - StringLiteralPtr scale = - std::dynamic_pointer_cast(nestedTypes[1]); - return std::make_shared(precision, scale, nullable); - } else if (TypeTraits::typeString == baseType) { - auto length = - std::dynamic_pointer_cast(nestedTypes[0]); - return std::make_shared(length, nullable); - } else if (TypeTraits::typeString == baseType) { - auto length = - std::dynamic_pointer_cast(nestedTypes[0]); - return std::make_shared(length, nullable); - } else if (TypeTraits::typeString == baseType) { - auto length = - std::dynamic_pointer_cast(nestedTypes[0]); - return std::make_shared(length, nullable); - } else { - SUBSTRAIT_UNSUPPORTED("Unsupported type: " + rawType); - } - } -} - -std::string Decimal::signature() const { - std::stringstream sign; - sign << TypeBase::signature(); - sign << "<" << precision_ << "," << scale_ << ">"; - return sign.str(); -} - -bool Decimal::isMatch( - const std::shared_ptr& type) const { - if (auto decimalType = std::dynamic_pointer_cast(type)) { - return TypeBase::isMatch(type) && precision_ == decimalType->precision() && - scale_ == decimalType->scale(); - } - - return false; -} - -std::string FixedBinary::signature() const { - std::stringstream sign; - sign << TypeBase::signature(); - sign << "<" << length() << ">"; - return sign.str(); -} -bool FixedBinary::isMatch( - const std::shared_ptr& type) const { - if (auto fBinaryType = std::dynamic_pointer_cast(type)) { - return TypeBase::isMatch(type) && length_ == fBinaryType->length(); - } - - return false; -} - -std::string FixedChar::signature() const { - std::stringstream sign; - sign << TypeBase::signature(); - sign << "<" << length() << ">"; - return sign.str(); -} - -bool FixedChar::isMatch( - const std::shared_ptr& type) const { - if (auto fBinaryType = std::dynamic_pointer_cast(type)) { - return TypeBase::isMatch(type) && length_ == fBinaryType->length(); - } - return false; -} - -std::string Varchar::signature() const { - std::stringstream sign; - sign << TypeBase::signature(); - sign << "<" << length() << ">"; - return sign.str(); -} - -bool Varchar::isMatch( - const std::shared_ptr& type) const { - if (auto varcharType = std::dynamic_pointer_cast(type)) { - return TypeBase::isMatch(type) && length_ == varcharType->length(); - } - return false; -} - -std::string List::signature() const { - std::stringstream sign; - sign << TypeBase::signature(); - sign << "<" << elementType_->signature() << ">"; - return sign.str(); -} - -bool List::isMatch(const std::shared_ptr& type) const { - if (auto listType = std::dynamic_pointer_cast(type)) { - return TypeBase::isMatch(type) && - elementType()->isMatch(listType->elementType()); - } - return false; -} - -std::string Struct::signature() const { - std::stringstream sign; - sign << TypeBase::signature(); - sign << "<"; - for (auto it = children_.begin(); it != children_.end(); ++it) { - const auto& typeSign = (*it)->signature(); - if (it == children_.end() - 1) { - sign << typeSign; - } else { - sign << typeSign << ","; - } - } - sign << ">"; - return sign.str(); -} -bool Struct::isMatch( - const std::shared_ptr& type) const { - if (auto structType = std::dynamic_pointer_cast(type)) { - bool sameSize = structType->children_.size() == children_.size(); - if (sameSize) { - for (int i = 0; i < children_.size(); i++) { - if (!children_[i]->isMatch(structType->children_[i])) { - return false; - } - } - return true; - } - } - return false; -} - -std::string Map::signature() const { - std::stringstream sign; - sign << TypeBase::signature(); - sign << "<"; - sign << keyType()->signature(); - sign << ","; - sign << valueType()->signature(); - sign << ">"; - return sign.str(); -} - -bool Map::isMatch(const std::shared_ptr& type) const { - if (auto mapType = std::dynamic_pointer_cast(type)) { - return TypeBase::isMatch(type) && keyType()->isMatch(mapType->keyType()) && - valueType()->isMatch(mapType->valueType()); - } - return false; -} - -std::string ParameterizedFixedBinary::signature() const { - std::stringstream sign; - sign << TypeTraits::signature; - sign << "<" << length_->value() << ">"; - return sign.str(); -} - -bool ParameterizedFixedBinary::isMatch( - const std::shared_ptr& type) const { - if (auto parameterizedFixedBinary = - std::dynamic_pointer_cast(type)) { - return length()->isMatch(parameterizedFixedBinary->length()) && - nullMatch(type); - } - - return false; -} - -std::string ParameterizedDecimal::signature() const { - std::stringstream sign; - sign << TypeTraits::signature; - sign << "<" << precision_->value() << "," << scale_->value() << ">"; - return sign.str(); -} - -bool ParameterizedDecimal::isMatch( - const std::shared_ptr& type) const { - if (auto decimal = std::dynamic_pointer_cast(type)) { - return nullMatch(type); - } - - return false; -} - -std::string ParameterizedFixedChar::signature() const { - std::stringstream sign; - sign << TypeTraits::signature; - sign << "<" << length_->value() << ">"; - return sign.str(); -} -bool ParameterizedFixedChar::isMatch( - const std::shared_ptr& type) const { - if (auto fixedChar = std::dynamic_pointer_cast(type)) { - return nullMatch(type); - } - - return false; -} - -std::string ParameterizedVarchar::signature() const { - std::stringstream sign; - sign << TypeTraits::signature; - sign << "<" << length_->value() << ">"; - return sign.str(); -} - -bool ParameterizedVarchar::isMatch( - const std::shared_ptr& type) const { - if (auto varchar = std::dynamic_pointer_cast(type)) { - return nullMatch(type); - } - - return false; -} - -std::string ParameterizedList::signature() const { - std::stringstream sign; - sign << TypeTraits::signature; - sign << "<" << elementType()->signature() << ">"; - return sign.str(); -} - -bool ParameterizedList::isMatch( - const std::shared_ptr& type) const { - if (auto list = std::dynamic_pointer_cast(type)) { - return elementType()->isMatch(list->elementType()) && nullMatch(type); - } - - return false; -} - -std::string ParameterizedStruct::signature() const { - std::stringstream sign; - sign << TypeTraits::signature; - sign << "<"; - for (auto it = children_.begin(); it != children_.end(); ++it) { - const auto& typeSign = (*it)->signature(); - if (it == children_.end() - 1) { - sign << typeSign; - } else { - sign << typeSign << ","; - } - } - sign << ">"; - return sign.str(); -} - -bool ParameterizedStruct::isMatch( - const std::shared_ptr& type) const { - if (auto structType = std::dynamic_pointer_cast(type)) { - bool sameSize = structType->children().size() == children_.size(); - if (sameSize) { - for (int i = 0; i < children_.size(); i++) { - if (!children_[i]->isMatch(structType->children()[i])) { - return false; - } - } - return nullMatch(type); - } - } - return false; -} - -std::string ParameterizedMap::signature() const { - std::stringstream sign; - sign << TypeTraits::signature; - sign << "<"; - sign << keyType()->signature(); - sign << ","; - sign << valueType()->signature(); - sign << ">"; - return sign.str(); -} - -bool ParameterizedMap::isMatch( - const std::shared_ptr& type) const { - if (auto mapType = std::dynamic_pointer_cast(type)) { - return keyType()->isMatch(mapType->keyType()) && - valueType()->isMatch(mapType->valueType()) && nullMatch(type); - } - return false; -} - -std::shared_ptr> BOOL() { - return std::make_shared>(false); -} - -std::shared_ptr> TINYINT() { - return std::make_shared>(false); -} - -std::shared_ptr> SMALLINT() { - return std::make_shared>(false); -} - -std::shared_ptr> INTEGER() { - return std::make_shared>(false); -} - -std::shared_ptr> BIGINT() { - return std::make_shared>(false); -} -std::shared_ptr> FLOAT() { - return std::make_shared>(false); -} - -std::shared_ptr> DOUBLE() { - return std::make_shared>(false); -} - -std::shared_ptr> STRING() { - return std::make_shared>(false); -} - -std::shared_ptr> BINARY() { - return std::make_shared>(false); -} - -std::shared_ptr> TIMESTAMP() { - return std::make_shared>(false); -} - -std::shared_ptr> DATE() { - return std::make_shared>(false); -} - -std::shared_ptr> TIME() { - return std::make_shared>(false); -} - -std::shared_ptr> INTERVAL_YEAR() { - return std::make_shared>(false); -} - -std::shared_ptr> INTERVAL_DAY() { - return std::make_shared>(false); -} - -std::shared_ptr> TIMESTAMP_TZ() { - return std::make_shared>(false); -} - -std::shared_ptr> UUID() { - return std::make_shared>(false); -} - -std::shared_ptr DECIMAL(int precision, int scale) { - return std::make_shared(precision, scale, false); -} - -std::shared_ptr VARCHAR(int len) { - return std::make_shared(len, false); -} -std::shared_ptr FCHAR(int len) { - return std::make_shared(len, false); -} - -std::shared_ptr FBinary(int len) { - return std::make_shared(len, false); -} - -std::shared_ptr LIST(const TypePtr& elementType) { - return std::make_shared(elementType, false); -} - -std::shared_ptr MAP( - const TypePtr& keyType, - const TypePtr& valueType) { - return std::make_shared(keyType, valueType, false); -} - -std::shared_ptr ROW(const std::vector& children) { - return std::make_shared(children, false); -} - -std::shared_ptr FChar(int len) { - return std::make_shared(len); -} - -bool StringLiteral::isMatch( - const std::shared_ptr& type) const { - if (isWildcard()) { - return true; - } else { - if (auto stringLiteral = - std::dynamic_pointer_cast(type)) { - return value_ == stringLiteral->value_; - } - return false; - } -} - -bool UsedDefinedType::isMatch( - const std::shared_ptr& type) const { - if (auto udt = std::dynamic_pointer_cast(type)) { - return value_ == udt->value_ && nullable() == udt->nullable(); - } - return true; -} - -} // namespace io::substrait diff --git a/core/type/tests/CMakeLists.txt b/core/type/tests/CMakeLists.txt deleted file mode 100644 index 6b7119d2..00000000 --- a/core/type/tests/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_executable( - substrait_type_test - TypeTest.cpp) - -add_test( - substrait_type_test - substrait_type_test) - -target_link_libraries( - substrait_type_test - substrait_type - gtest - gtest_main) diff --git a/core/type/tests/TypeTest.cpp b/core/type/tests/TypeTest.cpp deleted file mode 100644 index b9fe2cfe..00000000 --- a/core/type/tests/TypeTest.cpp +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "type/Type.h" -#include - -using namespace io::substrait; - -class TypeTest : public ::testing::Test { - protected: - template - void testDecode(const std::string& rawType, const std::string& signature) { - const auto& type = ParameterizedType::decode(rawType); - ASSERT_TRUE(type->kind() == kind); - ASSERT_EQ(type->signature(), signature); - } - - void testType( - const ParameterizedTypePtr& type, - TypeKind kind, - const std::string& signature) { - ASSERT_EQ(type->kind(), kind); - ASSERT_EQ(type->signature(), signature); - } - - template - void testDecode( - const std::string& rawType, - const std::function&)>& - typeCallBack) { - const auto& type = ParameterizedType::decode(rawType); - if (typeCallBack) { - typeCallBack(std::dynamic_pointer_cast(type)); - } - } -}; - -TEST_F(TypeTest, typeCreator) { - testType(BOOL(), TypeKind::kBool, "bool"); - testType(TINYINT(), TypeKind::kI8, "i8"); - testType(SMALLINT(), TypeKind::kI16, "i16"); - testType(INTEGER(), TypeKind::kI32, "i32"); - testType(BIGINT(), TypeKind::kI64, "i64"); - testType(FLOAT(), TypeKind::kFp32, "fp32"); - testType(DOUBLE(), TypeKind::kFp64, "fp64"); - testType(BINARY(), TypeKind::kBinary, "vbin"); - testType(TIMESTAMP(), TypeKind::kTimestamp, "ts"); - testType(STRING(), TypeKind::kString, "str"); - testType(TIMESTAMP_TZ(), TypeKind::kTimestampTz, "tstz"); - testType(DATE(), TypeKind::kDate, "date"); - testType(TIME(), TypeKind::kTime, "time"); - testType(INTERVAL_DAY(), TypeKind::kIntervalDay, "iday"); - testType(INTERVAL_YEAR(), TypeKind::kIntervalYear, "iyear"); - testType(UUID(), TypeKind::kUuid, "uuid"); - testType(FChar(12), TypeKind::kFixedChar, "fchar<12>"); - testType(FBinary(12), TypeKind::kFixedBinary, "fbin<12>"); - testType(VARCHAR(12), TypeKind::kVarchar, "vchar<12>"); - testType(DECIMAL(12,23), TypeKind::kDecimal, "dec<12,23>"); - testType(LIST(FLOAT()), TypeKind::kList, "list"); - testType(MAP(STRING(),FLOAT()), TypeKind::kMap, "map"); - testType(ROW({STRING(),FLOAT()}), TypeKind::kStruct, "struct"); -} - -TEST_F(TypeTest, decodeTest) { - testDecode("i32?", "i32"); - testDecode("BOOLEAN", "bool"); - testDecode("boolean", "bool"); - testDecode("i8", "i8"); - testDecode("i16", "i16"); - testDecode("i32", "i32"); - testDecode("i64", "i64"); - testDecode("fp32", "fp32"); - testDecode("fp64", "fp64"); - testDecode("binary", "vbin"); - testDecode("timestamp", "ts"); - testDecode("string", "str"); - testDecode("timestamp_tz", "tstz"); - testDecode("date", "date"); - testDecode("time", "time"); - testDecode("interval_day", "iday"); - testDecode("interval_year", "iyear"); - testDecode("uuid", "uuid"); - - testDecode( - "fixedchar", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->length()->value(), "L1"); - ASSERT_EQ(typePtr->signature(), "fchar"); - }); - - testDecode( - "fixedbinary", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->length()->value(), "L1"); - ASSERT_EQ(typePtr->signature(), "fbin"); - }); - - testDecode( - "varchar", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "vchar"); - ASSERT_EQ(typePtr->length()->value(), "L1"); - }); - - testDecode( - "decimal", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "dec"); - ASSERT_EQ(typePtr->precision()->value(), "P"); - ASSERT_EQ(typePtr->scale()->value(), "S"); - }); - - testDecode( - "struct", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "struct"); - }); - - testDecode( - "struct>", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "struct>"); - }); - - testDecode( - "list", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "list"); - }); - testDecode( - "LIST?", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "list"); - }); - - testDecode( - "map", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "map"); - }); - - testDecode( - "any1", [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "any1"); - ASSERT_TRUE(typePtr->isWildcard()); - }); - - testDecode( - "any", [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "any"); - ASSERT_TRUE(typePtr->isWildcard()); - }); - - testDecode( - "T", [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "T"); - ASSERT_TRUE(typePtr->isWildcard()); - }); - - testDecode( - "unknown", [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "u!name"); - }); -} diff --git a/include/common/Exceptions.h b/include/common/Exceptions.h deleted file mode 100644 index df1f5720..00000000 --- a/include/common/Exceptions.h +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include "fmt/format.h" - -namespace io::substrait::common { -namespace error_code { - -//====================== User Error Codes ======================: - -// An error raised when an argument verification fails -inline constexpr auto kInvalidArgument = "INVALID_ARGUMENT"; - -// An error raised when a requested operation is not supported. -inline constexpr auto kUnsupported = "UNSUPPORTED"; - -//====================== Runtime Error Codes ======================: - -// An error raised when the current state of a component is invalid. -inline constexpr auto kInvalidState = "INVALID_STATE"; - -// An error raised when unreachable code point was executed. -inline constexpr auto kUnreachableCode = "UNREACHABLE_CODE"; - -// An error raised when a requested operation is not yet supported. -inline constexpr auto kNotImplemented = "NOT_IMPLEMENTED"; - -// An error raised when a method has been passed an illegal or inappropriate -// argument. -inline constexpr auto kIllegalArgument = "ILLEGAL_ARGUMENT"; - -} // namespace error_code - -class SubstraitException : public std::exception { - public: - enum class Type { kUser = 0, kSystem = 1 }; - - SubstraitException( - std::string exceptionCode, - std::string& exceptionMessage, - Type exceptionType = Type::kSystem, - std::string exceptionName = "SubstraitException"); - - // Inherited - const char* what() const noexcept override { - return msg_.c_str(); - } - - private: - const std::string msg_; -}; - -class SubstraitUserError : public SubstraitException { - public: - SubstraitUserError( - std::string exceptionCode, - std::string& exceptionMessage, - std::string exceptionName = "SubstraitUserError") - : SubstraitException( - exceptionCode, - exceptionMessage, - Type::kUser, - exceptionName) {} -}; - -class SubstraitRuntimeError final : public SubstraitException { - public: - SubstraitRuntimeError( - std::string exceptionCode, - std::string& exceptionMessage, - std::string exceptionName = "SubstraitRuntimeError") - : SubstraitException( - exceptionCode, - exceptionMessage, - Type::kSystem, - exceptionName) {} -}; - -template -std::string errorMessage(fmt::string_view fmt, const Args&... args) { - return fmt::vformat(fmt, fmt::make_format_args(args...)); -} - -#define _SUBSTRAIT_THROW(exception, errorCode, ...) \ - { \ - auto message = io::substrait::common::errorMessage(__VA_ARGS__); \ - throw exception(errorCode, message); \ - } - -#define SUBSTRAIT_UNSUPPORTED(...) \ - _SUBSTRAIT_THROW( \ - ::io::substrait::common::SubstraitUserError, \ - ::io::substrait::common::error_code::kUnsupported, \ - ##__VA_ARGS__) - -#define SUBSTRAIT_UNREACHABLE(...) \ - _SUBSTRAIT_THROW( \ - ::io::substrait::common::SubstraitRuntimeError, \ - ::io::substrait::common::error_code::kUnreachableCode, \ - ##__VA_ARGS__) - -#define SUBSTRAIT_FAIL(...) \ - _SUBSTRAIT_THROW( \ - ::io::substrait::common::SubstraitRuntimeError, \ - ::io::substrait::common::error_code::kInvalidState, \ - ##__VA_ARGS__) - -#define SUBSTRAIT_USER_FAIL(...) \ - _SUBSTRAIT_THROW( \ - ::io::substrait::common::SubstraitUserError, \ - ::io::substrait::common::error_code::kInvalidState, \ - ##__VA_ARGS__) - -#define SUBSTRAIT_NYI(...) \ - _SUBSTRAIT_THROW( \ - ::io::substrait::common::SubstraitRuntimeError, \ - ::io::substrait::common::error_code::kNotImplemented, \ - ##__VA_ARGS__) - -#define SUBSTRAIT_ILLEGAL_ARGUMENT(...) \ - _SUBSTRAIT_THROW( \ - ::io::substrait::common::SubstraitUserError, \ - ::io::substrait::common::error_code::kIllegalArgument, \ - ##__VA_ARGS__) - -} // namespace io::substrait::common diff --git a/include/function/Extension.h b/include/function/Extension.h deleted file mode 100644 index e6f05ea6..00000000 --- a/include/function/Extension.h +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "function/FunctionSignature.h" -#include "function/Function.h" -#include "type/Type.h" - -namespace io::substrait { - -struct TypeVariant { - std::string name; - std::string uri; -}; - -using TypeVariantPtr = std::shared_ptr; - -using FunctionVariantMap = - std::unordered_map>; - -using TypeVariantMap = std::unordered_map; - -class Extension { - public: - /// Deserialize default substrait extension by given basePath - /// @throws exception if file not found - static std::shared_ptr load(const std::string& basePath); - - /// Deserialize substrait extension by given basePath and extensionFiles. - static std::shared_ptr load( - const std::string& basePath, - const std::vector& extensionFiles); - - /// Deserialize substrait extension by given extensionFiles. - static std::shared_ptr load( - const std::vector& extensionFiles); - - /// Add a scalar function variant. - void addScalarFunctionVariant(const FunctionVariantPtr& functionVariant); - - /// Add a aggregate function variant. - void addAggregateFunctionVariant(const FunctionVariantPtr& functionVariant); - - /// Add a window function variant. - void addWindowFunctionVariant(const FunctionVariantPtr& functionVariant); - - /// Add a type variant. - void addTypeVariant(const TypeVariantPtr& functionVariant); - - /// Lookup type variant by given type name. - /// @return matched type variant - TypeVariantPtr lookupType(const std::string& typeName) const; - - const FunctionVariantMap& scalaFunctionVariantMap() const { - return scalarFunctionVariantMap_; - } - - const FunctionVariantMap& windowFunctionVariantMap() const { - return windowFunctionVariantMap_; - } - - const FunctionVariantMap& aggregateFunctionVariantMap() const { - return aggregateFunctionVariantMap_; - } - - private: - FunctionVariantMap scalarFunctionVariantMap_; - - FunctionVariantMap aggregateFunctionVariantMap_; - - FunctionVariantMap windowFunctionVariantMap_; - - TypeVariantMap typeVariantMap_; -}; - -using ExtensionPtr = std::shared_ptr; - -} // namespace io::substrait diff --git a/include/function/Function.h b/include/function/Function.h deleted file mode 100644 index 43a0ac32..00000000 --- a/include/function/Function.h +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "function/FunctionSignature.h" -#include "type/Type.h" - -namespace io::substrait { - -struct FunctionArgument { - virtual bool isRequired() const = 0; - - /// Convert argument type to short type string based on - /// https://substrait.io/extensions/#function-signature-compound-names - virtual std::string toTypeString() const = 0; - - virtual bool isWildcardType() const { - return false; - }; - - virtual bool isValueArgument() const { - return false; - } -}; - -using FunctionArgumentPtr = std::shared_ptr; - -struct EnumArgument : public FunctionArgument { - bool required; - - bool isRequired() const override { - return required; - } - - std::string toTypeString() const override { - return required ? "req" : "opt"; - } -}; - -struct TypeArgument : public FunctionArgument { - std::string toTypeString() const override { - return "type"; - } - - bool isRequired() const override { - return true; - } -}; - -struct ValueArgument : public FunctionArgument { - ParameterizedTypePtr type; - - std::string toTypeString() const override { - return type->signature(); - } - - bool isRequired() const override { - return true; - } - - bool isWildcardType() const override { - return type->isWildcard(); - } - - bool isValueArgument() const override { - return true; - } -}; - -struct FunctionVariadic { - int min; - std::optional max; -}; - -struct FunctionVariant { - std::string name; - std::string uri; - std::vector arguments; - ParameterizedTypePtr returnType; - std::optional variadic; - - /// Test if the actual types matched with this function variant. - virtual bool tryMatch(const FunctionSignature& signature); - - /// Create function signature by given function name and arguments. - static std::string signature( - const std::string& name, - const std::vector& arguments); - - /// Create function signature by function name and arguments. - const std::string signature() const { - return signature(name, arguments); - } -}; - -using FunctionVariantPtr = std::shared_ptr; - -struct ScalarFunctionVariant : public FunctionVariant {}; - -struct AggregateFunctionVariant : public FunctionVariant { - ParameterizedTypePtr intermediate; - bool deterministic; - - - bool tryMatch(const FunctionSignature& signature) override; -}; - -} // namespace io::substrait diff --git a/include/function/FunctionLookup.h b/include/function/FunctionLookup.h deleted file mode 100644 index 27948125..00000000 --- a/include/function/FunctionLookup.h +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "function/Extension.h" -#include "function/FunctionMapping.h" -#include "function/FunctionSignature.h" - -namespace io::substrait { - -class FunctionLookup { - public: - FunctionLookup( - const ExtensionPtr& extension, - const FunctionMappingPtr& functionMapping) - : extension_(extension), functionMapping_(functionMapping) {} - - virtual FunctionVariantPtr lookupFunction( - const FunctionSignature& signature) const; - - virtual ~FunctionLookup() {} - - protected: - virtual FunctionMap getFunctionMap() const = 0; - - virtual FunctionVariantMap getFunctionVariants() const = 0; - - const FunctionMappingPtr functionMapping_; - - ExtensionPtr extension_; -}; - -using FunctionLookupPtr = std::shared_ptr; - -class ScalarFunctionLookup : public FunctionLookup { - public: - ScalarFunctionLookup( - const ExtensionPtr& extension, - const FunctionMappingPtr& functionMapping) - : FunctionLookup(extension, functionMapping) {} - - protected: - FunctionMap getFunctionMap() const override { - return functionMapping_->scalaMapping(); - } - - FunctionVariantMap getFunctionVariants() const override { - return extension_->scalaFunctionVariantMap(); - } -}; - -class AggregateFunctionLookup : public FunctionLookup { - public: - AggregateFunctionLookup( - const ExtensionPtr& extension, - const FunctionMappingPtr& functionMapping) - : FunctionLookup(extension, functionMapping) {} - - protected: - FunctionMap getFunctionMap() const override { - return functionMapping_->aggregateMapping(); - } - - FunctionVariantMap getFunctionVariants() const override { - return extension_->aggregateFunctionVariantMap(); - } -}; - -class WindowFunctionLookup : public FunctionLookup { - public: - WindowFunctionLookup( - const ExtensionPtr& extension, - const FunctionMappingPtr& functionMapping) - : FunctionLookup(extension, functionMapping) {} - - protected: - FunctionMap getFunctionMap() const override { - return functionMapping_->windowMapping(); - } - - FunctionVariantMap getFunctionVariants() const override { - return extension_->windowFunctionVariantMap(); - } -}; - -} // namespace io::substrait diff --git a/include/function/FunctionMapping.h b/include/function/FunctionMapping.h deleted file mode 100644 index 0e1bb8ab..00000000 --- a/include/function/FunctionMapping.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace io::substrait { - -using FunctionMap = std::unordered_map; - -/// An interface describe the function names in difference between engine-own -/// and substrait system. -class FunctionMapping { - public: - /// Scalar function names in difference between engine own and substrait. - virtual const FunctionMap& scalaMapping() const { - static const FunctionMap scalaFunctionMap{}; - return scalaFunctionMap; - } - - /// Scalar function names in difference between engine own and substrait. - virtual const FunctionMap& aggregateMapping() const { - static const FunctionMap aggregateFunctionMap{}; - return aggregateFunctionMap; - } - - /// Window function names in difference between engine own and substrait. - virtual const FunctionMap& windowMapping() const { - static const FunctionMap windowFunctionMap{}; - return windowFunctionMap; - } -}; - -using FunctionMappingPtr = std::shared_ptr; -} // namespace io::substrait diff --git a/include/function/FunctionSignature.h b/include/function/FunctionSignature.h deleted file mode 100644 index 4a3c78a4..00000000 --- a/include/function/FunctionSignature.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#pragma once - -#include "type/Type.h" - -namespace io::substrait { - -struct FunctionSignature { - std::string name; - std::vector arguments; - TypePtr returnType; -}; - -} // namespace io::substrait diff --git a/include/substrait/type/Type.h b/include/substrait/type/Type.h index 93208f9f..552befd2 100644 --- a/include/substrait/type/Type.h +++ b/include/substrait/type/Type.h @@ -189,7 +189,12 @@ class ParameterizedType { /// Deserialize substrait raw type string into Substrait extension type. /// @param rawType - substrait extension raw string type static std::shared_ptr decode( - const std::string& rawType); + const std::string& rawType){ + return decode(rawType, true); + } + + static std::shared_ptr decode( + const std::string& rawType,bool isParameterized); [[nodiscard]] const bool& nullable() const { return nullable_; diff --git a/include/type/Type.h b/include/type/Type.h deleted file mode 100644 index 0c3cc6c0..00000000 --- a/include/type/Type.h +++ /dev/null @@ -1,693 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace io::substrait { - -enum class TypeKind : int8_t { - kBool = 1, - kI8 = 2, - kI16 = 3, - kI32 = 5, - kI64 = 7, - kFp32 = 10, - kFp64 = 11, - kString = 12, - kBinary = 13, - kTimestamp = 14, - kDate = 16, - kTime = 17, - kIntervalYear = 19, - kIntervalDay = 20, - kTimestampTz = 29, - kUuid = 32, - kFixedChar = 21, - kVarchar = 22, - kFixedBinary = 23, - kDecimal = 24, - kStruct = 25, - kList = 27, - kMap = 28, - kUserDefined = 30, - KIND_NOT_SET = 0, -}; - -template -struct TypeTraits {}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "bool"; - static constexpr const char* typeString = "boolean"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "i8"; - static constexpr const char* typeString = "i8"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "i16"; - static constexpr const char* typeString = "i16"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "i32"; - static constexpr const char* typeString = "i32"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "i64"; - static constexpr const char* typeString = "i64"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "fp32"; - static constexpr const char* typeString = "fp32"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "fp64"; - static constexpr const char* typeString = "fp64"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "str"; - static constexpr const char* typeString = "string"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "vbin"; - static constexpr const char* typeString = "binary"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "ts"; - static constexpr const char* typeString = "timestamp"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "tstz"; - static constexpr const char* typeString = "timestamp_tz"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "date"; - static constexpr const char* typeString = "date"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "time"; - static constexpr const char* typeString = "time"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "iyear"; - static constexpr const char* typeString = "interval_year"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "iday"; - static constexpr const char* typeString = "interval_day"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "uuid"; - static constexpr const char* typeString = "uuid"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "fchar"; - static constexpr const char* typeString = "fixedchar"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "vchar"; - static constexpr const char* typeString = "varchar"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "fbin"; - static constexpr const char* typeString = "fixedbinary"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "dec"; - static constexpr const char* typeString = "decimal"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "struct"; - static constexpr const char* typeString = "struct"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "list"; - static constexpr const char* typeString = "list"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "map"; - static constexpr const char* typeString = "map"; -}; - -template <> -struct TypeTraits { - static constexpr const char* signature = "u!name"; - static constexpr const char* typeString = "user defined type"; -}; - -class ParameterizedType { - public: - ParameterizedType(bool nullable = false) : nullable_(nullable) {} - - virtual std::string signature() const = 0; - - virtual TypeKind kind() const = 0; - - /// Deserialize substrait raw type string into Substrait extension type. - /// @param rawType - substrait extension raw string type - static std::shared_ptr decode( - const std::string& rawType); - - const bool& nullable() const { - return nullable_; - } - - bool nullMatch(const std::shared_ptr& type) const { - return nullable() || nullable() == type->nullable(); - } - /// Test type is a Wildcard type or not. - virtual bool isWildcard() const { - return false; - } - - virtual bool isMatch( - const std::shared_ptr& type) const = 0; - - private: - const bool nullable_; -}; - -using ParameterizedTypePtr = std::shared_ptr; - -class Type : public ParameterizedType { - public: - Type(bool nullable = false) : ParameterizedType(nullable) {} -}; - -using TypePtr = std::shared_ptr; - -/// Types used in function argument declarations. -template -class TypeBase : public Type { - public: - TypeBase(bool nullable = false) : Type(nullable) {} - - std::string signature() const override { - return TypeTraits::signature; - } - - virtual TypeKind kind() const override { - return Kind; - } - - bool isMatch( - const std::shared_ptr& type) const override { - return kind() == type->kind() && nullMatch(type); - } -}; - -template -class ScalarType : public TypeBase { - public: - ScalarType(bool nullable) : TypeBase(nullable) {} -}; - -class Decimal : public TypeBase { - public: - Decimal(int precision, int scale, bool nullable = false) - : TypeBase(nullable), - precision_(precision), - scale_(scale) {} - - std::string signature() const override; - - const int& precision() const { - return precision_; - } - - const int& scale() const { - return scale_; - } - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const int precision_; - const int scale_; -}; - -class FixedBinary : public TypeBase { - public: - FixedBinary(int length, bool nullable = false) - : TypeBase(nullable), length_(length) {} - - const int& length() const { - return length_; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const int length_; -}; - -class FixedChar : public TypeBase { - public: - FixedChar(int length, bool nullable = false) - : TypeBase(nullable), length_(length){}; - - const int& length() const { - return length_; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const int length_; -}; - -class Varchar : public TypeBase { - public: - Varchar(int length, bool nullable = false) - : TypeBase(nullable), length_(length){}; - - const int& length() const { - return length_; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const int length_; -}; - -class List : public TypeBase { - public: - List(TypePtr elementType, bool nullable = false) - : TypeBase(nullable), - elementType_(std::move(elementType)){}; - - const TypePtr& elementType() const { - return elementType_; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const TypePtr elementType_; -}; - -class Struct : public TypeBase { - public: - Struct(std::vector types, bool nullable = false) - : TypeBase(nullable), children_(std::move(types)) {} - - std::string signature() const override; - - const std::vector& children() const { - return children_; - } - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const std::vector children_; -}; - -class Map : public TypeBase { - public: - Map(TypePtr keyType, TypePtr valueType, bool nullable = false) - : TypeBase(nullable), - keyType_(std::move(keyType)), - valueType_(std::move(valueType)) {} - - const TypePtr& keyType() const { - return keyType_; - } - - const TypePtr& valueType() const { - return valueType_; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const TypePtr keyType_; - const TypePtr valueType_; -}; - -/// ParameterizedType represent a type in -class ParameterizedTypeBase : public ParameterizedType { - public: - ParameterizedTypeBase(bool nullable = false) : ParameterizedType(nullable) {} -}; - -class UsedDefinedType : public ParameterizedTypeBase { - public: - UsedDefinedType(std::string value, bool nullable) - : ParameterizedTypeBase(nullable), value_(std::move(value)) {} - - const std::string& value() const { - return value_; - } - - TypeKind kind() const override { - return TypeKind::kUserDefined; - } - - std::string signature() const override { - return TypeTraits::signature; - } - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - /// raw string of wildcard type. - const std::string value_; -}; - -/// A string literal type can present the 'any1'. -class StringLiteral : public ParameterizedTypeBase { - public: - StringLiteral(std::string value) - : ParameterizedTypeBase(false), value_(std::move(value)) {} - - std::string signature() const override { - return value_; - } - - TypeKind kind() const override { - return TypeKind::KIND_NOT_SET; - } - - const std::string& value() const { - return value_; - } - - bool isWildcard() const override { - return value_.find("any") == 0 || value_ == "T"; - } - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const std::string value_; -}; - -using StringLiteralPtr = std::shared_ptr; - -class ParameterizedDecimal : public ParameterizedTypeBase { - public: - ParameterizedDecimal( - StringLiteralPtr precision, - StringLiteralPtr scale, - bool nullable = false) - : ParameterizedTypeBase(nullable), - precision_(std::move(precision)), - scale_(std::move(scale)) {} - - std::string signature() const override; - - const StringLiteralPtr& precision() const { - return precision_; - } - - TypeKind kind() const override { - return TypeKind::kDecimal; - } - - const StringLiteralPtr& scale() const { - return scale_; - } - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - StringLiteralPtr precision_; - StringLiteralPtr scale_; -}; - -class ParameterizedFixedBinary : public ParameterizedTypeBase { - public: - ParameterizedFixedBinary(StringLiteralPtr length, bool nullable = false) - : ParameterizedTypeBase(nullable), length_(std::move(length)) {} - - const StringLiteralPtr& length() const { - return length_; - } - - TypeKind kind() const override { - return TypeKind::kFixedBinary; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const StringLiteralPtr length_; -}; - -class ParameterizedFixedChar : public ParameterizedTypeBase { - public: - ParameterizedFixedChar(StringLiteralPtr length, bool nullable = false) - : ParameterizedTypeBase(nullable), length_(std::move(length)) {} - - const StringLiteralPtr& length() const { - return length_; - } - - TypeKind kind() const override { - return TypeKind::kFixedChar; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const StringLiteralPtr length_; -}; - -class ParameterizedVarchar : public ParameterizedTypeBase { - public: - ParameterizedVarchar(const StringLiteralPtr& length, bool nullable = false) - : ParameterizedTypeBase(nullable), length_(length) {} - - const StringLiteralPtr& length() const { - return length_; - } - - TypeKind kind() const override { - return TypeKind::kVarchar; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const StringLiteralPtr length_; -}; - -class ParameterizedList : public ParameterizedTypeBase { - public: - ParameterizedList(ParameterizedTypePtr elementType, bool nullable = false) - : ParameterizedTypeBase(nullable), elementType_(std::move(elementType)){}; - - const ParameterizedTypePtr& elementType() const { - return elementType_; - } - - TypeKind kind() const override { - return TypeKind::kList; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const ParameterizedTypePtr elementType_; -}; - -class ParameterizedStruct : public ParameterizedTypeBase { - public: - ParameterizedStruct( - std::vector types, - bool nullable = false) - : ParameterizedTypeBase(nullable), children_(std::move(types)) {} - - std::string signature() const override; - - const std::vector& children() const { - return children_; - } - - TypeKind kind() const override { - return TypeKind::kStruct; - } - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const std::vector children_; -}; - -class ParameterizedMap : public ParameterizedTypeBase { - public: - ParameterizedMap( - ParameterizedTypePtr keyType, - ParameterizedTypePtr valueType, - bool nullable = false) - : ParameterizedTypeBase(nullable), - keyType_(std::move(keyType)), - valueType_(std::move(valueType)) {} - - const ParameterizedTypePtr& keyType() const { - return keyType_; - } - - TypeKind kind() const override { - return TypeKind::kMap; - } - const ParameterizedTypePtr& valueType() const { - return valueType_; - } - - std::string signature() const override; - - bool isMatch( - const std::shared_ptr& type) const override; - - private: - const ParameterizedTypePtr keyType_; - const ParameterizedTypePtr valueType_; -}; - -std::shared_ptr> BOOL(); - -std::shared_ptr> TINYINT(); - -std::shared_ptr> SMALLINT(); - -std::shared_ptr> INTEGER(); - -std::shared_ptr> BIGINT(); - -std::shared_ptr> FLOAT(); - -std::shared_ptr> DOUBLE(); - -std::shared_ptr> STRING(); - -std::shared_ptr> BINARY(); - -std::shared_ptr> TIMESTAMP(); - -std::shared_ptr> TIMESTAMP_TZ(); - -std::shared_ptr> DATE(); - -std::shared_ptr> TIME(); - -std::shared_ptr> INTERVAL_YEAR(); - -std::shared_ptr> INTERVAL_DAY(); - -std::shared_ptr> UUID(); - -std::shared_ptr DECIMAL(int precision, int scale); - -std::shared_ptr VARCHAR(int len); - -std::shared_ptr FChar(int len); - -std::shared_ptr FBinary(int len); - -std::shared_ptr LIST(const TypePtr& elementType); - -std::shared_ptr MAP( - const TypePtr& keyType, - const TypePtr& valueType); - -std::shared_ptr ROW(const std::vector& children); - -} // namespace io::substrait diff --git a/scripts/setup-macos.sh b/scripts/setup-macos.sh deleted file mode 100755 index 6c3760ef..00000000 --- a/scripts/setup-macos.sh +++ /dev/null @@ -1,89 +0,0 @@ -#!/bin/bash -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -e # Exit on error. -set -x # Print commands that are executed. - -SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") -source $SCRIPTDIR/setup-helper-functions.sh - -CPU_TARGET="${CPU_TARGET:-avx}" -NPROC=$(getconf _NPROCESSORS_ONLN) -COMPILER_FLAGS=$(get_cxx_flags $CPU_TARGET) - -DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} -MACOS_DEPS="ninja bison cmake ccache gflags protobuf" - -function run_and_time { - time "$@" - { echo "+ Finished running $*"; } 2> /dev/null -} - -function prompt { - ( - while true; do - local input="${PROMPT_ALWAYS_RESPOND:-}" - echo -n "$(tput bold)$* [Y, n]$(tput sgr0) " - [[ -z "${input}" ]] && read input - if [[ "${input}" == "Y" || "${input}" == "y" || "${input}" == "" ]]; then - return 0 - elif [[ "${input}" == "N" || "${input}" == "n" ]]; then - return 1 - fi - done - ) 2> /dev/null -} - -function update_brew { - /usr/local/bin/brew update --force --quiet -} - -function install_build_prerequisites { - for pkg in ${MACOS_DEPS} - do - if [[ "${pkg}" =~ ^([0-9a-z-]*):([0-9](\.[0-9\])*)$ ]]; - then - pkg=${BASH_REMATCH[1]} - ver=${BASH_REMATCH[2]} - echo "Installing '${pkg}' at '${ver}'" - tap="local-${pkg}" - brew tap-new "${tap}" - brew extract "--version=${ver}" "${pkg}" "${tap}" - brew install "${tap}/${pkg}@${ver}" - else - brew install --formula "${pkg}" && echo "Installation of ${pkg} is successful" || brew upgrade --formula "$pkg" - fi - done - - pip3 install --user cmake-format regex -} - -function install_deps { - if [ "${INSTALL_PREREQUISITES:-Y}" == "Y" ]; then - run_and_time install_build_prerequisites - fi -} - -(return 2> /dev/null) && return # If script was sourced, don't run commands. - -( - if [[ $# -ne 0 ]]; then - for cmd in "$@"; do - run_and_time "${cmd}" - done - else - install_deps - fi -) - -echo "All deps installed! Now try \"make\"" \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt deleted file mode 100644 index c70c9a39..00000000 --- a/src/CMakeLists.txt +++ /dev/null @@ -1,59 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# set the project name -project(substrait-cpp) - -# Set up Proto -set(proto_directory ${CMAKE_SOURCE_DIR}/third_party/substrait/proto) -set(substrait_proto_directory ${proto_directory}/substrait) -set(PROTO_OUTPUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/proto/") -file(MAKE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/proto/substrait) -file(GLOB PROTO_FILES ${substrait_proto_directory}/*.proto - ${substrait_proto_directory}/extensions/*.proto) -foreach(PROTO ${PROTO_FILES}) - file(RELATIVE_PATH REL_PROTO ${substrait_proto_directory} ${PROTO}) - string(REGEX REPLACE "\\.proto" "" PROTO_NAME ${REL_PROTO}) - list(APPEND PROTO_SRCS "${PROTO_OUTPUT_DIR}/substrait/${PROTO_NAME}.pb.cc") - list(APPEND PROTO_HDRS "${PROTO_OUTPUT_DIR}/substrait/${PROTO_NAME}.pb.h") -endforeach() -set(PROTO_OUTPUT_FILES ${PROTO_HDRS} ${PROTO_SRCS}) -set_source_files_properties(${PROTO_OUTPUT_FILES} PROPERTIES GENERATED TRUE) - -get_filename_component(PROTO_DIR ${substrait_proto_directory}/, DIRECTORY) - -# Generate Substrait hearders -add_custom_command( - OUTPUT ${PROTO_OUTPUT_FILES} - COMMAND protoc --proto_path ${proto_directory}/ --cpp_out ${PROTO_OUTPUT_DIR} - ${PROTO_FILES} - DEPENDS ${PROTO_DIR} - COMMENT "Running PROTO compiler" - VERBATIM) -add_custom_target(substrait_proto ALL DEPENDS ${PROTO_OUTPUT_FILES}) -add_dependencies(substrait_proto protobuf::libprotobuf) - - -set(SRCS - ${PROTO_SRCS} - SubstraitType.cpp - SubstraitFunction.cpp - SubstraitExtension.cpp) - -add_library(substrait-cpp ${SRCS}) - -target_include_directories(substrait-cpp - PUBLIC ${PROTO_OUTPUT_DIR}) -target_link_libraries(substrait-cpp yaml-cpp) - -#if (${BUILD_TESTING}) - add_subdirectory(tests) -#endif () diff --git a/src/SubstraitExtension.cpp b/src/SubstraitExtension.cpp deleted file mode 100644 index 22aeab4b..00000000 --- a/src/SubstraitExtension.cpp +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "yaml-cpp/yaml.h" - -#include "SubstraitExtension.h" - -namespace YAML { - -using namespace io::substrait; - -static bool decodeFunctionVariant( - const Node& node, - SubstraitFunctionVariant& function) { - auto& returnType = node["return"]; - if (returnType && returnType.IsScalar()) { - /// Return type can be an expression. - const auto& returnExpr = returnType.as(); - std::stringstream ss(returnExpr); - - // TODO: currently we only parse the last sentence of type definition, use - // ANTLR in future. - std::string lastReturnType; - while (std::getline(ss, lastReturnType, '\n')) { - } - function.returnType = SubstraitType::decode(lastReturnType); - } - auto& args = node["args"]; - if (args && args.IsSequence()) { - for (auto& arg : args) { - if (arg["options"]) { // enum argument - auto enumArgument = std::make_shared( - arg.as()); - function.arguments.emplace_back(enumArgument); - } else if (arg["value"]) { // value argument - auto valueArgument = std::make_shared( - arg.as()); - function.arguments.emplace_back(valueArgument); - } else { // type argument - auto typeArgument = std::make_shared( - arg.as()); - function.arguments.emplace_back(typeArgument); - } - } - } - - auto& variadic = node["variadic"]; - if (variadic) { - auto& min = variadic["min"]; - auto& max = variadic["max"]; - if (min) { - function.variadic = std::make_optional( - {min.as(), - max ? std::make_optional(max.as()) : std::nullopt}); - } else { - function.variadic = std::nullopt; - } - } else { - function.variadic = std::nullopt; - } - - return true; -} - -template <> -struct convert { - static bool decode(const Node& node, SubstraitEnumArgument& argument) { - // 'options' is required property - auto& options = node["options"]; - if (options && options.IsSequence()) { - auto& required = node["required"]; - argument.required = required && required.as(); - return true; - } else { - return false; - } - } -}; - -template <> -struct convert { - static bool decode(const Node& node, SubstraitValueArgument& argument) { - auto& value = node["value"]; - if (value && value.IsScalar()) { - auto valueType = value.as(); - argument.type = SubstraitType::decode(valueType); - return true; - } - return false; - } -}; - -template <> -struct convert { - static bool decode(const Node& node, SubstraitTypeArgument& argument) { - // no properties need to populate for type argument, just return true if - // 'type' element exists. - return node["type"]; - } -}; - -template <> -struct convert { - static bool decode( - const Node& node, - SubstraitScalarFunctionVariant& function) { - return decodeFunctionVariant(node, function); - }; -}; - -template <> -struct convert { - static bool decode( - const Node& node, - SubstraitAggregateFunctionVariant& function) { - const auto& res = decodeFunctionVariant(node, function); - if (res) { - const auto& intermediate = node["intermediate"]; - if (intermediate) { - function.intermediate = - SubstraitType::decode(intermediate.as()); - } - } - return res; - } -}; - -template <> -struct convert { - static bool decode( - const Node& node, - io::substrait::SubstraitTypeVariant& typeAnchor) { - auto& name = node["name"]; - if (name && name.IsScalar()) { - typeAnchor.name = name.as(); - return true; - } - return false; - } -}; - -} // namespace YAML - -namespace io::substrait { - -namespace { - -std::string getSubstraitExtensionAbsolutePath() { - const std::string absolute_path = __FILE__; - auto const pos = absolute_path.find_last_of('/'); - return absolute_path.substr(0, pos) + "/../third_party/substrait/extensions/"; -} - -} // namespace - -std::shared_ptr SubstraitExtension::load() { - static const auto registry = loadDefault(); - return registry; -} - -std::shared_ptr SubstraitExtension::loadDefault() { - static const std::vector extensionFiles = { - "functions_aggregate_approx.yaml", - "functions_aggregate_generic.yaml", - "functions_arithmetic.yaml", - "functions_arithmetic_decimal.yaml", - "functions_boolean.yaml", - "functions_comparison.yaml", - "functions_datetime.yaml", - "functions_logarithmic.yaml", - "functions_rounding.yaml", - "functions_string.yaml", - "functions_set.yaml", - "unknown.yaml", - }; - const auto& extensionRootPath = getSubstraitExtensionAbsolutePath(); - return load(extensionRootPath, extensionFiles); -} - -std::shared_ptr SubstraitExtension::load( - const std::string& basePath, - const std::vector& extensionFiles) { - std::vector yamlExtensionFiles; - yamlExtensionFiles.reserve(extensionFiles.size()); - for (auto& extensionFile : extensionFiles) { - auto const pos = basePath.find_last_of('/'); - const auto& extensionUri = basePath.substr(0, pos) + "/" + extensionFile; - yamlExtensionFiles.emplace_back(extensionUri); - } - return load(yamlExtensionFiles); -} - -std::shared_ptr SubstraitExtension::load( - const std::vector& extensionFiles) { - auto registry = std::make_shared(); - for (const auto& extensionUri : extensionFiles) { - const auto& node = YAML::LoadFile(extensionUri); - - auto& scalarFunctions = node["scalar_functions"]; - if (scalarFunctions && scalarFunctions.IsSequence()) { - for (auto& scalarFunctionNode : scalarFunctions) { - const auto functionName = scalarFunctionNode["name"].as(); - for (auto& scalaFunctionVariantNode : scalarFunctionNode["impls"]) { - auto scalarFunctionVariant = - scalaFunctionVariantNode.as(); - scalarFunctionVariant.name = functionName; - scalarFunctionVariant.uri = extensionUri; - registry->addFunctionVariant( - std::make_shared( - scalarFunctionVariant)); - } - } - } - - auto& aggregateFunctions = node["aggregate_functions"]; - if (aggregateFunctions && aggregateFunctions.IsSequence()) { - for (auto& aggregateFunctionNode : aggregateFunctions) { - const auto functionName = - aggregateFunctionNode["name"].as(); - for (auto& aggregateFunctionVariantNode : - aggregateFunctionNode["impls"]) { - auto aggregateFunctionVariant = - aggregateFunctionVariantNode - .as(); - aggregateFunctionVariant.name = functionName; - aggregateFunctionVariant.uri = extensionUri; - registry->addFunctionVariant( - std::make_shared( - aggregateFunctionVariant)); - } - } - } - - auto& types = node["types"]; - if (types && types.IsSequence()) { - for (auto& type : types) { - auto typeAnchor = type.as(); - typeAnchor.uri = extensionUri; - registry->addTypeVariant( - std::make_shared(typeAnchor)); - } - } - } - return registry; -} - -void SubstraitExtension::addFunctionVariant( - const SubstraitFunctionVariantPtr& functionVariant) { - const auto& functionVariants = - functionVariantMap_.find(functionVariant->name); - if (functionVariants != functionVariantMap_.end()) { - auto& variants = functionVariants->second; - variants.emplace_back(functionVariant); - } else { - std::vector variants; - variants.emplace_back(functionVariant); - functionVariantMap_.insert({functionVariant->name, variants}); - } -} - -void SubstraitExtension::addTypeVariant( - const SubstraitTypeVariantPtr& functionVariant) { - typeVariantMap_.insert({functionVariant->name, functionVariant}); -} - -SubstraitTypeVariantPtr SubstraitExtension::lookupType( - const std::string& typeName) const { - auto typeVariantIter = typeVariantMap_.find(typeName); - if (typeVariantIter != typeVariantMap_.end()) { - return typeVariantIter->second; - } - return nullptr; -} - -SubstraitFunctionVariantPtr SubstraitExtension::lookupFunction( - const std::string& name, - const std::vector& types) const { - auto functionVariantIter = functionVariantMap_.find(name); - if (functionVariantIter != functionVariantMap_.end()) { - for (const auto& candidateFunctionVariant : functionVariantIter->second) { - if (candidateFunctionVariant->tryMatch(types)) { - return candidateFunctionVariant; - } - } - } - return nullptr; -} - -} // namespace io::substrait diff --git a/src/SubstraitExtension.h b/src/SubstraitExtension.h deleted file mode 100644 index f5c18d98..00000000 --- a/src/SubstraitExtension.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "SubstraitFunction.h" -#include "SubstraitType.h" - -namespace io::substrait { - -struct SubstraitTypeVariant { - std::string name; - std::string uri; -}; - -using SubstraitTypeVariantPtr = std::shared_ptr; - -using FunctionVariantMap = - std::unordered_map>; - -using TypeVariantMap = std::unordered_map; - -class SubstraitExtension { - public: - /// deserialize default substrait extension. - static std::shared_ptr load(); - - /// deserialize substrait extension by given basePath and extensionFiles. - static std::shared_ptr load( - const std::string& basePath, - const std::vector& extensionFiles); - - /// deserialize substrait extension by given extensionFiles. - static std::shared_ptr load( - const std::vector& extensionFiles); - - /// Add a function variant - void addFunctionVariant(const SubstraitFunctionVariantPtr& functionVariant); - - /// Add a type variant - void addTypeVariant(const SubstraitTypeVariantPtr& functionVariant); - - /// lookup function variant by given function name and function types. - /// @return matched function variant - SubstraitFunctionVariantPtr lookupFunction( - const std::string& name, - const std::vector& types) const; - - /// lookup type variant by given type name. - /// @return matched type variant - SubstraitTypeVariantPtr lookupType(const std::string& typeName) const; - - private: - /// deserialize default substrait extension. - static std::shared_ptr loadDefault(); - /// function variants loaded in registry. - FunctionVariantMap functionVariantMap_; - /// type variants loaded in registry. - TypeVariantMap typeVariantMap_; -}; - -using SubstraitExtensionPtr = std::shared_ptr; - -} // namespace io::substrait diff --git a/src/SubstraitFunction.cpp b/src/SubstraitFunction.cpp deleted file mode 100644 index f1ae10df..00000000 --- a/src/SubstraitFunction.cpp +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "SubstraitFunction.h" -#include - -namespace io::substrait { - -std::string SubstraitFunctionVariant::signature( - const std::string& name, - const std::vector& arguments) { - std::stringstream ss; - ss << name; - if (!arguments.empty()) { - ss << ":"; - for (auto it = arguments.begin(); it != arguments.end(); ++it) { - const auto& typeSign = (*it)->toTypeString(); - if (it == arguments.end() - 1) { - ss << typeSign; - } else { - ss << typeSign << "_"; - } - } - } - - return ss.str(); -} - -bool SubstraitFunctionVariant::tryMatch( - const std::vector& actualTypes) { - if (variadic.has_value()) { - // return false if actual types length less than min of variadic - const auto max = variadic->max; - if ((actualTypes.size() < variadic->min) || - (max.has_value() && actualTypes.size() > max.value())) { - return false; - } - - const auto& variadicArgument = arguments[0]; - // actual type must same as the variadicArgument - if (const auto& variadicValueArgument = - std::dynamic_pointer_cast( - variadicArgument)) { - for (auto& actualType : actualTypes) { - if (!variadicValueArgument->type->isSameAs(actualType)) { - return false; - } - } - } - return true; - } else { - std::vector> valueArguments; - for (const auto& argument : arguments) { - if (const auto& variadicValueArgument = - std::dynamic_pointer_cast( - argument)) { - valueArguments.emplace_back(variadicValueArgument); - } - } - // return false if size of actual types not equal to size of value - // arguments. - if (valueArguments.size() != actualTypes.size()) { - return false; - } - - for (auto i = 0; i < actualTypes.size(); i++) { - const auto& valueArgument = valueArguments[i]; - if (!valueArgument->type->isSameAs(actualTypes[i])) { - return false; - } - } - return true; - } -} - -bool SubstraitAggregateFunctionVariant::tryMatch( - const std::vector& actualTypes) { - bool matched = SubstraitFunctionVariant::tryMatch(actualTypes); - if (!matched && intermediate) { - if (actualTypes.size() == 1) { - return intermediate->isSameAs(actualTypes[0]); - } - } - return matched; -} - -} // namespace io::substrait diff --git a/src/SubstraitFunction.h b/src/SubstraitFunction.h deleted file mode 100644 index cac84559..00000000 --- a/src/SubstraitFunction.h +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include "SubstraitType.h" - -namespace io::substrait { - -struct SubstraitFunctionArgument { - /// whether the argument is required or not. - virtual const bool isRequired() const = 0; - /// convert argument type to short type string based on - /// https://substrait.io/extensions/#function-signature-compound-names - virtual const std::string toTypeString() const = 0; - - virtual const bool isWildcardType() const { - return false; - }; - - virtual const bool isValueArgument() const { - return false; - } -}; - -using SubstraitFunctionArgumentPtr = std::shared_ptr; - -struct SubstraitEnumArgument : public SubstraitFunctionArgument { - bool required; - bool const isRequired() const override { - return required; - } - - const std::string toTypeString() const override { - return required ? "req" : "opt"; - } -}; - -struct SubstraitTypeArgument : public SubstraitFunctionArgument { - const std::string toTypeString() const override { - return "type"; - } - const bool isRequired() const override { - return true; - } -}; - -struct SubstraitValueArgument : public SubstraitFunctionArgument { - SubstraitTypePtr type; - - const std::string toTypeString() const override { - return type->signature(); - } - - const bool isRequired() const override { - return true; - } - - const bool isWildcardType() const override { - return type->isWildcard(); - } - - const bool isValueArgument() const override { - return true; - } -}; - -struct SubstraitFunctionVariadic { - int min; - std::optional max; -}; - -struct SubstraitFunctionVariant { - /// function name. - std::string name; - /// function uri. - std::string uri; - /// function arguments. - std::vector arguments; - /// return type. - SubstraitTypePtr returnType; - /// function variadic - std::optional variadic; - - ///test if the actual types matched with this function variant. - virtual bool tryMatch(const std::vector& actualTypes); - - /// create function signature by given function name and arguments. - static std::string signature( - const std::string& name, - const std::vector& arguments); - - /// create function signature by function name and arguments. - const std::string signature() const { - return signature(name, arguments); - } -}; - -using SubstraitFunctionVariantPtr = std::shared_ptr; - -struct SubstraitScalarFunctionVariant : public SubstraitFunctionVariant { -}; - -struct SubstraitAggregateFunctionVariant : public SubstraitFunctionVariant { - SubstraitTypePtr intermediate; - - bool tryMatch(const std::vector& actualTypes) override; -}; - -} // namespace io::substrait diff --git a/src/SubstraitType.cpp b/src/SubstraitType.cpp deleted file mode 100644 index 9a2f8783..00000000 --- a/src/SubstraitType.cpp +++ /dev/null @@ -1,348 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "SubstraitType.h" -#include - -namespace io::substrait { - -namespace { - -size_t findNextComma(const std::string& str, size_t start) { - int cnt = 0; - for (auto i = start; i < str.size(); i++) { - if (str[i] == '<') { - cnt++; - } else if (str[i] == '>') { - cnt--; - } else if (cnt == 0 && str[i] == ',') { - return i; - } - } - - return std::string::npos; -} - -} // namespace - -SubstraitTypePtr SubstraitType::decode(const std::string& rawType) { - std::string matchingType = rawType; - const auto& questionMaskPos = rawType.find_last_of('?'); - // deal with type and with a question mask like "i32?". - if (questionMaskPos != std::string::npos) { - matchingType = rawType.substr(0, questionMaskPos); - } - std::transform( - matchingType.begin(), - matchingType.end(), - matchingType.begin(), - [](unsigned char c) { return std::tolower(c); }); - - const auto& leftAngleBracketPos = rawType.find('<'); - if (leftAngleBracketPos == std::string::npos) { - const auto& scalarType = scalarTypeMapping().find(matchingType); - if (scalarType != scalarTypeMapping().end()) { - return scalarType->second; - } else if (matchingType.rfind("unknown", 0) == 0) { - return std::make_shared(rawType); - } else { - return std::make_shared(rawType); - } - } - const auto& rightAngleBracketPos = rawType.rfind('>'); - - - auto baseType = matchingType.substr(0, leftAngleBracketPos); - - std::vector nestedTypes; - nestedTypes.reserve(8); - auto prevPos = leftAngleBracketPos + 1; - auto commaPos = findNextComma(rawType, prevPos); - while (commaPos != std::string::npos) { - auto token = rawType.substr(prevPos, commaPos - prevPos); - nestedTypes.emplace_back(decode(token)); - prevPos = commaPos + 1; - commaPos = findNextComma(rawType, prevPos); - } - auto token = rawType.substr(prevPos, rightAngleBracketPos - prevPos); - nestedTypes.emplace_back(decode(token)); - - if (SubstraitTypeTraits::typeString == baseType) { - - return std::make_shared(nestedTypes[0]); - } else if ( - SubstraitTypeTraits::typeString == baseType) { - - return std::make_shared(nestedTypes[0], nestedTypes[1]); - } else if ( - SubstraitTypeTraits::typeString == - baseType) { - - auto precision = - std::dynamic_pointer_cast( - nestedTypes[0]); - auto scale = std::dynamic_pointer_cast( - nestedTypes[1]); - return std::make_shared(precision, scale); - } else if ( - SubstraitTypeTraits::typeString == - baseType) { - auto length = std::dynamic_pointer_cast( - nestedTypes[0]); - return std::make_shared(length); - } else if ( - SubstraitTypeTraits::typeString == - baseType) { - - auto length = std::dynamic_pointer_cast( - nestedTypes[0]); - return std::make_shared(length); - } else if ( - SubstraitTypeTraits::typeString == - baseType) { - - auto length = std::dynamic_pointer_cast( - nestedTypes[0]); - return std::make_shared(length); - } else if ( - SubstraitTypeTraits::typeString == baseType) { - - return std::make_shared(nestedTypes); - } else { - throw std::runtime_error("Unsupported substrait type: " + rawType); - } -} - -#define SUBSTRAIT_SCALAR_TYPE_MAPPING(typeKind) \ - { \ - SubstraitTypeTraits::typeString, \ - std::make_shared>( \ - SubstraitTypeBase()) \ - } - -const std::unordered_map& -SubstraitType::scalarTypeMapping() { - static const std::unordered_map scalarTypeMap{ - SUBSTRAIT_SCALAR_TYPE_MAPPING(kBool), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kI8), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kI16), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kI32), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kI64), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kFp32), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kFp64), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kString), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kBinary), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kTimestamp), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kTimestampTz), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kDate), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kTime), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kIntervalDay), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kIntervalYear), - SUBSTRAIT_SCALAR_TYPE_MAPPING(kUuid), - }; - return scalarTypeMap; -} - -const std::string SubstraitFixedBinaryType::signature() const { - std::stringstream sign; - sign << SubstraitTypeBase::signature(); - sign << "<"; - sign << length_->value(); - sign << ">"; - return sign.str(); -} - -bool SubstraitFixedBinaryType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return true; - } - return false; -} - -const std::string SubstraitDecimalType::signature() const { - std::stringstream signature; - signature << SubstraitTypeBase::signature(); - signature << "<"; - signature << precision_->value() << "," << scale_->value(); - signature << ">"; - return signature.str(); -} - -bool SubstraitDecimalType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return true; - } - return false; -} - -const std::string SubstraitFixedCharType::signature() const { - std::ostringstream sign; - sign << SubstraitTypeBase::signature(); - sign << "<"; - sign << length_->value(); - sign << ">"; - return sign.str(); -} - -bool SubstraitFixedCharType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return true; - } - return false; -} - -const std::string SubstraitVarcharType::signature() const { - std::ostringstream sign; - sign << SubstraitTypeBase::signature(); - sign << "<"; - sign << length_->value(); - sign << ">"; - return sign.str(); -} - -bool SubstraitVarcharType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return true; - } - return false; -} - -const std::string SubstraitStructType::signature() const { - std::ostringstream signature; - signature << SubstraitTypeBase::signature(); - signature << "<"; - for (auto it = children_.begin(); it != children_.end(); ++it) { - const auto& typeSign = (*it)->signature(); - if (it == children_.end() - 1) { - signature << typeSign; - } else { - signature << typeSign << ","; - } - } - signature << ">"; - return signature.str(); -} - -bool SubstraitStructType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - bool sameSize = type->children_.size() == children_.size(); - if (sameSize) { - for (int i = 0; i < children_.size(); i++) { - if (!children_[i]->isSameAs(type->children_[i])) { - return false; - } - } - return true; - } - } - return false; -} - -const std::string SubstraitMapType::signature() const { - std::ostringstream signature; - signature << SubstraitTypeBase::signature(); - signature << "<"; - signature << keyType_->signature(); - signature << ","; - signature << valueType_->signature(); - signature << ">"; - return signature.str(); -} - -bool SubstraitMapType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return keyType_->isSameAs(type->keyType_) && - valueType_->isSameAs(type->valueType_); - } - return false; -} - -const std::string SubstraitListType::signature() const { - std::ostringstream signature; - signature << SubstraitTypeBase::signature(); - signature << "<"; - signature << elementType_->signature(); - signature << ">"; - return signature.str(); -} - -bool SubstraitListType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return elementType_->isSameAs(type->elementType_); - } - return false; -} - -bool SubstraitUsedDefinedType::isSameAs( - const std::shared_ptr& other) const { - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return type->value_ == value_; - } - return false; -} - -bool SubstraitStringLiteralType::isSameAs( - const std::shared_ptr& other) const { - if (isWildcard()) { - return true; - } - if (const auto& type = - std::dynamic_pointer_cast(other)) { - return type->value_ == value_; - } - return false; -} - -#define DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(typeKind) \ - std::shared_ptr> \ - typeKind() { \ - return std::make_shared< \ - const SubstraitScalarType>(); \ - } - -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kBool); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kI8); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kI16); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kI32); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kI64); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kFp32); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kFp64); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kString); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kBinary); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kTimestamp); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kDate); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kTime); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kIntervalYear); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kIntervalDay); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kTimestampTz); -DEFINE_SUBSTRAIT_SCALAR_ACCESSOR(kUuid); - -#undef DEFINE_SUBSTRAIT_SCALAR_ACCESSOR - -} // namespace io::substrait diff --git a/src/SubstraitType.h b/src/SubstraitType.h deleted file mode 100644 index 02ad6024..00000000 --- a/src/SubstraitType.h +++ /dev/null @@ -1,465 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include "proto/substrait/algebra.pb.h" -#include "proto/substrait/type.pb.h" - -namespace io::substrait { - -using SubstraitTypeKind = ::substrait::Type::KindCase; - -template -struct SubstraitTypeTraits {}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "bool"; - static constexpr const char* typeString = "boolean"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "i8"; - static constexpr const char* typeString = "i8"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "i16"; - static constexpr const char* typeString = "i16"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "i32"; - static constexpr const char* typeString = "i32"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "i64"; - static constexpr const char* typeString = "i64"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "fp32"; - static constexpr const char* typeString = "fp32"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "fp64"; - static constexpr const char* typeString = "fp64"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "str"; - static constexpr const char* typeString = "string"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "vbin"; - static constexpr const char* typeString = "binary"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "ts"; - static constexpr const char* typeString = "timestamp"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "tstz"; - static constexpr const char* typeString = "timestamp_tz"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "date"; - static constexpr const char* typeString = "date"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "time"; - static constexpr const char* typeString = "time"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "iyear"; - static constexpr const char* typeString = "interval_year"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "iday"; - static constexpr const char* typeString = "interval_day"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "uuid"; - static constexpr const char* typeString = "uuid"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "fchar"; - static constexpr const char* typeString = "fixedchar"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "vchar"; - static constexpr const char* typeString = "varchar"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "fbin"; - static constexpr const char* typeString = "fixedbinary"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "dec"; - static constexpr const char* typeString = "decimal"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "struct"; - static constexpr const char* typeString = "struct"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "list"; - static constexpr const char* typeString = "list"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "map"; - static constexpr const char* typeString = "map"; -}; - -template <> -struct SubstraitTypeTraits { - static constexpr const char* signature = "u!name"; - static constexpr const char* typeString = "user defined type"; -}; - -class SubstraitType { - public: - /// deserialize substrait raw type string into Substrait extension type. - /// @param rawType - substrait extension raw string type - static std::shared_ptr decode( - const std::string& rawType); - - /// signature name. - virtual const std::string signature() const = 0; - - /// test type is a Wildcard type or not. - virtual const bool isWildcard() const { - return false; - } - - /// a known substrait type kind - virtual const SubstraitTypeKind kind() const = 0; - - virtual const std::string typeString() const = 0; - - /// whether two types are same as each other - virtual bool isSameAs( - const std::shared_ptr& other) const { - return kind() == other->kind(); - } - - private: - /// A map store the raw type string and corresponding Substrait Type - static const std:: - unordered_map>& - scalarTypeMapping(); -}; - -using SubstraitTypePtr = std::shared_ptr; - -/// Types used in function argument declarations. -template -class SubstraitTypeBase : public SubstraitType { - public: - const std::string signature() const override { - return SubstraitTypeTraits::signature; - } - - virtual const SubstraitTypeKind kind() const override { - return Kind; - } - - const std::string typeString() const override { - return SubstraitTypeTraits::typeString; - } -}; - -template -class SubstraitScalarType : public SubstraitTypeBase {}; - -/// A string literal type can present the 'any1' -class SubstraitStringLiteralType : public SubstraitType { - public: - SubstraitStringLiteralType(const std::string& value) : value_(value) {} - - const std::string& value() const { - return value_; - } - - const std::string signature() const override { - return value_; - } - - const std::string typeString() const override { - return value_; - } - const bool isWildcard() const override { - return value_.find("any") == 0 || value_ == "T"; - } - - bool isSameAs( - const std::shared_ptr& other) const override; - - const SubstraitTypeKind kind() const override { - return SubstraitTypeKind ::KIND_NOT_SET; - } - - private: - /// raw string of wildcard type. - const std::string value_; -}; - -using SubstraitStringLiteralTypePtr = - std::shared_ptr; - -class SubstraitDecimalType - : public SubstraitTypeBase { - public: - SubstraitDecimalType( - const SubstraitStringLiteralTypePtr& precision, - const SubstraitStringLiteralTypePtr& scale) - : precision_(precision), scale_(scale) {} - - SubstraitDecimalType(const std::string& precision, const std::string& scale) - : precision_(std::make_shared(precision)), - scale_(std::make_shared(scale)) {} - - SubstraitDecimalType(const int precision, const int scale) - : SubstraitDecimalType(std::to_string(precision), std::to_string(scale)) { - } - - bool isSameAs( - const std::shared_ptr& other) const override; - - const std::string signature() const override; - - const std::string precision() const { - return precision_->value(); - } - - const std::string scale() const { - return scale_->value(); - } - - private: - SubstraitStringLiteralTypePtr precision_; - SubstraitStringLiteralTypePtr scale_; -}; - -class SubstraitFixedBinaryType - : public SubstraitTypeBase { - public: - SubstraitFixedBinaryType(const SubstraitStringLiteralTypePtr& length) - : length_(length) {} - - SubstraitFixedBinaryType(const int length) - : SubstraitFixedBinaryType(std::make_shared( - std::to_string(length))) {} - - bool isSameAs( - const std::shared_ptr& other) const override; - - const SubstraitStringLiteralTypePtr& length() const { - return length_; - } - - const std::string signature() const override; - - protected: - SubstraitStringLiteralTypePtr length_; -}; - -class SubstraitFixedCharType - : public SubstraitTypeBase { - public: - SubstraitFixedCharType(const SubstraitStringLiteralTypePtr& length) - : length_(length) {} - - SubstraitFixedCharType(const int length) - : SubstraitFixedCharType(std::make_shared( - std::to_string(length))) {} - - bool isSameAs( - const std::shared_ptr& other) const override; - - const SubstraitStringLiteralTypePtr& length() const { - return length_; - } - - const std::string signature() const override; - - protected: - SubstraitStringLiteralTypePtr length_; -}; - -class SubstraitVarcharType - : public SubstraitTypeBase { - public: - SubstraitVarcharType(const SubstraitStringLiteralTypePtr& length) - : length_(length) {} - - SubstraitVarcharType(const int length) - : SubstraitVarcharType(std::make_shared( - std::to_string(length))) {} - - bool isSameAs( - const std::shared_ptr& other) const override; - - const SubstraitStringLiteralTypePtr& length() const { - return length_; - } - - const std::string signature() const override; - - protected: - SubstraitStringLiteralTypePtr length_; -}; - -class SubstraitListType : public SubstraitTypeBase { - public: - SubstraitListType(const SubstraitTypePtr& elementType) - : elementType_(elementType){}; - - const SubstraitTypePtr elementType() const { - return elementType_; - } - - bool isSameAs( - const std::shared_ptr& other) const override; - - const std::string signature() const override; - - private: - SubstraitTypePtr elementType_; -}; - -class SubstraitStructType - : public SubstraitTypeBase { - public: - SubstraitStructType(const std::vector& types) - : children_(types) {} - - bool isSameAs( - const std::shared_ptr& other) const override; - - const std::string signature() const override; - - const std::vector& children() const { - return children_; - } - - private: - std::vector children_; -}; - -class SubstraitMapType : public SubstraitTypeBase { - public: - SubstraitMapType( - const SubstraitTypePtr& keyType, - const SubstraitTypePtr& valueType) - : keyType_(keyType), valueType_(valueType) {} - - const SubstraitTypePtr keyType() const { - return keyType_; - } - - const SubstraitTypePtr valueType() const { - return valueType_; - } - - bool isSameAs( - const std::shared_ptr& other) const override; - - const std::string signature() const override; - - private: - SubstraitTypePtr keyType_; - SubstraitTypePtr valueType_; -}; - -class SubstraitUsedDefinedType - : public SubstraitTypeBase { - public: - SubstraitUsedDefinedType(const std::string& value) : value_(value) {} - - const std::string& value() const { - return value_; - } - - bool isSameAs( - const std::shared_ptr& other) const override; - - private: - /// raw string of wildcard type. - const std::string value_; -}; - -#define SUBSTRAIT_SCALAR_ACCESSOR(KIND) \ - std::shared_ptr> KIND() - -SUBSTRAIT_SCALAR_ACCESSOR(kBool); -SUBSTRAIT_SCALAR_ACCESSOR(kI8); -SUBSTRAIT_SCALAR_ACCESSOR(kI16); -SUBSTRAIT_SCALAR_ACCESSOR(kI32); -SUBSTRAIT_SCALAR_ACCESSOR(kI64); -SUBSTRAIT_SCALAR_ACCESSOR(kFp32); -SUBSTRAIT_SCALAR_ACCESSOR(kFp64); -SUBSTRAIT_SCALAR_ACCESSOR(kString); -SUBSTRAIT_SCALAR_ACCESSOR(kBinary); -SUBSTRAIT_SCALAR_ACCESSOR(kTimestamp); -SUBSTRAIT_SCALAR_ACCESSOR(kDate); -SUBSTRAIT_SCALAR_ACCESSOR(kTime); -SUBSTRAIT_SCALAR_ACCESSOR(kIntervalYear); -SUBSTRAIT_SCALAR_ACCESSOR(kIntervalDay); -SUBSTRAIT_SCALAR_ACCESSOR(kTimestampTz); -SUBSTRAIT_SCALAR_ACCESSOR(kUuid); - -} // namespace io::substrait diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt deleted file mode 100644 index 5c3a3e71..00000000 --- a/src/tests/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_executable( - substrait_cpp_test - SubstraitExtensionTest.cpp - SubstraitTypeTest.cpp) - - -add_test( - NAME substrait_cpp_test - COMMAND substrait_cpp_test - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - -target_link_libraries( - substrait_cpp_test - substrait-cpp - gtest - gtest_main) diff --git a/src/tests/SubstraitExtensionTest.cpp b/src/tests/SubstraitExtensionTest.cpp deleted file mode 100644 index e29ff2bf..00000000 --- a/src/tests/SubstraitExtensionTest.cpp +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -#include - -#include "iostream" -#include "../SubstraitExtension.h" - -using namespace io::substrait; - -class SubstraitExtensionTest : public ::testing::Test { -protected: - void testLookupFunction( - const std::string &name, - const std::vector &arguments, - const std::string &matchedSignature) { - const auto &functionOption = registry_->lookupFunction(name, arguments); - - ASSERT_TRUE(functionOption != nullptr); - ASSERT_EQ(functionOption->signature(), matchedSignature); - } - - /// Load registry from substrait core extension YAML files. - SubstraitExtensionPtr registry_ = SubstraitExtension::load(); -}; - -TEST_F(SubstraitExtensionTest, comparison_function) { - testLookupFunction("lt", {kI8(), kI8()}, "lt:any1_any1"); - testLookupFunction("lt", {kI16(), kI16()}, "lt:any1_any1"); - testLookupFunction("lt", {kI32(), kI32()}, "lt:any1_any1"); - testLookupFunction("lt", {kI64(), kI64()}, "lt:any1_any1"); - testLookupFunction("lt", {kFp32(), kFp32()}, "lt:any1_any1"); - testLookupFunction("lt", {kFp64(), kFp64()}, "lt:any1_any1"); - - testLookupFunction( - "between", {kI8(), kI8(), kI8()}, "between:any1_any1_any1"); -} - -TEST_F(SubstraitExtensionTest, arithmetic_function) { - testLookupFunction("add", {kI8(), kI8()}, "add:opt_i8_i8"); - testLookupFunction( - "divide", - { - kFp32(), - kFp32(), - }, - "divide:opt_opt_opt_fp32_fp32"); - - testLookupFunction( - "avg", {SubstraitType::decode("struct")}, "avg:opt_fp32"); -} - -TEST_F(SubstraitExtensionTest, boolean_function) { - testLookupFunction("and", {kBool(), kBool()}, "and:bool"); - testLookupFunction("or", {kBool(), kBool()}, "or:bool"); - testLookupFunction("not", {kBool()}, "not:bool"); - testLookupFunction( - "xor", {kBool(), kBool()}, "xor:bool_bool"); -} - -TEST_F(SubstraitExtensionTest, string_function) { - testLookupFunction( - "like", {kString(), kString()}, "like:opt_str_str"); -} - -TEST_F(SubstraitExtensionTest, unknowLookup) { - auto unknown = registry_->lookupType("unknown"); - ASSERT_TRUE(unknown); - ASSERT_EQ(unknown->name, "unknown"); -} diff --git a/src/tests/SubstraitTypeTest.cpp b/src/tests/SubstraitTypeTest.cpp deleted file mode 100644 index f10b8f0e..00000000 --- a/src/tests/SubstraitTypeTest.cpp +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "iostream" -#include "../SubstraitType.h" - -using namespace io::substrait; - -class SubstraitTypeTest : public ::testing::Test { - protected: - template - void testDecode(const std::string& rawType, const std::string& signature) { - const auto& type = SubstraitType::decode(rawType); - ASSERT_TRUE(type->kind() == kind); - ASSERT_EQ(type->signature(), signature); - } - - template - void testDecode( - const std::string& rawType, - const std::function&)>& - typeCallBack) { - const auto& type = SubstraitType::decode(rawType); - if (typeCallBack) { - typeCallBack(std::dynamic_pointer_cast(type)); - } - } -}; - -TEST_F(SubstraitTypeTest, decodeTest) { - testDecode("i32?", "i32"); - testDecode("BOOLEAN", "bool"); - testDecode("boolean", "bool"); - testDecode("i8", "i8"); - testDecode("i16", "i16"); - testDecode("i32", "i32"); - testDecode("i64", "i64"); - testDecode("fp32", "fp32"); - testDecode("fp64", "fp64"); - testDecode("binary", "vbin"); - testDecode("timestamp", "ts"); - testDecode("date", "date"); - testDecode("time", "time"); - testDecode("interval_day", "iday"); - testDecode("interval_year", "iyear"); - testDecode("timestamp_tz", "tstz"); - testDecode("uuid", "uuid"); - - testDecode( - "fixedchar", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->length()->value(), "L1"); - ASSERT_EQ(typePtr->signature(), "fchar"); - }); - - testDecode( - "fixedbinary", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->length()->value(), "L1"); - ASSERT_EQ(typePtr->signature(), "fbin"); - }); - - testDecode( - "varchar", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "vchar"); - ASSERT_EQ(typePtr->length()->value(), "L1"); - }); - - testDecode( - "decimal", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "dec"); - ASSERT_EQ(typePtr->precision(), "P"); - ASSERT_EQ(typePtr->scale(), "S"); - }); - - testDecode( - "struct", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "struct"); - }); - - testDecode( - "struct>", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "struct>"); - }); - - testDecode( - "list", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "list"); - }); - - testDecode( - "map", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "map"); - }); - - testDecode( - "any1", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "any1"); - ASSERT_TRUE(typePtr->isWildcard()); - }); - - testDecode( - "any", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "any"); - ASSERT_TRUE(typePtr->isWildcard()); - }); - - testDecode( - "T", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "T"); - ASSERT_TRUE(typePtr->isWildcard()); - }); - - testDecode( - "unknown", - [](const std::shared_ptr& typePtr) { - ASSERT_EQ(typePtr->signature(), "u!name"); - }); -} diff --git a/substrait/function/Extension.cpp b/substrait/function/Extension.cpp index 4d80713e..f7382df9 100644 --- a/substrait/function/Extension.cpp +++ b/substrait/function/Extension.cpp @@ -17,7 +17,7 @@ bool decodeFunctionVariant( std::string lastReturnType; while (std::getline(ss, lastReturnType, '\n')) { } - function.returnType = io::substrait::Type::decode(lastReturnType); + function.returnType = io::substrait::ParameterizedType::decode(lastReturnType); } const auto& args = node["args"]; if (args && args.IsSequence()) { @@ -77,7 +77,7 @@ struct YAML::convert { const auto& value = node["value"]; if (value && value.IsScalar()) { auto valueType = value.as(); - argument.type = io::substrait::Type::decode(valueType); + argument.type = io::substrait::ParameterizedType::decode(valueType); return true; } return false; diff --git a/substrait/type/Type.cpp b/substrait/type/Type.cpp index f8b676ac..b5457642 100644 --- a/substrait/type/Type.cpp +++ b/substrait/type/Type.cpp @@ -1,9 +1,9 @@ /* SPDX-License-Identifier: Apache-2.0 */ +#include "substrait/type/Type.h" #include #include #include -#include "substrait/type/Type.h" #include "substrait/common/Exceptions.h" namespace io::substrait { @@ -27,7 +27,7 @@ size_t findNextComma(const std::string& str, size_t start) { } // namespace -ParameterizedTypePtr ParameterizedType::decode(const std::string& rawType) { +ParameterizedTypePtr ParameterizedType::decode(const std::string& rawType, bool isParameterized) { std::string matchingType = rawType; std::transform( matchingType.begin(), @@ -97,38 +97,83 @@ ParameterizedTypePtr ParameterizedType::decode(const std::string& rawType) { auto commaPos = findNextComma(rawType, prevPos); while (commaPos != std::string::npos) { auto token = rawType.substr(prevPos, commaPos - prevPos); - nestedTypes.emplace_back(decode(token)); + nestedTypes.emplace_back(decode(token,isParameterized)); prevPos = commaPos + 1; commaPos = findNextComma(rawType, prevPos); } auto token = rawType.substr(prevPos, rightAngleBracketPos - prevPos); - nestedTypes.emplace_back(decode(token)); + nestedTypes.emplace_back(decode(token,isParameterized)); if (TypeTraits::typeString == baseType) { - return std::make_shared(nestedTypes[0], nullable); + if (isParameterized) { + return std::make_shared(nestedTypes[0], nullable); + } else { + return std::make_shared( + std::dynamic_pointer_cast(nestedTypes[0]), nullable); + } } else if (TypeTraits::typeString == baseType) { - return std::make_shared( - nestedTypes[0], nestedTypes[1], nullable); + if (isParameterized) { + return std::make_shared( + nestedTypes[0], nestedTypes[1], nullable); + } else { + return std::make_shared( + std::dynamic_pointer_cast(nestedTypes[0]), + std::dynamic_pointer_cast(nestedTypes[1]), + nullable); + } + } else if (TypeTraits::typeString == baseType) { - return std::make_shared(nestedTypes, nullable); + if (isParameterized) { + return std::make_shared( + nestedTypes, nullable); + } else { + std::vector types; + types.reserve(nestedTypes.size()); + for (int i = 0; i < nestedTypes.size(); i++) { + types.emplace_back( + std::dynamic_pointer_cast(nestedTypes.at(i))); + } + return std::make_shared(types, nullable); + } } else if (TypeTraits::typeString == baseType) { StringLiteralPtr precision = std::dynamic_pointer_cast(nestedTypes[0]); StringLiteralPtr scale = std::dynamic_pointer_cast(nestedTypes[1]); - return std::make_shared(precision, scale, nullable); + if (isParameterized) { + return std::make_shared( + precision, scale, nullable); + } else { + return std::make_shared( + std::stoi(precision->value()), std::stoi(scale->value()), nullable); + } } else if (TypeTraits::typeString == baseType) { auto length = std::dynamic_pointer_cast(nestedTypes[0]); - return std::make_shared(length, nullable); + if (isParameterized) { + return std::make_shared(length, nullable); + } else { + return std::make_shared(std::stoi(length->value()), nullable); + } + } else if (TypeTraits::typeString == baseType) { auto length = std::dynamic_pointer_cast(nestedTypes[0]); - return std::make_shared(length, nullable); + if (isParameterized) { + return std::make_shared(length, nullable); + } else { + return std::make_shared( + std::stoi(length->value()), nullable); + } } else if (TypeTraits::typeString == baseType) { auto length = std::dynamic_pointer_cast(nestedTypes[0]); - return std::make_shared(length, nullable); + if (isParameterized) { + return std::make_shared(length, nullable); + } else { + return std::make_shared( + std::stoi(length->value()), nullable); + } } else { SUBSTRAIT_UNSUPPORTED("Unsupported type: " + rawType); } diff --git a/third_party/googletest b/third_party/googletest index 3026483a..d1a0039b 160000 --- a/third_party/googletest +++ b/third_party/googletest @@ -1 +1 @@ -Subproject commit 3026483ae575e2de942db5e760cf95e973308dd5 +Subproject commit d1a0039b97291dd1dc14f123b906bb7622ffe07c