diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 2342722c0bb14..1d23774a51692 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.{SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.variant.ParseJson import org.apache.spark.sql.internal.SqlApiConf @@ -46,7 +47,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputEntry - List of all input entries that need to be generated * @param collationType - Flag defining collation type to use - * @return + * @return - List of data generated for expression instance creation */ def generateData( inputEntry: Seq[Any], @@ -54,23 +55,11 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputEntry.map(generateSingleEntry(_, collationType)) } - /** - * Helper function to generate single entry of data as a string. - * @param inputEntry - Single input entry that requires generation - * @param collationType - Flag defining collation type to use - * @return - */ - def generateDataAsStrings( - inputEntry: Seq[AbstractDataType], - collationType: CollationType): Seq[Any] = { - inputEntry.map(generateInputAsString(_, collationType)) - } - /** * Helper function to generate single entry of data. * @param inputEntry - Single input entry that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Single input entry data */ def generateSingleEntry( inputEntry: Any, @@ -100,7 +89,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input literal type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Literal/Expression containing expression ready for evaluation */ def generateLiterals( inputType: AbstractDataType, @@ -116,6 +105,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => Literal(true) case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59")) + case DecimalType => Literal((new Decimal).set(5)) case _: DecimalType => Literal((new Decimal).set(5)) case _: DoubleType => Literal(5.0) case IntegerType | NumericType | IntegralType => Literal(5) @@ -158,11 +148,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType => val key = generateLiterals(StringTypeAnyCollation, collationType) val value = generateLiterals(StringTypeAnyCollation, collationType) - Literal.create(Map(key -> value)) + CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) val value = generateLiterals(valueType, collationType) - Literal.create(Map(key -> value)) + CreateMap(Seq(key, value)) + case AbstractMapType(keyType, valueType) => + val key = generateLiterals(keyType, collationType) + val value = generateLiterals(valueType, collationType) + CreateMap(Seq(key, value)) case StructType => CreateNamedStruct( Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), @@ -174,7 +168,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation of a input ready for SQL query */ def generateInputAsString( inputType: AbstractDataType, @@ -189,6 +183,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" + case DecimalType => "5.0" case _: DecimalType => "5.0" case _: DoubleType => "5.0" case IntegerType | NumericType | IntegralType => "5" @@ -221,6 +216,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" + case AbstractMapType(keyType, valueType) => + "map(" + generateInputAsString(keyType, collationType) + ", " + + generateInputAsString(valueType, collationType) + ")" case StructType => "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" @@ -234,7 +232,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation for SQL query of a inputType */ def generateInputTypeAsStrings( inputType: AbstractDataType, @@ -244,6 +242,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case BinaryType => "BINARY" case BooleanType => "BOOLEAN" case _: DatetimeType => "DATE" + case DecimalType => "DECIMAL(2, 1)" case _: DecimalType => "DECIMAL(2, 1)" case _: DoubleType => "DOUBLE" case IntegerType | NumericType | IntegralType => "INT" @@ -275,6 +274,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" + case AbstractMapType(keyType, valueType) => + "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + + generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => "struct hasStringType(elementType) case TypeCollection(typeCollection) => typeCollection.exists(hasStringType) - case StructType => true case StructType(fields) => fields.exists(sf => hasStringType(sf.dataType)) case _ => false } @@ -310,7 +311,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * Helper function to replace expected parameters with expected input types. * @param inputTypes - Input types generated by ExpectsInputType.inputTypes * @param params - Parameters that are read from expression info - * @return + * @return - List of parameters where Expressions are replaced with input types */ def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = { (inputTypes, params) match { @@ -325,7 +326,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi /** * Helper method to extract relevant expressions that can be walked over. - * @return + * @return - (List of relevant expressions that expect input, List of expressions to skip) */ def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = { var expressionCounter = 0 @@ -384,6 +385,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi (funInfos, toSkip) } + /** + * Helper method to extract relevant expressions that can be walked over but are built with + * expression builder. + * + * @return - (List of expressions that are relevant builders, List of expressions to skip) + */ + def extractRelevantBuilders(): (Array[ExpressionInfo], List[String]) = { + var builderExpressionCounter = 0 + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => + spark.sessionState.catalog.lookupFunctionInfo(funcId) + }.filter(funInfo => { + // make sure that there is a constructor. + val cl = Utils.classForName(funInfo.getClassName) + cl.isAssignableFrom(classOf[ExpressionBuilder]) + }).filter(funInfo => { + builderExpressionCounter = builderExpressionCounter + 1 + val cl = Utils.classForName(funInfo.getClassName) + val method = cl.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + var input: Seq[Expression] = Seq.empty + var i = 0 + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] + } + catch { + case _: Exception => i = i + 1 + } + } + if (i == 10) false + else true + }).toArray + + logInfo("Total number of expression that are built: " + builderExpressionCounter) + logInfo("Number of extracted expressions of relevance: " + funInfos.length) + + (funInfos, List()) + } + /** * Helper function to generate string of an expression suitable for execution. * @param expr - Expression that needs to be converted @@ -441,10 +483,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for expression evaluation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) @@ -526,10 +594,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for codeGen generation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType)