Skip to content

Commit

Permalink
[SPARK-48280][SQL][FOLLOW-UP] Add expressions that are built via expr…
Browse files Browse the repository at this point in the history
…essionBuilder to Expression Walker

### What changes were proposed in this pull request?
Addition of new expressions to expression walker. This PR also improves descriptions of methods in the Suite.

### Why are the changes needed?
It was noticed while debugging that startsWith, endsWith and contains are not tested with this suite and these expressions represent core of collation testing.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Test only.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48162 from mihailom-db/expressionwalkerfollowup.

Authored-by: Mihailo Milosevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mihailom-db authored and cloud-fan committed Sep 19, 2024
1 parent 492d1b1 commit ac34f1d
Showing 1 changed file with 121 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp

import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.variant.ParseJson
import org.apache.spark.sql.internal.SqlApiConf
Expand All @@ -46,31 +47,19 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputEntry - List of all input entries that need to be generated
* @param collationType - Flag defining collation type to use
* @return
* @return - List of data generated for expression instance creation
*/
def generateData(
inputEntry: Seq[Any],
collationType: CollationType): Seq[Any] = {
inputEntry.map(generateSingleEntry(_, collationType))
}

/**
* Helper function to generate single entry of data as a string.
* @param inputEntry - Single input entry that requires generation
* @param collationType - Flag defining collation type to use
* @return
*/
def generateDataAsStrings(
inputEntry: Seq[AbstractDataType],
collationType: CollationType): Seq[Any] = {
inputEntry.map(generateInputAsString(_, collationType))
}

/**
* Helper function to generate single entry of data.
* @param inputEntry - Single input entry that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - Single input entry data
*/
def generateSingleEntry(
inputEntry: Any,
Expand Down Expand Up @@ -100,7 +89,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputType - Single input literal type that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - Literal/Expression containing expression ready for evaluation
*/
def generateLiterals(
inputType: AbstractDataType,
Expand All @@ -116,6 +105,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
}
case BooleanType => Literal(true)
case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59"))
case DecimalType => Literal((new Decimal).set(5))
case _: DecimalType => Literal((new Decimal).set(5))
case _: DoubleType => Literal(5.0)
case IntegerType | NumericType | IntegralType => Literal(5)
Expand Down Expand Up @@ -158,11 +148,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case MapType =>
val key = generateLiterals(StringTypeAnyCollation, collationType)
val value = generateLiterals(StringTypeAnyCollation, collationType)
Literal.create(Map(key -> value))
CreateMap(Seq(key, value))
case MapType(keyType, valueType, _) =>
val key = generateLiterals(keyType, collationType)
val value = generateLiterals(valueType, collationType)
Literal.create(Map(key -> value))
CreateMap(Seq(key, value))
case AbstractMapType(keyType, valueType) =>
val key = generateLiterals(keyType, collationType)
val value = generateLiterals(valueType, collationType)
CreateMap(Seq(key, value))
case StructType =>
CreateNamedStruct(
Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType),
Expand All @@ -174,7 +168,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputType - Single input type that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - String representation of a input ready for SQL query
*/
def generateInputAsString(
inputType: AbstractDataType,
Expand All @@ -189,6 +183,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
}
case BooleanType => "True"
case _: DatetimeType => "date'2016-04-08'"
case DecimalType => "5.0"
case _: DecimalType => "5.0"
case _: DoubleType => "5.0"
case IntegerType | NumericType | IntegralType => "5"
Expand Down Expand Up @@ -221,6 +216,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case MapType(keyType, valueType, _) =>
"map(" + generateInputAsString(keyType, collationType) + ", " +
generateInputAsString(valueType, collationType) + ")"
case AbstractMapType(keyType, valueType) =>
"map(" + generateInputAsString(keyType, collationType) + ", " +
generateInputAsString(valueType, collationType) + ")"
case StructType =>
"named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) +
", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")"
Expand All @@ -234,7 +232,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
*
* @param inputType - Single input type that requires generation
* @param collationType - Flag defining collation type to use
* @return
* @return - String representation for SQL query of a inputType
*/
def generateInputTypeAsStrings(
inputType: AbstractDataType,
Expand All @@ -244,6 +242,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case BinaryType => "BINARY"
case BooleanType => "BOOLEAN"
case _: DatetimeType => "DATE"
case DecimalType => "DECIMAL(2, 1)"
case _: DecimalType => "DECIMAL(2, 1)"
case _: DoubleType => "DOUBLE"
case IntegerType | NumericType | IntegralType => "INT"
Expand Down Expand Up @@ -275,6 +274,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case MapType(keyType, valueType, _) =>
"map<" + generateInputTypeAsStrings(keyType, collationType) + ", " +
generateInputTypeAsStrings(valueType, collationType) + ">"
case AbstractMapType(keyType, valueType) =>
"map<" + generateInputTypeAsStrings(keyType, collationType) + ", " +
generateInputTypeAsStrings(valueType, collationType) + ">"
case StructType =>
"struct<start:" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) +
", end:" +
Expand All @@ -287,7 +289,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
/**
* Helper function to extract types of relevance
* @param inputType
* @return
* @return - Boolean that represents if inputType has/is a StringType
*/
def hasStringType(inputType: AbstractDataType): Boolean = {
inputType match {
Expand All @@ -300,7 +302,6 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
case AbstractArrayType(elementType) => hasStringType(elementType)
case TypeCollection(typeCollection) =>
typeCollection.exists(hasStringType)
case StructType => true
case StructType(fields) => fields.exists(sf => hasStringType(sf.dataType))
case _ => false
}
Expand All @@ -310,7 +311,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
* Helper function to replace expected parameters with expected input types.
* @param inputTypes - Input types generated by ExpectsInputType.inputTypes
* @param params - Parameters that are read from expression info
* @return
* @return - List of parameters where Expressions are replaced with input types
*/
def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = {
(inputTypes, params) match {
Expand All @@ -325,7 +326,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi

/**
* Helper method to extract relevant expressions that can be walked over.
* @return
* @return - (List of relevant expressions that expect input, List of expressions to skip)
*/
def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = {
var expressionCounter = 0
Expand Down Expand Up @@ -384,6 +385,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
(funInfos, toSkip)
}

/**
* Helper method to extract relevant expressions that can be walked over but are built with
* expression builder.
*
* @return - (List of expressions that are relevant builders, List of expressions to skip)
*/
def extractRelevantBuilders(): (Array[ExpressionInfo], List[String]) = {
var builderExpressionCounter = 0
val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId =>
spark.sessionState.catalog.lookupFunctionInfo(funcId)
}.filter(funInfo => {
// make sure that there is a constructor.
val cl = Utils.classForName(funInfo.getClassName)
cl.isAssignableFrom(classOf[ExpressionBuilder])
}).filter(funInfo => {
builderExpressionCounter = builderExpressionCounter + 1
val cl = Utils.classForName(funInfo.getClassName)
val method = cl.getMethod("build",
Utils.classForName("java.lang.String"),
Utils.classForName("scala.collection.Seq"))
var input: Seq[Expression] = Seq.empty
var i = 0
for (_ <- 1 to 10) {
input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary)
try {
method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes]
}
catch {
case _: Exception => i = i + 1
}
}
if (i == 10) false
else true
}).toArray

logInfo("Total number of expression that are built: " + builderExpressionCounter)
logInfo("Number of extracted expressions of relevance: " + funInfos.length)

(funInfos, List())
}

