Skip to content

Commit

Permalink
[SPARK-49683][SQL] Block trim collation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Trim collation is currently in implementation phase. These change blocks all paths from using it and afterwards trim collation gets enabled for different expressions it will be gradually whitelisted.

### Why are the changes needed?
Trim collation is currently in implementation phase. These change blocks all paths from using it and afterwards trim collation gets enabled for different expressions it will be gradually whitelisted.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
No additional tests, just added field that's not used.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48336 from jovanpavl-db/block-collation-trim.

Lead-authored-by: Jovan Pavlovic <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
jovanpavl-db and HyukjinKwon committed Oct 5, 2024
1 parent d8c04cf commit 3e69b40
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,79 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
* AbstractStringType is an abstract class for StringType with collation support.
* AbstractStringType is an abstract class for StringType with collation support. As every type of
* collation can support trim specifier this class is parametrized with it.
*/
abstract class AbstractStringType extends AbstractDataType {
abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false)
extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def simpleString: String = "string"
private[sql] def canUseTrimCollation(other: DataType): Boolean =
supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation
}

/**
* Use StringTypeBinary for expressions supporting only binary collation.
*/
case object StringTypeBinary extends AbstractStringType {
case class StringTypeBinary(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality &&
canUseTrimCollation(other)
}

object StringTypeBinary extends StringTypeBinary(false) {
def apply(supportsTrimCollation: Boolean): StringTypeBinary = {
new StringTypeBinary(supportsTrimCollation)
}
}

/**
* Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation.
*/
case object StringTypeBinaryLcase extends AbstractStringType {
case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality ||
other.asInstanceOf[StringType].isUTF8LcaseCollation)
other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other)
}

object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) {
def apply(supportsTrimCollation: Boolean): StringTypeBinaryLcase = {
new StringTypeBinaryLcase(supportsTrimCollation)
}
}

/**
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
* and ICU) but limited to using case and accent sensitivity specifiers.
*/
case object StringTypeWithCaseAccentSensitivity extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]
case class StringTypeWithCaseAccentSensitivity(
override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && canUseTrimCollation(other)
}

object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) {
def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = {
new StringTypeWithCaseAccentSensitivity(supportsTrimCollation)
}
}

/**
* Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
* CS_AI collation types.
*/
case object StringTypeNonCSAICollation extends AbstractStringType {
case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false)
extends AbstractStringType(supportsTrimCollation) {
override private[sql] def acceptsType(other: DataType): Boolean =
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI
other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI &&
canUseTrimCollation(other)
}

object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) {
def apply(supportsTrimCollation: Boolean): StringTypeNonCSAICollation = {
new StringTypeNonCSAICollation(supportsTrimCollation)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
private[sql] def isNonCSAI: Boolean =
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)

private[sql] def usesTrimCollation: Boolean =
CollationFactory.usesTrimCollation(collationId)

private[sql] def isUTF8BinaryCollation: Boolean =
collationId == CollationFactory.UTF8_BINARY_COLLATION_ID

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
override def dataType: DataType = BinaryType

final lazy val collationId: Int = expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ case class HllSketchAgg(

override def inputTypes: Seq[AbstractDataType] =
Seq(
TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType),
TypeCollection(
IntegerType,
LongType,
StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true),
BinaryType),
IntegerType)

override def dataType: DataType = BinaryType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ case class Collate(child: Expression, collationName: String)
extends UnaryExpression with ExpectsInputTypes {
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Expand Down Expand Up @@ -115,5 +116,6 @@ case class Collation(child: Expression)
val collationName = CollationFactory.fetchCollation(collationId).collationName
Literal.create(collationName, SQLConf.get.defaultStringType)
}
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true))
}
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,11 @@ class CollationSQLExpressionsSuite
StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI",
Map("1" -> "A", "2" -> "B", "3" -> "C"))
)
val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null)
val unsupportedTestCases = Seq(
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null),
StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null))
testCases.foreach(t => {
// Unit test.
val text = Literal.create(t.text, StringType(t.collation))
Expand All @@ -998,28 +1002,30 @@ class CollationSQLExpressionsSuite
}
})
// Test unsupported collation.
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) {
val query =
s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " +
s"'${unsupportedTestCase.keyValueDelim}')"
checkError(
exception = intercept[AnalysisException] {
sql(query).collect()
},
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
sqlState = Some("42K09"),
parameters = Map(
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " +
"'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""),
"paramIndex" -> "first",
"inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"",
"inputType" -> "\"STRING COLLATE UNICODE_AI\"",
"requiredType" -> "\"STRING\""),
context = ExpectedContext(
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
start = 7,
stop = 41))
}
unsupportedTestCases.foreach(t => {
withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) {
val query =
s"select str_to_map('${t.text}', '${t.pairDelim}', " +
s"'${t.keyValueDelim}')"
checkError(
exception = intercept[AnalysisException] {
sql(query).collect()
},
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
sqlState = Some("42K09"),
parameters = Map(
"sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate " + s"${t.collation}, " +
"'?' collate " + s"${t.collation}, '?' collate ${t.collation})" + "\""),
"paramIndex" -> "first",
"inputSql" -> ("\"'a:1,b:2,c:3' collate " + s"${t.collation}" + "\""),
"inputType" -> ("\"STRING COLLATE " + s"${t.collation}" + "\""),
"requiredType" -> "\"STRING\""),
context = ExpectedContext(
fragment = "str_to_map('a:1,b:2,c:3', '?', '?')",
start = 7,
stop = 41))
}
})
}

test("Support RaiseError misc expression with collation") {
Expand Down

0 comments on commit 3e69b40

Please sign in to comment.