Skip to content

Commit

Permalink
[feat][test](nereids)(vec) support the sinh, asinh, acosh and atanh f…
Browse files Browse the repository at this point in the history
…unction and test(apache#48203)
  • Loading branch information
ChenMiaoi committed Feb 28, 2025
1 parent 2835a52 commit cb6f78c
Show file tree
Hide file tree
Showing 10 changed files with 468 additions and 1 deletion.
28 changes: 28 additions & 0 deletions be/src/vec/functions/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ struct AcosName {
using FunctionAcos =
FunctionMathUnaryAlwayNullable<UnaryFunctionPlainAlwayNullable<AcosName, std::acos>>;

struct AcoshName {
static constexpr auto name = "acosh";
static constexpr bool is_invalid_input(Float64 x) { return x < 1; }
};
using FunctionAcosh =
FunctionMathUnaryAlwayNullable<UnaryFunctionPlainAlwayNullable<AcoshName, std::acosh>>;

struct AsinName {
static constexpr auto name = "asin";
// https://dev.mysql.com/doc/refman/8.4/en/mathematical-functions.html#function_asin
Expand All @@ -68,11 +75,23 @@ struct AsinName {
using FunctionAsin =
FunctionMathUnaryAlwayNullable<UnaryFunctionPlainAlwayNullable<AsinName, std::asin>>;

struct AsinhName {
static constexpr auto name = "asinh";
};
using FunctionAsinh = FunctionMathUnary<UnaryFunctionPlain<AsinhName, std::asinh>>;

struct AtanName {
static constexpr auto name = "atan";
};
using FunctionAtan = FunctionMathUnary<UnaryFunctionPlain<AtanName, std::atan>>;

struct AtanhName {
static constexpr auto name = "atanh";
static constexpr bool is_invalid_input(Float64 x) { return x <= -1 || x >= 1; }
};
using FunctionAtanh =
FunctionMathUnaryAlwayNullable<UnaryFunctionPlainAlwayNullable<AtanhName, std::atanh>>;

template <typename A, typename B>
struct Atan2Impl {
using ResultType = double;
Expand Down Expand Up @@ -247,6 +266,11 @@ struct UnaryFunctionPlainSin {

using FunctionSin = FunctionMathUnary<UnaryFunctionPlainSin>;

struct SinhName {
static constexpr auto name = "sinh";
};
using FunctionSinh = FunctionMathUnary<UnaryFunctionPlain<SinhName, std::sinh>>;

struct SqrtName {
static constexpr auto name = "sqrt";
// https://dev.mysql.com/doc/refman/8.4/en/mathematical-functions.html#function_sqrt
Expand Down Expand Up @@ -427,8 +451,11 @@ class FunctionNormalCdf : public IFunction {
// so mush. Split it to speed up compile time in the future
void register_function_math(SimpleFunctionFactory& factory) {
factory.register_function<FunctionAcos>();
factory.register_function<FunctionAcosh>();
factory.register_function<FunctionAsin>();
factory.register_function<FunctionAsinh>();
factory.register_function<FunctionAtan>();
factory.register_function<FunctionAtanh>();
factory.register_function<FunctionAtan2>();
factory.register_function<FunctionCos>();
factory.register_function<FunctionCosh>();
Expand All @@ -445,6 +472,7 @@ void register_function_math(SimpleFunctionFactory& factory) {
factory.register_function<FunctionNegative>();
factory.register_function<FunctionPositive>();
factory.register_function<FunctionSin>();
factory.register_function<FunctionSinh>();
factory.register_function<FunctionSqrt>();
factory.register_alias("sqrt", "dsqrt");
factory.register_function<FunctionCbrt>();
Expand Down
56 changes: 56 additions & 0 deletions be/test/vec/function/function_math_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ TEST(MathFunctionTest, acos_test) {
static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, acosh_test) {
std::string func_name = "acosh"; // acosh(x) = ln(x + sqrt(x^2 - 1)), x ∈ [1, +∞)

InputTypeSet input_types = {TypeIndex::Float64};

DataSet data_set = {{{1.0}, 0.0},
{{2.0}, 1.3169578969248168},
{{3.0}, 1.7627471740390861},
{{10.0}, 2.9932228461263808},
{{100.0}, 5.2982923656104850}};

static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, asin_test) {
std::string func_name = "asin"; //[-1,1] -->[-pi_2, pi_2]

Expand All @@ -54,6 +68,20 @@ TEST(MathFunctionTest, asin_test) {
static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, asinh_test) {
std::string func_name = "asinh"; // asinh(x) = ln(x + sqrt(x^2 + 1)), x ∈ (-∞, +∞)

InputTypeSet input_types = {TypeIndex::Float64};

DataSet data_set = {{{0.0}, 0.0},
{{1.0}, 0.8813735870195430},
{{-1.0}, -0.8813735870195430},
{{2.0}, 1.4436354751788103},
{{-2.0}, -1.4436354751788103}};

static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, atan_test) {
std::string func_name = "atan"; //[-,+] -->(pi_2,pi_2)

Expand All @@ -67,6 +95,20 @@ TEST(MathFunctionTest, atan_test) {
static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, atanh_test) {
std::string func_name = "atanh"; // atanh(x) = 0.5 * ln((1 + x) / (1 - x)), x ∈ (-1, 1)

InputTypeSet input_types = {TypeIndex::Float64};

DataSet data_set = {{{0.0}, 0.0},
{{0.5}, 0.5493061443340549},
{{-0.5}, -0.5493061443340549},
{{0.9}, 1.4722194895832204},
{{-0.9}, -1.4722194895832204}};

static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, cos_test) {
std::string func_name = "cos";

Expand Down Expand Up @@ -95,6 +137,20 @@ TEST(MathFunctionTest, sin_test) {
static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, sinh_test) {
std::string func_name = "sinh"; // sinh(x) = (e^x - e^(-x)) / 2, x ∈ (-∞, +∞)

InputTypeSet input_types = {TypeIndex::Float64};

DataSet data_set = {{{0.0}, 0.0},
{{1.0}, 1.1752011936438014},
{{-1.0}, -1.1752011936438014},
{{2.0}, 3.6268604078470186},
{{-2.0}, -3.6268604078470186}};

static_cast<void>(check_function<DataTypeFloat64, true>(func_name, input_types, data_set));
}

TEST(MathFunctionTest, sqrt_test) {
std::string func_name = "sqrt"; //sqrt(x) x>=0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.trees.expressions.Regexp;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Acos;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Acosh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesDecrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AesEncrypt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AppendTrailingCharIfAbsent;
Expand Down Expand Up @@ -75,9 +76,11 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Ascii;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Asin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Asinh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AssertTrue;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Atan;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Atan2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Atanh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AutoPartitionName;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Bin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitCount;
Expand Down Expand Up @@ -382,6 +385,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sha2;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sign;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sinh;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sleep;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Sm3sum;
Expand Down Expand Up @@ -498,6 +502,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
public final List<ScalarFunc> scalarFunctions = ImmutableList.of(
scalar(Abs.class, "abs"),
scalar(Acos.class, "acos"),
scalar(Acosh.class, "acosh"),
scalar(AesDecrypt.class, "aes_decrypt"),
scalar(AesEncrypt.class, "aes_encrypt"),
scalar(AppendTrailingCharIfAbsent.class, "append_trailing_char_if_absent"),
Expand Down Expand Up @@ -552,8 +557,10 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArraysOverlap.class, "arrays_overlap"),
scalar(Ascii.class, "ascii"),
scalar(Asin.class, "asin"),
scalar(Asinh.class, "asinh"),
scalar(AssertTrue.class, "assert_true"),
scalar(Atan.class, "atan"),
scalar(Atanh.class, "atanh"),
scalar(Atan2.class, "atan2"),
scalar(AutoPartitionName.class, "auto_partition_name"),
scalar(Bin.class, "bin"),
Expand Down Expand Up @@ -877,6 +884,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Sha2.class, "sha2"),
scalar(Sign.class, "sign"),
scalar(Sin.class, "sin"),
scalar(Sinh.class, "sinh"),
scalar(Sleep.class, "sleep"),
scalar(StructElement.class, "struct_element"),
scalar(Sm3.class, "sm3"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction Acosh
*/
public class Acosh extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable, PropagateNullLiteral {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
);

/**
* constructor with 1 argument.
*/
public Acosh(Expression arg) {
super("acosh", arg);
}

/**
* withChildren.
*/
@Override
public Acosh withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Acosh(children.get(0));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAcosh(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction asinh
*/
public class Asinh extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
);

/**
* constructor with 1 argument.
*/
public Asinh(Expression arg) {
super("asinh", arg);
}

/**
* withChildren.
*/
@Override
public Asinh withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Asinh(children.get(0));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAsinh(this, context);
}
}
Loading

0 comments on commit cb6f78c

Please sign in to comment.