/**
* Helper function to generate string of an expression suitable for execution.
* @param expr - Expression that needs to be converted
Expand Down Expand Up @@ -441,10 +483,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
* 5) Otherwise, check if exceptions are the same
*/
test("SPARK-48280: Expression Walker for expression evaluation") {
val (funInfos, toSkip) = extractRelevantExpressions()
val (funInfosExpr, toSkip) = extractRelevantExpressions()
val (funInfosBuild, _) = extractRelevantBuilders()
val funInfos = funInfosExpr ++ funInfosBuild

for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) {
val cl = Utils.classForName(f.getClassName)
val TempCl = Utils.classForName(f.getClassName)
val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) {
val clTemp = Utils.classForName(f.getClassName)
val method = clTemp.getMethod("build",
Utils.classForName("java.lang.String"),
Utils.classForName("scala.collection.Seq"))
val instance = {
var input: Seq[Expression] = Seq.empty
var result: Expression = null
for (_ <- 1 to 10) {
input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary)
try {
val tempResult = method.invoke(null, f.getClassName, input)
if (result == null) result = tempResult.asInstanceOf[Expression]
}
catch {
case _: Exception =>
}
}
result
}
instance.getClass
}
else Utils.classForName(f.getClassName)

val headConstructor = cl.getConstructors
.zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1
val params = headConstructor.getParameters.map(p => p.getType)
Expand Down Expand Up @@ -526,10 +594,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi
* 5) Otherwise, check if exceptions are the same
*/
test("SPARK-48280: Expression Walker for codeGen generation") {
val (funInfos, toSkip) = extractRelevantExpressions()
val (funInfosExpr, toSkip) = extractRelevantExpressions()
val (funInfosBuild, _) = extractRelevantBuilders()
val funInfos = funInfosExpr ++ funInfosBuild

for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) {
val cl = Utils.classForName(f.getClassName)
val TempCl = Utils.classForName(f.getClassName)
val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) {
val clTemp = Utils.classForName(f.getClassName)
val method = clTemp.getMethod("build",
Utils.classForName("java.lang.String"),
Utils.classForName("scala.collection.Seq"))
val instance = {
var input: Seq[Expression] = Seq.empty
var result: Expression = null
for (_ <- 1 to 10) {
input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary)
try {
val tempResult = method.invoke(null, f.getClassName, input)
if (result == null) result = tempResult.asInstanceOf[Expression]
}
catch {
case _: Exception =>
}
}
result
}
instance.getClass
}
else Utils.classForName(f.getClassName)

val headConstructor = cl.getConstructors
.zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1
val params = headConstructor.getParameters.map(p => p.getType)
Expand Down

0 comments on commit ac34f1d

Please sign in to comment.