diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..eab4576f --- /dev/null +++ b/.clang-format @@ -0,0 +1,87 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... diff --git a/.gitignore b/.gitignore index 259148fa..ca9d349b 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ *.exe *.out *.app + +src/proto/substrait diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..aa615775 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,12 @@ +[submodule "third_party/yaml-cpp"] + path = third_party/yaml-cpp + url = https://github.com/jbeder/yaml-cpp.git +[submodule "third_party/googletest"] + path = third_party/googletest + url = https://github.com/google/googletest.git +[submodule "third_party/substrait"] + path = third_party/substrait + url = https://github.com/substrait-io/substrait.git +[submodule "third_party/fmt"] + path = third_party/fmt + url = https://github.com/fmtlib/fmt diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..ab8efc31 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.10) + +# set the project name +project(substrait-cpp) + +message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +option( + BUILD_TESTING + "Enable substrait-cpp tests. This will enable all other build options automatically." + ON) + +find_package(Protobuf REQUIRED) +include_directories(${PROTOBUF_INCLUDE_DIRS}) + +add_subdirectory(third_party) +include_directories(include) +add_subdirectory(substrait) diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9e..00000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..dd855043 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 + +.PHONY: all clean build debug release + +BUILD_TYPE := Release + +all: debug + +clean: + @rm -rf build-* + +build-common: + @mkdir -p build-${BUILD_TYPE} + @cd build-${BUILD_TYPE} && \ + cmake -Wno-dev \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DPREFER_STATIC_LIBS=OFF \ + $(FORCE_COLOR) \ + .. + +build: + VERBOSE=1 cmake --build build-${BUILD_TYPE} -j $${CPU_COUNT:-`nproc`} || \ + cmake --build build-${BUILD_TYPE} + +debug: + @$(MAKE) build-common BUILD_TYPE=Debug + @$(MAKE) build BUILD_TYPE=Debug + +release: + @$(MAKE) build-common BUILD_TYPE=Release + @$(MAKE) build BUILD_TYPE=Release diff --git a/README.md b/README.md index 36ca729e..87090451 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,30 @@ # substrait-cpp Planned home for CPP libraries to help build/consume Substrait query plans. + +## Getting Started + +We provide scripts to help developers setup and install substrait-cpp dependencies. + +### Get the substrait-cpp Source +``` +git clone --recursive https://github.com/substrait-io/substrait-cpp.git +cd substrait-cpp +# if you are updating an existing checkout +git submodule sync --recursive +git submodule update --init --recursive +``` + +### Setting up on Linux (Ubuntu 20.04 or later) + +Once you have checked out substrait-cpp, you can setup and build like so: + +```shell +$ ./scripts/setup-ubuntu.sh +$ make +``` + +## License + +substrait-cpp is licensed under the Apache 2.0 License. A copy of the license +[can be found here.](https://www.apache.org/licenses/LICENSE-2.0.html) \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 00000000..fe1505ae --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +FROM ubuntu:20.04 + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +WORKDIR /substrait + +RUN DEBIAN_FRONTEND=noninteractive TZ=America/New_York apt-get update -y && apt-get upgrade -y \ + && apt-get install -y sudo apt-utils tzdata +RUN dpkg-reconfigure tzdata + +RUN DEBIAN_FRONTEND=noninteractive apt-get update -y && apt-get install -y git build-essential cmake + +RUN git clone https://github.com/substrait-io/substrait-cpp.git \ + && cd substrait-cpp \ + && git submodule sync --recursive \ + && git submodule update --init --recursive + +RUN cd substrait-cpp && ./scripts/setup-ubuntu.sh + +RUN cd substrait-cpp && make + +ENTRYPOINT ["/bin/bash"] diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 00000000..0f30717f --- /dev/null +++ b/docker/README.md @@ -0,0 +1,21 @@ +# Setup Docker Container + +## Build + +```bash +docker build -t substrait-cpp . +``` + +## Run + +```bash +docker run -it substrait-cpp +``` + +## Evaluate + +Run function tests + +```bash +./build-Debug/substrait/function/tests/substrait_function_test +``` diff --git a/include/substrait/common/Exceptions.h b/include/substrait/common/Exceptions.h new file mode 100644 index 00000000..91c7768e --- /dev/null +++ b/include/substrait/common/Exceptions.h @@ -0,0 +1,136 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include +#include + +namespace io::substrait::common { +namespace error_code { + +//====================== User Error Codes ======================: + +// An error raised when an argument verification fails +inline constexpr const char* kInvalidArgument = "INVALID_ARGUMENT"; + +// An error raised when a requested operation is not supported. +inline constexpr const char* kUnsupported = "UNSUPPORTED"; + +//====================== Runtime Error Codes ======================: + +// An error raised when the current state of a component is invalid. +inline constexpr const char* kInvalidState = "INVALID_STATE"; + +// An error raised when unreachable code point was executed. +inline constexpr const char* kUnreachableCode = "UNREACHABLE_CODE"; + +// An error raised when a requested operation is not implemented. +inline constexpr const char* kNotImplemented = "NOT_IMPLEMENTED"; + +} // namespace error_code + +class SubstraitException : public std::exception { + public: + + enum class Type { + // Errors where the root cause of the problem is either because of bad input + // or an unsupported pattern of use are classified with USER. Examples + // of errors in this category include syntax errors, unavailable names or + // objects. + kUser = 0, + + // Errors where the root cause of the problem is some unreliable aspect of the + // system are classified with SYSTEM. + kSystem = 1 + }; + + SubstraitException( + const std::string& exceptionCode, + const std::string& exceptionMessage, + Type exceptionType = Type::kSystem, + const std::string& exceptionName = "SubstraitException"); + + // Inherited + [[nodiscard]] const char* what() const noexcept override { + return msg_.c_str(); + } + + private: + const std::string msg_; +}; + +class SubstraitUserError : public SubstraitException { + public: + SubstraitUserError( + const std::string& exceptionCode, + const std::string& exceptionMessage, + const std::string& exceptionName = "SubstraitUserError") + : SubstraitException( + exceptionCode, + exceptionMessage, + Type::kUser, + exceptionName) {} +}; + +class SubstraitRuntimeError final : public SubstraitException { + public: + SubstraitRuntimeError( + const std::string& exceptionCode, + const std::string& exceptionMessage, + const std::string& exceptionName = "SubstraitRuntimeError") + : SubstraitException( + exceptionCode, + exceptionMessage, + Type::kSystem, + exceptionName) {} +}; + +template +std::string errorMessage(fmt::string_view fmt, const Args&... args) { + return fmt::vformat(fmt, fmt::make_format_args(args...)); +} + +#define SUBSTRAIT_THROW(exception, errorCode, ...) \ + { \ + auto message = substrait::common::errorMessage(__VA_ARGS__); \ + throw exception(errorCode, message); \ + } + +#define SUBSTRAIT_UNSUPPORTED(...) \ + SUBSTRAIT_THROW( \ + substrait::common::SubstraitUserError, \ + substrait::common::error_code::kUnsupported, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_UNREACHABLE(...) \ + SUBSTRAIT_THROW( \ + substrait::common::SubstraitRuntimeError, \ + substrait::common::error_code::kUnreachableCode, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_FAIL(...) \ + SUBSTRAIT_THROW( \ + ::substrait::common::SubstraitRuntimeError, \ + ::substrait::common::error_code::kInvalidState, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_USER_FAIL(...) \ + SUBSTRAIT_THROW( \ + substrait::common::SubstraitUserError, \ + substrait::common::error_code::kInvalidState, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_NYI(...) \ + SUBSTRAIT_THROW( \ + substrait::common::SubstraitRuntimeError, \ + substrait::common::error_code::kNotImplemented, \ + ##__VA_ARGS__) + +#define SUBSTRAIT_IVALID_ARGUMENT(...) \ + SUBSTRAIT_THROW( \ + substrait::common::SubstraitUserError, \ + substrait::common::error_code::kInvalidArgument, \ + ##__VA_ARGS__) + +} // namespace io::substrait::common diff --git a/include/substrait/function/Extension.h b/include/substrait/function/Extension.h new file mode 100644 index 00000000..611e5ad9 --- /dev/null +++ b/include/substrait/function/Extension.h @@ -0,0 +1,83 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include +#include +#include + +#include "substrait/function/Function.h" +#include "substrait/function/FunctionSignature.h" +#include "substrait/type/Type.h" + +namespace io::substrait { + +struct TypeVariant { + std::string name; + std::string uri; +}; + +using TypeVariantPtr = std::shared_ptr; + +using FunctionImplMap = + std::unordered_map>; + +using TypeVariantMap = std::unordered_map; + +class Extension { + public: + /// Deserialize default substrait extension by given basePath + /// @throws exception if file not found + static std::shared_ptr load(const std::string& basePath); + + /// Deserialize substrait extension by given basePath and extensionFiles. + static std::shared_ptr load( + const std::string& basePath, + const std::vector& extensionFiles); + + /// Deserialize substrait extension by given extensionFiles. + static std::shared_ptr load( + const std::vector& extensionFiles); + + /// Add a scalar function implementation. + void addScalarFunctionImpl(const FunctionImplementationPtr& functionImpl); + + /// Add an aggregate function implementation. + void addAggregateFunctionImpl(const FunctionImplementationPtr& functionImpl); + + /// Add a window function implementation. + void addWindowFunctionImpl(const FunctionImplementationPtr& functionImpl); + + /// Add a type variant. + void addTypeVariant(const TypeVariantPtr& typeVariant); + + /// Lookup type variant by given type name. + /// @return matched type variant + TypeVariantPtr lookupType(const std::string& typeName) const; + + const FunctionImplMap& scalaFunctionImplMap() const { + return scalarFunctionImplMap_; + } + + const FunctionImplMap& windowFunctionImplMap() const { + return windowFunctionImplMap_; + } + + const FunctionImplMap& aggregateFunctionImplMap() const { + return aggregateFunctionImplMap_; + } + + private: + FunctionImplMap scalarFunctionImplMap_; + + FunctionImplMap aggregateFunctionImplMap_; + + FunctionImplMap windowFunctionImplMap_; + + TypeVariantMap typeVariantMap_; +}; + +using ExtensionPtr = std::shared_ptr; + +} // namespace io::substrait diff --git a/include/substrait/function/Function.h b/include/substrait/function/Function.h new file mode 100644 index 00000000..a198c593 --- /dev/null +++ b/include/substrait/function/Function.h @@ -0,0 +1,116 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "substrait/type/Type.h" +#include "substrait/function/FunctionSignature.h" + +namespace io::substrait { + +struct FunctionArgument { + [[nodiscard]] virtual bool isRequired() const = 0; + + /// Convert argument type to short type string based on + /// https://substrait.io/extensions/#function-signature-compound-names + [[nodiscard]] virtual std::string toTypeString() const = 0; + + [[nodiscard]] virtual bool isWildcardType() const { + return false; + }; + + [[nodiscard]] virtual bool isValueArgument() const { + return false; + } + + [[nodiscard]] virtual bool isEnumArgument() const { + return false; + } + + [[nodiscard]] virtual bool isTypeArgument() const { + return false; + } +}; + +using FunctionArgumentPtr = std::shared_ptr; + +struct EnumArgument : public FunctionArgument { + bool required{}; + + [[nodiscard]] bool isRequired() const override { + return required; + } + + [[nodiscard]] std::string toTypeString() const override { + return required ? "req" : "opt"; + } + + [[nodiscard]] bool isEnumArgument() const override { + return true; + } +}; + +struct TypeArgument : public FunctionArgument { + [[nodiscard]] std::string toTypeString() const override { + return "type"; + } + + [[nodiscard]] bool isRequired() const override { + return true; + } + + [[nodiscard]] bool isTypeArgument() const override { + return true; + } +}; + +struct ValueArgument : public FunctionArgument { + ParameterizedTypePtr type; + + [[nodiscard]] std::string toTypeString() const override { + return type->signature(); + } + + [[nodiscard]] bool isRequired() const override { + return true; + } + + [[nodiscard]] bool isWildcardType() const override { + return type->isWildcard(); + } + + [[nodiscard]] bool isValueArgument() const override { + return true; + } +}; + +struct FunctionVariadic { + int min; + std::optional max; +}; + +struct FunctionImplementation { + std::string name; + std::string uri; + std::vector arguments; + ParameterizedTypePtr returnType; + std::optional variadic; + + /// Test if the actual types matched with this function's implementation. + virtual bool tryMatch(const FunctionSignature& signature); + + /// Create function signature by function name and arguments. + [[nodiscard]] std::string signature() const; +}; + +using FunctionImplementationPtr = std::shared_ptr; + +struct ScalarFunctionImplementation : public FunctionImplementation {}; + +struct AggregateFunctionImplementation : public FunctionImplementation { + ParameterizedTypePtr intermediate; + bool deterministic; + + bool tryMatch(const FunctionSignature& signature) override; +}; + +} // namespace io::substrait diff --git a/include/substrait/function/FunctionLookup.h b/include/substrait/function/FunctionLookup.h new file mode 100644 index 00000000..90f98426 --- /dev/null +++ b/include/substrait/function/FunctionLookup.h @@ -0,0 +1,61 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "substrait/function/Extension.h" +#include "substrait/function/FunctionSignature.h" + +namespace io::substrait { + +class FunctionLookup { + public: + explicit FunctionLookup(ExtensionPtr extension) + : extension_(std::move(extension)) {} + + [[nodiscard]] virtual FunctionImplementationPtr lookupFunction( + const FunctionSignature& signature) const; + + virtual ~FunctionLookup() = default; + + protected: + [[nodiscard]] virtual FunctionImplMap getFunctionImpls() const = 0; + + ExtensionPtr extension_{}; +}; + +using FunctionLookupPtr = std::shared_ptr; + +class ScalarFunctionLookup : public FunctionLookup { + public: + ScalarFunctionLookup(const ExtensionPtr& extension) + : FunctionLookup(extension) {} + + protected: + [[nodiscard]] FunctionImplMap getFunctionImpls() const override { + return extension_->scalaFunctionImplMap(); + } +}; + +class AggregateFunctionLookup : public FunctionLookup { + public: + explicit AggregateFunctionLookup(const ExtensionPtr& extension) + : FunctionLookup(extension) {} + + protected: + [[nodiscard]] FunctionImplMap getFunctionImpls() const override { + return extension_->aggregateFunctionImplMap(); + } +}; + +class WindowFunctionLookup : public FunctionLookup { + public: + explicit WindowFunctionLookup(const ExtensionPtr& extension) + : FunctionLookup(extension) {} + + protected: + [[nodiscard]] FunctionImplMap getFunctionImpls() const override { + return extension_->windowFunctionImplMap(); + } +}; + +} // namespace io::substrait diff --git a/include/substrait/function/FunctionSignature.h b/include/substrait/function/FunctionSignature.h new file mode 100644 index 00000000..5ebf54f3 --- /dev/null +++ b/include/substrait/function/FunctionSignature.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include "substrait/type/Type.h" + +namespace io::substrait { + +struct FunctionSignature { + std::string name; + std::vector arguments; + TypePtr returnType; +}; + +} // namespace io::substrait diff --git a/include/substrait/type/Type.h b/include/substrait/type/Type.h new file mode 100644 index 00000000..93208f9f --- /dev/null +++ b/include/substrait/type/Type.h @@ -0,0 +1,656 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace io::substrait { + +enum class TypeKind : int8_t { + kBool = 1, + kI8 = 2, + kI16 = 3, + kI32 = 4, + kI64 = 5, + kFp32 = 6, + kFp64 = 7, + kString = 8, + kBinary = 9, + kTimestamp = 10, + kDate = 11, + kTime = 12, + kIntervalYear = 13, + kIntervalDay = 14, + kTimestampTz = 15, + kUuid = 16, + kFixedChar = 17, + kVarchar = 18, + kFixedBinary = 19, + kDecimal = 20, + kStruct = 21, + kList = 22, + kMap = 23, + KIND_NOT_SET = 0, +}; + +template +struct TypeTraits {}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "bool"; + static constexpr const char* typeString = "boolean"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i8"; + static constexpr const char* typeString = "i8"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i16"; + static constexpr const char* typeString = "i16"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i32"; + static constexpr const char* typeString = "i32"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "i64"; + static constexpr const char* typeString = "i64"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fp32"; + static constexpr const char* typeString = "fp32"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fp64"; + static constexpr const char* typeString = "fp64"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "str"; + static constexpr const char* typeString = "string"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "vbin"; + static constexpr const char* typeString = "binary"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "ts"; + static constexpr const char* typeString = "timestamp"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "tstz"; + static constexpr const char* typeString = "timestamp_tz"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "date"; + static constexpr const char* typeString = "date"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "time"; + static constexpr const char* typeString = "time"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "iyear"; + static constexpr const char* typeString = "interval_year"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "iday"; + static constexpr const char* typeString = "interval_day"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "uuid"; + static constexpr const char* typeString = "uuid"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fchar"; + static constexpr const char* typeString = "fixedchar"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "vchar"; + static constexpr const char* typeString = "varchar"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "fbin"; + static constexpr const char* typeString = "fixedbinary"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "dec"; + static constexpr const char* typeString = "decimal"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "struct"; + static constexpr const char* typeString = "struct"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "list"; + static constexpr const char* typeString = "list"; +}; + +template <> +struct TypeTraits { + static constexpr const char* signature = "map"; + static constexpr const char* typeString = "map"; +}; + +class ParameterizedType { + public: + explicit ParameterizedType(bool nullable = false) : nullable_(nullable) {} + + [[nodiscard]] virtual std::string signature() const = 0; + + [[nodiscard]] virtual TypeKind kind() const = 0; + + /// Deserialize substrait raw type string into Substrait extension type. + /// @param rawType - substrait extension raw string type + static std::shared_ptr decode( + const std::string& rawType); + + [[nodiscard]] const bool& nullable() const { + return nullable_; + } + + [[nodiscard]] bool nullMatch( + const std::shared_ptr& type) const { + return nullable() || nullable() == type->nullable(); + } + /// Test type is a Wildcard type or not. + [[nodiscard]] virtual bool isWildcard() const { + return false; + } + + [[nodiscard]] virtual bool isMatch( + const std::shared_ptr& type) const = 0; + + private: + const bool nullable_; +}; + +using ParameterizedTypePtr = std::shared_ptr; + +class Type : public ParameterizedType { + public: + explicit Type(bool nullable = false) : ParameterizedType(nullable) {} +}; + +using TypePtr = std::shared_ptr; + +/// Types used in function argument declarations. +template +class TypeBase : public Type { + public: + explicit TypeBase(bool nullable = false) : Type(nullable) {} + + [[nodiscard]] std::string signature() const override { + return TypeTraits::signature; + } + + [[nodiscard]] TypeKind kind() const override { + return Kind; + } + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override { + return kind() == type->kind() && nullMatch(type); + } +}; + +template +class ScalarType : public TypeBase { + public: + explicit ScalarType(bool nullable) : TypeBase(nullable) {} +}; + +class Decimal : public TypeBase { + public: + Decimal(int precision, int scale, bool nullable = false) + : TypeBase(nullable), + precision_(precision), + scale_(scale) {} + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] const int& precision() const { + return precision_; + } + + [[nodiscard]] const int& scale() const { + return scale_; + } + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const int precision_; + const int scale_; +}; + +class FixedBinary : public TypeBase { + public: + explicit FixedBinary(int length, bool nullable = false) + : TypeBase(nullable), length_(length) {} + + [[nodiscard]] const int& length() const { + return length_; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const int length_; +}; + +class FixedChar : public TypeBase { + public: + explicit FixedChar(int length, bool nullable = false) + : TypeBase(nullable), length_(length){}; + + [[nodiscard]] const int& length() const { + return length_; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const int length_; +}; + +class Varchar : public TypeBase { + public: + explicit Varchar(int length, bool nullable = false) + : TypeBase(nullable), length_(length){}; + + [[nodiscard]] const int& length() const { + return length_; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const int length_; +}; + +class List : public TypeBase { + public: + explicit List(TypePtr elementType, bool nullable = false) + : TypeBase(nullable), + elementType_(std::move(elementType)){}; + + [[nodiscard]] const TypePtr& elementType() const { + return elementType_; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const TypePtr elementType_; +}; + +class Struct : public TypeBase { + public: + explicit Struct(std::vector types, bool nullable = false) + : TypeBase(nullable), children_(std::move(types)) {} + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] const std::vector& children() const { + return children_; + } + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const std::vector children_; +}; + +class Map : public TypeBase { + public: + Map(TypePtr keyType, TypePtr valueType, bool nullable = false) + : TypeBase(nullable), + keyType_(std::move(keyType)), + valueType_(std::move(valueType)) {} + + [[nodiscard]] const TypePtr& keyType() const { + return keyType_; + } + + [[nodiscard]] const TypePtr& valueType() const { + return valueType_; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const TypePtr keyType_; + const TypePtr valueType_; +}; + +class ParameterizedTypeBase : public ParameterizedType { + public: + explicit ParameterizedTypeBase(bool nullable = false) + : ParameterizedType(nullable) {} +}; + +/// A string literal type can present the 'any1'. +class StringLiteral : public ParameterizedTypeBase { + public: + explicit StringLiteral(std::string value) + : ParameterizedTypeBase(false), value_(std::move(value)) {} + + [[nodiscard]] std::string signature() const override { + return value_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::KIND_NOT_SET; + } + + [[nodiscard]] const std::string& value() const { + return value_; + } + + [[nodiscard]] bool isWildcard() const override { + return value_.find("any") == 0 || value_ == "T"; + } + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const std::string value_; +}; + +using StringLiteralPtr = std::shared_ptr; + +class ParameterizedDecimal : public ParameterizedTypeBase { + public: + ParameterizedDecimal( + StringLiteralPtr precision, + StringLiteralPtr scale, + bool nullable = false) + : ParameterizedTypeBase(nullable), + precision_(std::move(precision)), + scale_(std::move(scale)) {} + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] const StringLiteralPtr& precision() const { + return precision_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::kDecimal; + } + + [[nodiscard]] const StringLiteralPtr& scale() const { + return scale_; + } + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + StringLiteralPtr precision_; + StringLiteralPtr scale_; +}; + +class ParameterizedFixedBinary : public ParameterizedTypeBase { + public: + explicit ParameterizedFixedBinary( + StringLiteralPtr length, + bool nullable = false) + : ParameterizedTypeBase(nullable), length_(std::move(length)) {} + + [[nodiscard]] const StringLiteralPtr& length() const { + return length_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::kFixedBinary; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const StringLiteralPtr length_; +}; + +class ParameterizedFixedChar : public ParameterizedTypeBase { + public: + explicit ParameterizedFixedChar( + StringLiteralPtr length, + bool nullable = false) + : ParameterizedTypeBase(nullable), length_(std::move(length)) {} + + [[nodiscard]] const StringLiteralPtr& length() const { + return length_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::kFixedChar; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const StringLiteralPtr length_; +}; + +class ParameterizedVarchar : public ParameterizedTypeBase { + public: + explicit ParameterizedVarchar(StringLiteralPtr length, bool nullable = false) + : ParameterizedTypeBase(nullable), length_(std::move(length)) {} + + [[nodiscard]] const StringLiteralPtr& length() const { + return length_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::kVarchar; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const StringLiteralPtr length_; +}; + +class ParameterizedList : public ParameterizedTypeBase { + public: + explicit ParameterizedList( + ParameterizedTypePtr elementType, + bool nullable = false) + : ParameterizedTypeBase(nullable), elementType_(std::move(elementType)){}; + + [[nodiscard]] const ParameterizedTypePtr& elementType() const { + return elementType_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::kList; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const ParameterizedTypePtr elementType_; +}; + +class ParameterizedStruct : public ParameterizedTypeBase { + public: + explicit ParameterizedStruct( + std::vector types, + bool nullable = false) + : ParameterizedTypeBase(nullable), children_(std::move(types)) {} + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] const std::vector& children() const { + return children_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::kStruct; + } + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const std::vector children_; +}; + +class ParameterizedMap : public ParameterizedTypeBase { + public: + ParameterizedMap( + ParameterizedTypePtr keyType, + ParameterizedTypePtr valueType, + bool nullable = false) + : ParameterizedTypeBase(nullable), + keyType_(std::move(keyType)), + valueType_(std::move(valueType)) {} + + [[nodiscard]] const ParameterizedTypePtr& keyType() const { + return keyType_; + } + + [[nodiscard]] TypeKind kind() const override { + return TypeKind::kMap; + } + [[nodiscard]] const ParameterizedTypePtr& valueType() const { + return valueType_; + } + + [[nodiscard]] std::string signature() const override; + + [[nodiscard]] bool isMatch( + const std::shared_ptr& type) const override; + + private: + const ParameterizedTypePtr keyType_; + const ParameterizedTypePtr valueType_; +}; + +std::shared_ptr> BOOL(); + +std::shared_ptr> TINYINT(); + +std::shared_ptr> SMALLINT(); + +std::shared_ptr> INTEGER(); + +std::shared_ptr> BIGINT(); + +std::shared_ptr> FLOAT(); + +std::shared_ptr> DOUBLE(); + +std::shared_ptr> STRING(); + +std::shared_ptr> BINARY(); + +std::shared_ptr> TIMESTAMP(); + +std::shared_ptr> TIMESTAMP_TZ(); + +std::shared_ptr> DATE(); + +std::shared_ptr> TIME(); + +std::shared_ptr> INTERVAL_YEAR(); + +std::shared_ptr> INTERVAL_DAY(); + +std::shared_ptr> UUID(); + +std::shared_ptr DECIMAL(int precision, int scale); + +std::shared_ptr VARCHAR(int len); + +std::shared_ptr FIXED_CHAR(int len); + +std::shared_ptr FIXED_BINARY(int len); + +std::shared_ptr LIST(const TypePtr& elementType); + +std::shared_ptr MAP( + const TypePtr& keyType, + const TypePtr& valueType); + +std::shared_ptr STRUCT(const std::vector& children); + +} // namespace io::substrait diff --git a/scripts/setup-helper-functions.sh b/scripts/setup-helper-functions.sh new file mode 100755 index 00000000..8c2c4c7a --- /dev/null +++ b/scripts/setup-helper-functions.sh @@ -0,0 +1,129 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 + +# github_checkout $REPO $VERSION $GIT_CLONE_PARAMS clones or re-uses an existing clone of the +# specified repo, checking out the requested version. +function github_checkout { + local REPO=$1 + shift + local VERSION=$1 + shift + local GIT_CLONE_PARAMS=$@ + local DIRNAME=$(basename $REPO) + cd "${DEPENDENCY_DIR}" + if [ -z "${DIRNAME}" ]; then + echo "Failed to get repo name from ${REPO}" + exit 1 + fi + if [ -d "${DIRNAME}" ] && prompt "${DIRNAME} already exists. Delete?"; then + rm -rf "${DIRNAME}" + fi + if [ ! -d "${DIRNAME}" ]; then + git clone -q -b $VERSION $GIT_CLONE_PARAMS "https://github.com/${REPO}.git" + fi + cd "${DIRNAME}" +} + + +# get_cxx_flags [$CPU_ARCH] +# Sets and exports the variable VELOX_CXX_FLAGS with appropriate compiler flags. +# If $CPU_ARCH is set then we use that else we determine best possible set of flags +# to use based on current cpu architecture. +# The goal of this function is to consolidate all architecture specific flags to one +# location. +# The values that CPU_ARCH can take are as follows: +# arm64 : Target Apple silicon. +# aarch64: Target general 64 bit arm cpus. +# avx: Target Intel CPUs with AVX. +# sse: Target Intel CPUs with sse. +# Echo's the appropriate compiler flags which can be captured as so +# CXX_FLAGS=$(get_cxx_flags) or +# CXX_FLAGS=$(get_cxx_flags "avx") + +function get_cxx_flags { + local CPU_ARCH=$1 + + local OS + OS=$(uname) + local MACHINE + MACHINE=$(uname -m) + + if [ -z "$CPU_ARCH" ]; then + + if [ "$OS" = "Darwin" ]; then + + if [ "$MACHINE" = "x86_64" ]; then + local CPU_CAPABILITIES + CPU_CAPABILITIES=$(sysctl -a | grep machdep.cpu.features | awk '{print tolower($0)}') + + if [[ $CPU_CAPABILITIES =~ "avx" ]]; then + CPU_ARCH="avx" + else + CPU_ARCH="sse" + fi + + elif [[ $(sysctl -a | grep machdep.cpu.brand_string) =~ "Apple" ]]; then + # Apple silicon. + CPU_ARCH="arm64" + fi + else [ "$OS" = "Linux" ]; + + local CPU_CAPABILITIES + CPU_CAPABILITIES=$(cat /proc/cpuinfo | grep flags | head -n 1| awk '{print tolower($0)}') + + if [[ "$CPU_CAPABILITIES" =~ "avx" ]]; then + CPU_ARCH="avx" + elif [[ "$CPU_CAPABILITIES" =~ "sse" ]]; then + CPU_ARCH="sse" + elif [ "$MACHINE" = "aarch64" ]; then + CPU_ARCH="aarch64" + fi + fi + fi + + case $CPU_ARCH in + + "arm64") + echo -n "-mcpu=apple-m1+crc -std=c++17" + ;; + + "avx") + echo -n "-mavx2 -mfma -mavx -mf16c -mlzcnt -std=c++17" + ;; + + "sse") + echo -n "-msse4.2 -std=c++17" + ;; + + "aarch64") + echo -n "-mcpu=neoverse-n1 -std=c++17" + ;; + *) + echo -n "Architecture not supported!" + esac + +} + +function cmake_install { + local NAME=$(basename "$(pwd)") + local BINARY_DIR=_build + if [ -d "${BINARY_DIR}" ] && prompt "Do you want to rebuild ${NAME}?"; then + rm -rf "${BINARY_DIR}" + fi + mkdir -p "${BINARY_DIR}" + CPU_TARGET="${CPU_TARGET:-avx}" + COMPILER_FLAGS=$(get_cxx_flags $CPU_TARGET) + + # CMAKE_POSITION_INDEPENDENT_CODE is required so that Velox can be built into dynamic libraries \ + cmake -Wno-dev -B"${BINARY_DIR}" \ + -GNinja \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_CXX_STANDARD=17 \ + "${INSTALL_PREFIX+-DCMAKE_PREFIX_PATH=}${INSTALL_PREFIX-}" \ + "${INSTALL_PREFIX+-DCMAKE_INSTALL_PREFIX=}${INSTALL_PREFIX-}" \ + -DCMAKE_CXX_FLAGS="$COMPILER_FLAGS" \ + -DBUILD_TESTING=OFF \ + "$@" + ninja -C "${BINARY_DIR}" install +} + diff --git a/scripts/setup-ubuntu.sh b/scripts/setup-ubuntu.sh new file mode 100755 index 00000000..438de79a --- /dev/null +++ b/scripts/setup-ubuntu.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 + +# Minimal setup for Ubuntu 20.04. +set -eufx -o pipefail +SCRIPTDIR=$(dirname "${BASH_SOURCE[0]}") +source $SCRIPTDIR/setup-helper-functions.sh + +CPU_TARGET="${CPU_TARGET:-avx}" +export COMPILER_FLAGS=$(get_cxx_flags $CPU_TARGET) +NPROC=$(getconf _NPROCESSORS_ONLN) +DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} + +# Install all dependencies. +sudo --preserve-env apt install -y \ + wget \ + g++ \ + cmake \ + ccache \ + ninja-build \ + checkinstall \ + git \ + wget + +function run_and_time { + time "$@" + { echo "+ Finished running $*"; } 2> /dev/null +} + +function prompt { + ( + while true; do + local input="${PROMPT_ALWAYS_RESPOND:-}" + echo -n "$(tput bold)$* [Y, n]$(tput sgr0) " + [[ -z "${input}" ]] && read input + if [[ "${input}" == "Y" || "${input}" == "y" || "${input}" == "" ]]; then + return 0 + elif [[ "${input}" == "N" || "${input}" == "n" ]]; then + return 1 + fi + done + ) 2> /dev/null +} + +function install_protobuf { + wget https://github.com/protocolbuffers/protobuf/releases/download/v21.4/protobuf-all-21.4.tar.gz + tar -xzf protobuf-all-21.4.tar.gz + cd protobuf-21.4 + ./configure --prefix=/usr + make "-j$(nproc)" + make install + ldconfig +} + +function install_deps { + run_and_time install_protobuf +} + +(return 2> /dev/null) && return # If script was sourced, don't run commands. + +( + if [[ $# -ne 0 ]]; then + for cmd in "$@"; do + run_and_time "${cmd}" + done + else + install_deps + fi +) + +echo "All deps installed! Now try \"make\"" diff --git a/substrait/CMakeLists.txt b/substrait/CMakeLists.txt new file mode 100644 index 00000000..2d18d0b2 --- /dev/null +++ b/substrait/CMakeLists.txt @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_subdirectory(common) +add_subdirectory(type) +add_subdirectory(function) diff --git a/substrait/common/CMakeLists.txt b/substrait/common/CMakeLists.txt new file mode 100644 index 00000000..8c6936ca --- /dev/null +++ b/substrait/common/CMakeLists.txt @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 + +find_package(fmt) + +add_library( + substrait_common + Exceptions.cpp) + +target_link_libraries( + substrait_common + fmt::fmt-header-only) + diff --git a/substrait/common/Exceptions.cpp b/substrait/common/Exceptions.cpp new file mode 100644 index 00000000..15537cf4 --- /dev/null +++ b/substrait/common/Exceptions.cpp @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include +#include "substrait/common/Exceptions.h" + +namespace io::substrait::common { + +SubstraitException::SubstraitException( + const std::string& exceptionCode, + const std::string& exceptionMessage, + Type exceptionType, + const std::string& exceptionName) + : msg_(fmt::format( + "Exception: {}\nError Code: {}\nType: {}\nReason: {}\n" + "Function: {}\nFile: {}\n:Line: {}\n", + exceptionName, + exceptionCode, + exceptionType == Type::kSystem ? "system" : "user", + exceptionMessage, + __FUNCTION__, + __FILE__, + std::to_string(__LINE__))) {} + +} // namespace io::substrait::common diff --git a/substrait/function/CMakeLists.txt b/substrait/function/CMakeLists.txt new file mode 100644 index 00000000..fd124c60 --- /dev/null +++ b/substrait/function/CMakeLists.txt @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 + +set(FUNCTION_SRCS + Function.cpp + Extension.cpp + FunctionLookup.cpp) + +add_library(substrait_function ${FUNCTION_SRCS}) + +target_link_libraries( + substrait_function + substrait_type + yaml-cpp) + +if (${BUILD_TESTING}) + add_subdirectory(tests) +endif () \ No newline at end of file diff --git a/substrait/function/Extension.cpp b/substrait/function/Extension.cpp new file mode 100644 index 00000000..d2778f98 --- /dev/null +++ b/substrait/function/Extension.cpp @@ -0,0 +1,280 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include +#include "substrait/function/Extension.h" + +bool decodeFunctionImpl( + const YAML::Node& node, + io::substrait::FunctionImplementation& function) { + const auto& returnType = node["return"]; + if (returnType && returnType.IsScalar()) { + /// Return type can be an expression. + const auto& returnExpr = returnType.as(); + std::stringstream ss(returnExpr); + + // TODO: currently we only parse the last sentence of type definition, use + // ANTLR in future. + std::string lastReturnType; + while (std::getline(ss, lastReturnType, '\n')) { + } + function.returnType = io::substrait::Type::decode(lastReturnType); + } + const auto& args = node["args"]; + if (args && args.IsSequence()) { + for (auto& arg : args) { + if (arg["options"]) { // enum argument + auto enumArgument = std::make_shared( + arg.as()); + function.arguments.emplace_back(enumArgument); + } else if (arg["value"]) { // value argument + auto valueArgument = std::make_shared( + arg.as()); + function.arguments.emplace_back(valueArgument); + } else { // type argument + auto typeArgument = std::make_shared( + arg.as()); + function.arguments.emplace_back(typeArgument); + } + } + } + + const auto& variadic = node["variadic"]; + if (variadic) { + auto& min = variadic["min"]; + auto& max = variadic["max"]; + if (min) { + function.variadic = std::make_optional( + {min.as(), + max ? std::make_optional(max.as()) : std::nullopt}); + } else { + function.variadic = std::nullopt; + } + } else { + function.variadic = std::nullopt; + } + + return true; +} + +template <> +struct YAML::convert { + static bool decode(const Node& node, io::substrait::EnumArgument& argument) { + // 'options' is required property + const auto& options = node["options"]; + if (options && options.IsSequence()) { + auto& required = node["required"]; + argument.required = required && required.as(); + return true; + } else { + return false; + } + } +}; + +template <> +struct YAML::convert { + static bool decode(const Node& node, io::substrait::ValueArgument& argument) { + const auto& value = node["value"]; + if (value && value.IsScalar()) { + auto valueType = value.as(); + argument.type = io::substrait::Type::decode(valueType); + return true; + } + return false; + } +}; + +template <> +struct YAML::convert { + static bool decode( + const YAML::Node& node, + io::substrait::TypeArgument& argument) { + // no properties need to populate for type argument, just return true if + // 'type' element exists. + if (node["type"]) { + return true; + } + return false; + } +}; + +template <> +struct YAML::convert { + static bool decode( + const Node& node, + io::substrait::ScalarFunctionImplementation& function) { + return decodeFunctionImpl(node, function); + }; +}; + +template <> +struct YAML::convert { + static bool decode( + const Node& node, + io::substrait::AggregateFunctionImplementation& function) { + const auto& res = decodeFunctionImpl(node, function); + if (res) { + const auto& intermediate = node["intermediate"]; + if (intermediate) { + function.intermediate = + io::substrait::ParameterizedType::decode(intermediate.as()); + } + } + return res; + } +}; + +template <> +struct YAML::convert { + static bool decode(const Node& node, io::substrait::TypeVariant& typeAnchor) { + const auto& name = node["name"]; + if (name && name.IsScalar()) { + typeAnchor.name = name.as(); + return true; + } + return false; + } +}; + +namespace io::substrait { + +std::shared_ptr Extension::load(const std::string& basePath) { + static const std::vector extensionFiles{ + "functions_aggregate_approx.yaml", + "functions_aggregate_generic.yaml", + "functions_arithmetic.yaml", + "functions_arithmetic_decimal.yaml", + "functions_boolean.yaml", + "functions_comparison.yaml", + "functions_datetime.yaml", + "functions_logarithmic.yaml", + "functions_rounding.yaml", + "functions_string.yaml", + "functions_set.yaml", + }; + return load(basePath, extensionFiles); +} + +std::shared_ptr Extension::load( + const std::string& basePath, + const std::vector& extensionFiles) { + std::vector yamlExtensionFiles; + yamlExtensionFiles.reserve(extensionFiles.size()); + for (auto& extensionFile : extensionFiles) { + auto const pos = basePath.find_last_of('/'); + const auto& extensionUri = basePath.substr(0, pos) + "/" + extensionFile; + yamlExtensionFiles.emplace_back(extensionUri); + } + return load(yamlExtensionFiles); +} + +std::shared_ptr Extension::load( + const std::vector& extensionFiles) { + auto extension = std::make_shared(); + for (const auto& extensionUri : extensionFiles) { + const auto& node = YAML::LoadFile(extensionUri); + + const auto& scalarFunctions = node["scalar_functions"]; + if (scalarFunctions && scalarFunctions.IsSequence()) { + for (auto& scalarFunctionNode : scalarFunctions) { + const auto functionName = scalarFunctionNode["name"].as(); + for (auto& scalaFunctionImplNode : scalarFunctionNode["impls"]) { + auto scalarFunctionImpl = + scalaFunctionImplNode.as(); + scalarFunctionImpl.name = functionName; + scalarFunctionImpl.uri = extensionUri; + extension->addScalarFunctionImpl( + std::make_shared( + scalarFunctionImpl)); + } + } + } + + const auto& aggregateFunctions = node["aggregate_functions"]; + if (aggregateFunctions && aggregateFunctions.IsSequence()) { + for (auto& aggregateFunctionNode : aggregateFunctions) { + const auto functionName = + aggregateFunctionNode["name"].as(); + for (auto& aggregateFunctionImplNode : + aggregateFunctionNode["impls"]) { + auto aggregateFunctionImpl = + aggregateFunctionImplNode.as(); + aggregateFunctionImpl.name = functionName; + aggregateFunctionImpl.uri = extensionUri; + extension->addAggregateFunctionImpl( + std::make_shared( + aggregateFunctionImpl)); + } + } + } + + const auto& types = node["types"]; + if (types && types.IsSequence()) { + for (auto& type : types) { + auto typeAnchor = type.as(); + typeAnchor.uri = extensionUri; + extension->addTypeVariant(std::make_shared(typeAnchor)); + } + } + } + return extension; +} + +void Extension::addWindowFunctionImpl( + const FunctionImplementationPtr& functionImpl) { + const auto& functionImpls = + windowFunctionImplMap_.find(functionImpl->name); + if (functionImpls != windowFunctionImplMap_.end()) { + auto& impls = functionImpls->second; + impls.emplace_back(functionImpl); + } else { + std::vector impls; + impls.emplace_back(functionImpl); + windowFunctionImplMap_.insert( + {functionImpl->name, std::move(impls)}); + } +} + +void Extension::addTypeVariant(const TypeVariantPtr& typeVariant) { + typeVariantMap_.insert({typeVariant->name, typeVariant}); +} + +TypeVariantPtr Extension::lookupType(const std::string& typeName) const { + auto typeVariantIter = typeVariantMap_.find(typeName); + if (typeVariantIter != typeVariantMap_.end()) { + return typeVariantIter->second; + } + return nullptr; +} + +void Extension::addScalarFunctionImpl( + const FunctionImplementationPtr& functionImpl) { + const auto& functionImpls = + scalarFunctionImplMap_.find(functionImpl->name); + if (functionImpls != scalarFunctionImplMap_.end()) { + auto& impls = functionImpls->second; + impls.emplace_back(functionImpl); + } else { + std::vector impls; + impls.emplace_back(functionImpl); + scalarFunctionImplMap_.insert( + {functionImpl->name, std::move(impls)}); + } +} + +void Extension::addAggregateFunctionImpl( + const FunctionImplementationPtr& functionImpl) { + const auto& functionImpls = + aggregateFunctionImplMap_.find(functionImpl->name); + if (functionImpls != aggregateFunctionImplMap_.end()) { + auto& impls = functionImpls->second; + impls.emplace_back(functionImpl); + } else { + std::vector impls; + impls.emplace_back(functionImpl); + aggregateFunctionImplMap_.insert( + {functionImpl->name, std::move(impls)}); + } +} + +} // namespace io::substrait diff --git a/substrait/function/Function.cpp b/substrait/function/Function.cpp new file mode 100644 index 00000000..d7c0eee7 --- /dev/null +++ b/substrait/function/Function.cpp @@ -0,0 +1,86 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include +#include "substrait/function/Function.h" + +namespace io::substrait { + +bool FunctionImplementation::tryMatch(const FunctionSignature& signature) { + const auto& actualTypes = signature.arguments; + if (variadic.has_value()) { + // return false if actual types length less than min of variadic + const auto max = variadic->max; + if ((actualTypes.size() < variadic->min) || + (max.has_value() && actualTypes.size() > max.value())) { + return false; + } + + const auto& variadicArgument = arguments[0]; + // actual type must same as the variadicArgument + if (const auto& variadicValueArgument = + std::dynamic_pointer_cast(variadicArgument)) { + for (auto& actualType : actualTypes) { + if (!variadicValueArgument->type->isMatch(actualType)) { + return false; + } + } + } + } else { + std::vector> valueArguments; + for (const auto& argument : arguments) { + if (const auto& variadicValueArgument = + std::dynamic_pointer_cast(argument)) { + valueArguments.emplace_back(variadicValueArgument); + } + } + // return false if size of actual types not equal to size of value + // arguments. + if (valueArguments.size() != actualTypes.size()) { + return false; + } + + for (auto i = 0; i < actualTypes.size(); i++) { + const auto& valueArgument = valueArguments[i]; + if (!valueArgument->type->isMatch(actualTypes[i])) { + return false; + } + } + } + const auto& sigReturnType = signature.returnType; + if (this->returnType && sigReturnType) { + return returnType->isMatch(sigReturnType); + } else { + return true; + } +} + +std::string FunctionImplementation::signature() const { + std::stringstream ss; + ss << name; + if (!arguments.empty()) { + ss << ":"; + for (auto it = arguments.begin(); it != arguments.end(); ++it) { + const auto& typeSign = (*it)->toTypeString(); + if (it == arguments.end() - 1) { + ss << typeSign; + } else { + ss << typeSign << "_"; + } + } + } + + return ss.str(); +} + +bool AggregateFunctionImplementation::tryMatch(const FunctionSignature& signature) { + bool matched = FunctionImplementation::tryMatch(signature); + if (!matched && intermediate) { + const auto& actualTypes = signature.arguments; + if (actualTypes.size() == 1) { + return intermediate->isMatch(actualTypes[0]); + } + } + return matched; +} + +} // namespace io::substrait diff --git a/substrait/function/FunctionLookup.cpp b/substrait/function/FunctionLookup.cpp new file mode 100644 index 00000000..7f24c47a --- /dev/null +++ b/substrait/function/FunctionLookup.cpp @@ -0,0 +1,22 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "substrait/function/FunctionLookup.h" + +namespace io::substrait { + +FunctionImplementationPtr FunctionLookup::lookupFunction( + const FunctionSignature& signature) const { + + const auto& functionImpls = getFunctionImpls(); + auto functionImplsIter = functionImpls.find(signature.name); + if (functionImplsIter != functionImpls.end()) { + for (const auto& candidateFunctionImpl : functionImplsIter->second) { + if (candidateFunctionImpl->tryMatch(signature)) { + return candidateFunctionImpl; + } + } + } + return nullptr; +} + +} // namespace io::substrait diff --git a/substrait/function/tests/CMakeLists.txt b/substrait/function/tests/CMakeLists.txt new file mode 100644 index 00000000..94835efe --- /dev/null +++ b/substrait/function/tests/CMakeLists.txt @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_executable( + substrait_function_test + FunctionLookupTest.cpp) + +add_test( + substrait_function_test + substrait_function_test) + +target_link_libraries( + substrait_function_test + substrait_function + gtest + gtest_main) diff --git a/substrait/function/tests/FunctionLookupTest.cpp b/substrait/function/tests/FunctionLookupTest.cpp new file mode 100644 index 00000000..03232b22 --- /dev/null +++ b/substrait/function/tests/FunctionLookupTest.cpp @@ -0,0 +1,112 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include +#include +#include "substrait/function/FunctionLookup.h" + +using namespace io::substrait; + +class FunctionLookupTest : public ::testing::Test { + protected: + static std::string getExtensionAbsolutePath() { + const std::string absolute_path = __FILE__; + auto const pos = absolute_path.find_last_of('/'); + return absolute_path.substr(0, pos) + + "/../../../third_party/substrait/extensions/"; + } + + void SetUp() override { + ExtensionPtr extension_ = Extension::load(getExtensionAbsolutePath()); + scalarFunctionLookup_ = + std::make_shared(extension_); + aggregateFunctionLookup_ = + std::make_shared(extension_); + } + + void testScalarFunctionLookup( + const FunctionSignature& inputSignature, + const std::string& outputSignature) { + const auto& functionImpl = + scalarFunctionLookup_->lookupFunction(inputSignature); + + ASSERT_TRUE(functionImpl != nullptr); + ASSERT_EQ(functionImpl->signature(), outputSignature); + } + + void testAggregateFunctionLookup( + const FunctionSignature& inputSignature, + const std::string& outputSignature) { + const auto& functionImpl = + aggregateFunctionLookup_->lookupFunction(inputSignature); + + ASSERT_TRUE(functionImpl != nullptr); + ASSERT_EQ(functionImpl->signature(), outputSignature); + } + + private: + FunctionLookupPtr scalarFunctionLookup_; + FunctionLookupPtr aggregateFunctionLookup_; +}; + +TEST_F(FunctionLookupTest, compare_function) { + testScalarFunctionLookup( + {"lt", {TINYINT(), TINYINT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {SMALLINT(), SMALLINT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {INTEGER(), INTEGER()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {BIGINT(), BIGINT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup({"lt", {FLOAT(), FLOAT()}, BOOL()}, "lt:any1_any1"); + + testScalarFunctionLookup( + {"lt", {DOUBLE(), DOUBLE()}, BOOL()}, "lt:any1_any1"); + testScalarFunctionLookup( + {"between", {TINYINT(), TINYINT(), TINYINT()}, BOOL()}, + "between:any1_any1_any1"); +} + +TEST_F(FunctionLookupTest, arithmetic_function) { + testScalarFunctionLookup( + {"add", {TINYINT(), TINYINT()}, TINYINT()}, "add:opt_i8_i8"); + + testScalarFunctionLookup( + {"divide", + { + FLOAT(), + FLOAT(), + }, + FLOAT()}, + "divide:opt_opt_opt_fp32_fp32"); +} + +TEST_F(FunctionLookupTest, aggregate) { + // for intermediate type + testAggregateFunctionLookup( + {"avg", {STRUCT({DOUBLE(), BIGINT()})}, FLOAT()}, "avg:opt_fp32"); +} + +TEST_F(FunctionLookupTest, logical) { + testScalarFunctionLookup({"and", {}, BOOL()}, "and:bool"); + testScalarFunctionLookup({"and", {BOOL()}, BOOL()}, "and:bool"); + testScalarFunctionLookup({"and", {BOOL(), BOOL()}, BOOL()}, "and:bool"); + + testScalarFunctionLookup({"or", {BOOL(), BOOL()}, BOOL()}, "or:bool"); + testScalarFunctionLookup({"not", {BOOL()}, BOOL()}, "not:bool"); + testScalarFunctionLookup({"xor", {BOOL(), BOOL()}, BOOL()}, "xor:bool_bool"); +} + +TEST_F(FunctionLookupTest, string_function) { + testScalarFunctionLookup( + {"like", {STRING(), STRING()}, BOOL()}, "like:opt_str_str"); + testScalarFunctionLookup( + {"like", {VARCHAR(3), VARCHAR(4)}, BOOL()}, + "like:opt_vchar_vchar"); + testScalarFunctionLookup( + {"substring", {STRING(), INTEGER(), INTEGER()}, STRING()}, + "substring:str_i32_i32"); +} diff --git a/substrait/type/CMakeLists.txt b/substrait/type/CMakeLists.txt new file mode 100644 index 00000000..3aeb3e96 --- /dev/null +++ b/substrait/type/CMakeLists.txt @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 + +set(TYPE_SRCS + Type.cpp) + +add_library(substrait_type ${TYPE_SRCS}) + +target_link_libraries( + substrait_type + substrait_common) + +if (${BUILD_TESTING}) + add_subdirectory(tests) +endif () \ No newline at end of file diff --git a/substrait/type/Type.cpp b/substrait/type/Type.cpp new file mode 100644 index 00000000..f8b676ac --- /dev/null +++ b/substrait/type/Type.cpp @@ -0,0 +1,507 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include +#include +#include +#include "substrait/type/Type.h" +#include "substrait/common/Exceptions.h" + +namespace io::substrait { + +namespace { + +size_t findNextComma(const std::string& str, size_t start) { + int cnt = 0; + for (auto i = start; i < str.size(); i++) { + if (str[i] == '<') { + cnt++; + } else if (str[i] == '>') { + cnt--; + } else if (cnt == 0 && str[i] == ',') { + return i; + } + } + + return std::string::npos; +} + +} // namespace + +ParameterizedTypePtr ParameterizedType::decode(const std::string& rawType) { + std::string matchingType = rawType; + std::transform( + matchingType.begin(), + matchingType.end(), + matchingType.begin(), + [](unsigned char c) { return std::tolower(c); }); + + const auto& questionMaskPos = matchingType.find_last_of('?'); + + bool nullable = questionMaskPos != std::string::npos; + + const auto& leftAngleBracketPos = matchingType.find('<'); + if (leftAngleBracketPos == std::string::npos) { + // deal with type and with a question mask like "i32?". + const auto& baseType = nullable + ? matchingType = matchingType.substr(0, questionMaskPos) + : matchingType; + + if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>( + nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>( + nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>( + nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared>(nullable); + } else { + return std::make_shared(rawType); + } + } else { + const auto& rightAngleBracketPos = rawType.rfind('>'); + const auto& baseTypePos = nullable + ? std::min(leftAngleBracketPos, questionMaskPos) + : leftAngleBracketPos; + + const auto& baseType = matchingType.substr(0, baseTypePos); + + std::vector nestedTypes; + auto prevPos = leftAngleBracketPos + 1; + auto commaPos = findNextComma(rawType, prevPos); + while (commaPos != std::string::npos) { + auto token = rawType.substr(prevPos, commaPos - prevPos); + nestedTypes.emplace_back(decode(token)); + prevPos = commaPos + 1; + commaPos = findNextComma(rawType, prevPos); + } + auto token = rawType.substr(prevPos, rightAngleBracketPos - prevPos); + nestedTypes.emplace_back(decode(token)); + + if (TypeTraits::typeString == baseType) { + return std::make_shared(nestedTypes[0], nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared( + nestedTypes[0], nestedTypes[1], nullable); + } else if (TypeTraits::typeString == baseType) { + return std::make_shared(nestedTypes, nullable); + } else if (TypeTraits::typeString == baseType) { + StringLiteralPtr precision = + std::dynamic_pointer_cast(nestedTypes[0]); + StringLiteralPtr scale = + std::dynamic_pointer_cast(nestedTypes[1]); + return std::make_shared(precision, scale, nullable); + } else if (TypeTraits::typeString == baseType) { + auto length = + std::dynamic_pointer_cast(nestedTypes[0]); + return std::make_shared(length, nullable); + } else if (TypeTraits::typeString == baseType) { + auto length = + std::dynamic_pointer_cast(nestedTypes[0]); + return std::make_shared(length, nullable); + } else if (TypeTraits::typeString == baseType) { + auto length = + std::dynamic_pointer_cast(nestedTypes[0]); + return std::make_shared(length, nullable); + } else { + SUBSTRAIT_UNSUPPORTED("Unsupported type: " + rawType); + } + } +} + +std::string Decimal::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<" << precision_ << "," << scale_ << ">"; + return sign.str(); +} + +bool Decimal::isMatch( + const std::shared_ptr& type) const { + if (auto decimalType = std::dynamic_pointer_cast(type)) { + return TypeBase::isMatch(type) && precision_ == decimalType->precision() && + scale_ == decimalType->scale(); + } + + return false; +} + +std::string FixedBinary::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<" << length() << ">"; + return sign.str(); +} +bool FixedBinary::isMatch( + const std::shared_ptr& type) const { + if (auto fBinaryType = std::dynamic_pointer_cast(type)) { + return TypeBase::isMatch(type) && length_ == fBinaryType->length(); + } + + return false; +} + +std::string FixedChar::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<" << length() << ">"; + return sign.str(); +} + +bool FixedChar::isMatch( + const std::shared_ptr& type) const { + if (auto fBinaryType = std::dynamic_pointer_cast(type)) { + return TypeBase::isMatch(type) && length_ == fBinaryType->length(); + } + return false; +} + +std::string Varchar::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<" << length() << ">"; + return sign.str(); +} + +bool Varchar::isMatch( + const std::shared_ptr& type) const { + if (auto varcharType = std::dynamic_pointer_cast(type)) { + return TypeBase::isMatch(type) && length_ == varcharType->length(); + } + return false; +} + +std::string List::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<" << elementType_->signature() << ">"; + return sign.str(); +} + +bool List::isMatch(const std::shared_ptr& type) const { + if (auto listType = std::dynamic_pointer_cast(type)) { + return TypeBase::isMatch(type) && + elementType()->isMatch(listType->elementType()); + } + return false; +} + +std::string Struct::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<"; + for (auto it = children_.begin(); it != children_.end(); ++it) { + const auto& typeSign = (*it)->signature(); + if (it == children_.end() - 1) { + sign << typeSign; + } else { + sign << typeSign << ","; + } + } + sign << ">"; + return sign.str(); +} +bool Struct::isMatch( + const std::shared_ptr& type) const { + if (auto structType = std::dynamic_pointer_cast(type)) { + bool sameSize = structType->children_.size() == children_.size(); + if (sameSize) { + for (int i = 0; i < children_.size(); i++) { + if (!children_[i]->isMatch(structType->children_[i])) { + return false; + } + } + return true; + } + } + return false; +} + +std::string Map::signature() const { + std::stringstream sign; + sign << TypeBase::signature(); + sign << "<"; + sign << keyType()->signature(); + sign << ","; + sign << valueType()->signature(); + sign << ">"; + return sign.str(); +} + +bool Map::isMatch(const std::shared_ptr& type) const { + if (auto mapType = std::dynamic_pointer_cast(type)) { + return TypeBase::isMatch(type) && keyType()->isMatch(mapType->keyType()) && + valueType()->isMatch(mapType->valueType()); + } + return false; +} + +std::string ParameterizedFixedBinary::signature() const { + std::stringstream sign; + sign << TypeTraits::signature; + sign << "<" << length_->value() << ">"; + return sign.str(); +} + +bool ParameterizedFixedBinary::isMatch( + const std::shared_ptr& type) const { + if (auto parameterizedFixedBinary = + std::dynamic_pointer_cast(type)) { + return length()->isMatch(parameterizedFixedBinary->length()) && + nullMatch(type); + } + + return false; +} + +std::string ParameterizedDecimal::signature() const { + std::stringstream sign; + sign << TypeTraits::signature; + sign << "<" << precision_->value() << "," << scale_->value() << ">"; + return sign.str(); +} + +bool ParameterizedDecimal::isMatch( + const std::shared_ptr& type) const { + if (auto decimal = std::dynamic_pointer_cast(type)) { + return nullMatch(type); + } + + return false; +} + +std::string ParameterizedFixedChar::signature() const { + std::stringstream sign; + sign << TypeTraits::signature; + sign << "<" << length_->value() << ">"; + return sign.str(); +} +bool ParameterizedFixedChar::isMatch( + const std::shared_ptr& type) const { + if (auto fixedChar = std::dynamic_pointer_cast(type)) { + return nullMatch(type); + } + + return false; +} + +std::string ParameterizedVarchar::signature() const { + std::stringstream sign; + sign << TypeTraits::signature; + sign << "<" << length_->value() << ">"; + return sign.str(); +} + +bool ParameterizedVarchar::isMatch( + const std::shared_ptr& type) const { + if (auto varchar = std::dynamic_pointer_cast(type)) { + return nullMatch(type); + } + + return false; +} + +std::string ParameterizedList::signature() const { + std::stringstream sign; + sign << TypeTraits::signature; + sign << "<" << elementType()->signature() << ">"; + return sign.str(); +} + +bool ParameterizedList::isMatch( + const std::shared_ptr& type) const { + if (auto list = std::dynamic_pointer_cast(type)) { + return elementType()->isMatch(list->elementType()) && nullMatch(type); + } + + return false; +} + +std::string ParameterizedStruct::signature() const { + std::stringstream sign; + sign << TypeTraits::signature; + sign << "<"; + for (auto it = children_.begin(); it != children_.end(); ++it) { + const auto& typeSign = (*it)->signature(); + if (it == children_.end() - 1) { + sign << typeSign; + } else { + sign << typeSign << ","; + } + } + sign << ">"; + return sign.str(); +} + +bool ParameterizedStruct::isMatch( + const std::shared_ptr& type) const { + if (auto structType = std::dynamic_pointer_cast(type)) { + bool sameSize = structType->children().size() == children_.size(); + if (sameSize) { + for (int i = 0; i < children_.size(); i++) { + if (!children_[i]->isMatch(structType->children()[i])) { + return false; + } + } + return nullMatch(type); + } + } + return false; +} + +std::string ParameterizedMap::signature() const { + std::stringstream sign; + sign << TypeTraits::signature; + sign << "<"; + sign << keyType()->signature(); + sign << ","; + sign << valueType()->signature(); + sign << ">"; + return sign.str(); +} + +bool ParameterizedMap::isMatch( + const std::shared_ptr& type) const { + if (auto mapType = std::dynamic_pointer_cast(type)) { + return keyType()->isMatch(mapType->keyType()) && + valueType()->isMatch(mapType->valueType()) && nullMatch(type); + } + return false; +} + +std::shared_ptr> BOOL() { + return std::make_shared>(false); +} + +std::shared_ptr> TINYINT() { + return std::make_shared>(false); +} + +std::shared_ptr> SMALLINT() { + return std::make_shared>(false); +} + +std::shared_ptr> INTEGER() { + return std::make_shared>(false); +} + +std::shared_ptr> BIGINT() { + return std::make_shared>(false); +} +std::shared_ptr> FLOAT() { + return std::make_shared>(false); +} + +std::shared_ptr> DOUBLE() { + return std::make_shared>(false); +} + +std::shared_ptr> STRING() { + return std::make_shared>(false); +} + +std::shared_ptr> BINARY() { + return std::make_shared>(false); +} + +std::shared_ptr> TIMESTAMP() { + return std::make_shared>(false); +} + +std::shared_ptr> DATE() { + return std::make_shared>(false); +} + +std::shared_ptr> TIME() { + return std::make_shared>(false); +} + +std::shared_ptr> INTERVAL_YEAR() { + return std::make_shared>(false); +} + +std::shared_ptr> INTERVAL_DAY() { + return std::make_shared>(false); +} + +std::shared_ptr> TIMESTAMP_TZ() { + return std::make_shared>(false); +} + +std::shared_ptr> UUID() { + return std::make_shared>(false); +} + +std::shared_ptr DECIMAL(int precision, int scale) { + return std::make_shared(precision, scale, false); +} + +std::shared_ptr VARCHAR(int len) { + return std::make_shared(len, false); +} +std::shared_ptr FCHAR(int len) { + return std::make_shared(len, false); +} + +std::shared_ptr FIXED_BINARY(int len) { + return std::make_shared(len, false); +} + +std::shared_ptr LIST(const TypePtr& elementType) { + return std::make_shared(elementType, false); +} + +std::shared_ptr MAP( + const TypePtr& keyType, + const TypePtr& valueType) { + return std::make_shared(keyType, valueType, false); +} + +std::shared_ptr STRUCT(const std::vector& children) { + return std::make_shared(children, false); +} + +std::shared_ptr FIXED_CHAR(int len) { + return std::make_shared(len); +} + +bool StringLiteral::isMatch( + const std::shared_ptr& type) const { + if (isWildcard()) { + return true; + } else { + if (auto stringLiteral = + std::dynamic_pointer_cast(type)) { + return value_ == stringLiteral->value_; + } + return false; + } +} + +} // namespace io::substrait diff --git a/substrait/type/tests/CMakeLists.txt b/substrait/type/tests/CMakeLists.txt new file mode 100644 index 00000000..781cd5cd --- /dev/null +++ b/substrait/type/tests/CMakeLists.txt @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_executable( + substrait_type_test + TypeTest.cpp) + +add_test( + substrait_type_test + substrait_type_test) + +target_link_libraries( + substrait_type_test + substrait_type + gtest + gtest_main) diff --git a/substrait/type/tests/TypeTest.cpp b/substrait/type/tests/TypeTest.cpp new file mode 100644 index 00000000..2e30a14d --- /dev/null +++ b/substrait/type/tests/TypeTest.cpp @@ -0,0 +1,158 @@ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include +#include "substrait/type/Type.h" + +using namespace io::substrait; + +class TypeTest : public ::testing::Test { + protected: + template + void testDecode(const std::string& rawType, const std::string& signature) { + const auto& type = ParameterizedType::decode(rawType); + ASSERT_TRUE(type->kind() == kind); + ASSERT_EQ(type->signature(), signature); + } + + static void testType( + const ParameterizedTypePtr& type, + TypeKind kind, + const std::string& signature) { + ASSERT_EQ(type->kind(), kind); + ASSERT_EQ(type->signature(), signature); + } + + template + void testDecode( + const std::string& rawType, + const std::function&)>& + typeCallBack) { + const auto& type = ParameterizedType::decode(rawType); + if (typeCallBack) { + typeCallBack(std::dynamic_pointer_cast(type)); + } + } +}; + +TEST_F(TypeTest, typeCreator) { + testType(BOOL(), TypeKind::kBool, "bool"); + testType(TINYINT(), TypeKind::kI8, "i8"); + testType(SMALLINT(), TypeKind::kI16, "i16"); + testType(INTEGER(), TypeKind::kI32, "i32"); + testType(BIGINT(), TypeKind::kI64, "i64"); + testType(FLOAT(), TypeKind::kFp32, "fp32"); + testType(DOUBLE(), TypeKind::kFp64, "fp64"); + testType(BINARY(), TypeKind::kBinary, "vbin"); + testType(TIMESTAMP(), TypeKind::kTimestamp, "ts"); + testType(STRING(), TypeKind::kString, "str"); + testType(TIMESTAMP_TZ(), TypeKind::kTimestampTz, "tstz"); + testType(DATE(), TypeKind::kDate, "date"); + testType(TIME(), TypeKind::kTime, "time"); + testType(INTERVAL_DAY(), TypeKind::kIntervalDay, "iday"); + testType(INTERVAL_YEAR(), TypeKind::kIntervalYear, "iyear"); + testType(UUID(), TypeKind::kUuid, "uuid"); + testType(FIXED_CHAR(12), TypeKind::kFixedChar, "fchar<12>"); + testType(FIXED_BINARY(12), TypeKind::kFixedBinary, "fbin<12>"); + testType(VARCHAR(12), TypeKind::kVarchar, "vchar<12>"); + testType(DECIMAL(12,23), TypeKind::kDecimal, "dec<12,23>"); + testType(LIST(FLOAT()), TypeKind::kList, "list"); + testType(MAP(STRING(),FLOAT()), TypeKind::kMap, "map"); + testType(STRUCT({STRING(), FLOAT()}), TypeKind::kStruct, "struct"); +} + +TEST_F(TypeTest, decodeTest) { + testDecode("i32?", "i32"); + testDecode("BOOLEAN", "bool"); + testDecode("boolean", "bool"); + testDecode("i8", "i8"); + testDecode("i16", "i16"); + testDecode("i32", "i32"); + testDecode("i64", "i64"); + testDecode("fp32", "fp32"); + testDecode("fp64", "fp64"); + testDecode("binary", "vbin"); + testDecode("timestamp", "ts"); + testDecode("string", "str"); + testDecode("timestamp_tz", "tstz"); + testDecode("date", "date"); + testDecode("time", "time"); + testDecode("interval_day", "iday"); + testDecode("interval_year", "iyear"); + testDecode("uuid", "uuid"); + + testDecode( + "fixedchar", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->length()->value(), "L1"); + ASSERT_EQ(typePtr->signature(), "fchar"); + }); + + testDecode( + "fixedbinary", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->length()->value(), "L1"); + ASSERT_EQ(typePtr->signature(), "fbin"); + }); + + testDecode( + "varchar", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "vchar"); + ASSERT_EQ(typePtr->length()->value(), "L1"); + }); + + testDecode( + "decimal", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "dec"); + ASSERT_EQ(typePtr->precision()->value(), "P"); + ASSERT_EQ(typePtr->scale()->value(), "S"); + }); + + testDecode( + "struct", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "struct"); + }); + + testDecode( + "struct>", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "struct>"); + }); + + testDecode( + "list", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "list"); + }); + testDecode( + "LIST?", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "list"); + }); + + testDecode( + "map", + [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "map"); + }); + + testDecode( + "any1", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "any1"); + ASSERT_TRUE(typePtr->isWildcard()); + }); + + testDecode( + "any", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "any"); + ASSERT_TRUE(typePtr->isWildcard()); + }); + + testDecode( + "T", [](const std::shared_ptr& typePtr) { + ASSERT_EQ(typePtr->signature(), "T"); + ASSERT_TRUE(typePtr->isWildcard()); + }); +} diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt new file mode 100644 index 00000000..5780a71d --- /dev/null +++ b/third_party/CMakeLists.txt @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 + +add_subdirectory(fmt) +add_subdirectory(googletest) + +set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "Enable testing") +include_directories(yaml-cpp/include) +add_subdirectory(yaml-cpp) diff --git a/third_party/fmt b/third_party/fmt new file mode 160000 index 00000000..9e8b86fd --- /dev/null +++ b/third_party/fmt @@ -0,0 +1 @@ +Subproject commit 9e8b86fd2d9806672cc73133d21780dd182bfd24 diff --git a/third_party/googletest b/third_party/googletest new file mode 160000 index 00000000..d1a0039b --- /dev/null +++ b/third_party/googletest @@ -0,0 +1 @@ +Subproject commit d1a0039b97291dd1dc14f123b906bb7622ffe07c diff --git a/third_party/substrait b/third_party/substrait new file mode 160000 index 00000000..f3f6bdc9 --- /dev/null +++ b/third_party/substrait @@ -0,0 +1 @@ +Subproject commit f3f6bdc947e689e800279666ff33f118e42d2146 diff --git a/third_party/yaml-cpp b/third_party/yaml-cpp new file mode 160000 index 00000000..c90c08cc --- /dev/null +++ b/third_party/yaml-cpp @@ -0,0 +1 @@ +Subproject commit c90c08ccc9a08abcca609064fb9a856dfdbbb7b4