Skip to content

Commit

Permalink
[SPARK-49352][SQL] Avoid redundant array transform for identical expr…
Browse files Browse the repository at this point in the history
…ession

### What changes were proposed in this pull request?

This patch avoids `ArrayTransform` in `resolveArrayType` function if the resolution expression is the same as input param.

### Why are the changes needed?

Our customer encounters significant performance regression when migrating from Spark 3.2 to Spark 3.4 on a `Insert Into` query which is analyzed as a `AppendData` on an Iceberg table.
We found that the root cause is in Spark 3.4, `TableOutputResolver` resolves the query with additional `ArrayTransform` on an `ArrayType` field. The `ArrayTransform`'s lambda function is actually an identical function, i.e., the transformation is redundant.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test and manual e2e test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #47843 from viirya/fix_redundant_array_transform.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
viirya authored and dongjoon-hyun committed Aug 23, 2024
1 parent af45052 commit ee97caa
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,16 @@ object TableOutputResolver extends SQLConfHelper with Logging {
resolveColumnsByPosition(tableName, Seq(param), Seq(fakeAttr), conf, addError, colPath)
}
if (res.length == 1) {
val func = LambdaFunction(res.head, Seq(param))
Some(Alias(ArrayTransform(nullCheckedInput, func), expected.name)())
if (res.head == param) {
// If the element type is the same, we can reuse the input array directly.
Some(
Alias(nullCheckedInput, expected.name)(
nonInheritableMetadataKeys =
Seq(CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)))
} else {
val func = LambdaFunction(res.head, Seq(param))
Some(Alias(ArrayTransform(nullCheckedInput, func), expected.name)())
}
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Locale

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -303,6 +303,36 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {

def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan

test("SPARK-49352: Avoid redundant array transform for identical expression") {
def assertArrayField(fromType: ArrayType, toType: ArrayType, hasTransform: Boolean): Unit = {
val table = TestRelation(Seq($"a".int, $"arr".array(toType)))
val query = TestRelation(Seq($"arr".array(fromType), $"a".int))

val writePlan = byName(table, query).analyze

assertResolved(writePlan)
checkAnalysis(writePlan, writePlan)

val transform = writePlan.children.head.expressions.exists { e =>
e.find {
case _: ArrayTransform => true
case _ => false
}.isDefined
}
if (hasTransform) {
assert(transform)
} else {
assert(!transform)
}
}

assertArrayField(ArrayType(LongType), ArrayType(LongType), hasTransform = false)
assertArrayField(
ArrayType(new StructType().add("x", "int").add("y", "int")),
ArrayType(new StructType().add("y", "int").add("x", "byte")),
hasTransform = true)
}

test("SPARK-33136: output resolved on complex types for V2 write commands") {
def assertTypeCompatibility(name: String, fromType: DataType, toType: DataType): Unit = {
val table = TestRelation(StructType(Seq(StructField("a", toType))))
Expand Down

0 comments on commit ee97caa

Please sign in to comment.