Skip to content

Commit

Permalink
[SPARK-48578][SQL] add UTF8 string validation related functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Introduced 4 new string expressions in Spark SQL: `IsValidUTF8`, `MakeValidUTF8`, `ValidateUTF8`, `TryValidateUTF8`.

### Why are the changes needed?
These expressions offer a complete set of user-facing expressions that allow for UTF8String validation in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, 4 new string expressions are available.

### How was this patch tested?
Unit tests in `UTF8StringSuite` and `CollationSupportSuite` and e2e sql tests in `string-functions.sql`.

### Was this patch authored or co-authored using generative AI tooling?
Yes.

Closes #46845 from uros-db/string-validation.

Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
uros-db authored and cloud-fan committed Jun 25, 2024
1 parent 8c4ca7e commit 068be4b
Show file tree
Hide file tree
Showing 14 changed files with 775 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ public void testLowerCaseCodePoints() {
// Surrogate pairs are treated as invalid UTF8 sequences
assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
{(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}),
UTF8String.fromString("\ufffd\ufffd"), false);
UTF8String.fromString("\uFFFD\uFFFD"), false);
assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
{(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}),
UTF8String.fromString("\ufffd\ufffd"), true);
UTF8String.fromString("\uFFFD\uFFFD"), true);
}

/**
Expand Down
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2909,6 +2909,12 @@
],
"sqlState" : "42000"
},
"INVALID_UTF8_STRING" : {
"message" : [
"Invalid UTF8 byte sequence found in string: <str>."
],
"sqlState" : "22029"
},
"INVALID_VARIABLE_TYPE_FOR_QUERY_EXECUTE_IMMEDIATE" : {
"message" : [
"Variable type must be string type but got <varType>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,32 @@ public static boolean isLuhnNumber(UTF8String numberString) {
return checkSum % 10 == 0;
}

/**
* Function to validate a given UTF8 string according to Unicode rules.
*
* @param utf8String
* the input string to validate against possible invalid byte sequences
* @return
* the original string if the input string is a valid UTF8String, throw exception otherwise.
*/
public static UTF8String validateUTF8String(UTF8String utf8String) {
if (utf8String.isValid()) return utf8String;
else throw QueryExecutionErrors.invalidUTF8StringError(utf8String);
}

/**
* Function to try to validate a given UTF8 string according to Unicode rules.
*
* @param utf8String
* the input string to validate against possible invalid byte sequences
* @return
* the original string if the input string is a valid UTF8String, null otherwise.
*/
public static UTF8String tryValidateUTF8String(UTF8String utf8String) {
if (utf8String.isValid()) return utf8String;
else return null;
}

public static byte[] aesEncrypt(byte[] input,
byte[] key,
UTF8String mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ object FunctionRegistry {
expression[RegExpCount]("regexp_count"),
expression[RegExpSubStr]("regexp_substr"),
expression[RegExpInStr]("regexp_instr"),
expression[IsValidUTF8]("is_valid_utf8"),
expression[MakeValidUTF8]("make_valid_utf8"),
expression[ValidateUTF8]("validate_utf8"),
expression[TryValidateUTF8]("try_validate_utf8"),

// url functions
expression[UrlEncode]("url_encode"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -696,6 +696,189 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight)
}

/**
* A function that checks if a UTF8 string is valid.
*/
@ExpressionDescription(
usage = "_FUNC_(str) - Returns true if `str` is a valid UTF-8 string, otherwise returns false.",
arguments = """
Arguments:
* str - a string expression
""",
examples = """
Examples:
> SELECT _FUNC_('Spark');
true
> SELECT _FUNC_(x'61');
true
> SELECT _FUNC_(x'80');
false
> SELECT _FUNC_(x'61C262');
false
""",
since = "4.0.0",
group = "string_funcs")
case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType)

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)

override def nodeName: String = "is_valid_utf8"

override def nullable: Boolean = true

override def child: Expression = input

override protected def withNewChildInternal(newChild: Expression): IsValidUTF8 = {
copy(input = newChild)
}

}

/**
* A function that converts an invalid UTF8 string to a valid UTF8 string by replacing invalid
* UTF-8 byte sequences with the Unicode replacement character (U+FFFD), according to the UNICODE
* standard rules (Section 3.9, Paragraph D86, Table 3-7). Valid strings remain unchanged.
*/
// scalastyle:off
@ExpressionDescription(
usage = "_FUNC_(str) - Returns the original string if `str` is a valid UTF-8 string, " +
"otherwise returns a new string whose invalid UTF8 byte sequences are replaced using the " +
"UNICODE replacement character U+FFFD.",
arguments = """
Arguments:
* str - a string expression
""",
examples = """
Examples:
> SELECT _FUNC_('Spark');
Spark
> SELECT _FUNC_(x'61');
a
> SELECT _FUNC_(x'80');
> SELECT _FUNC_(x'61C262');
a�b
""",
since = "4.0.0",
group = "string_funcs")
// scalastyle:on
case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = Invoke(
input, "makeValid", SQLConf.get.defaultStringType)

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)

override def nodeName: String = "make_valid_utf8"

override def nullable: Boolean = true

override def child: Expression = input

override protected def withNewChildInternal(newChild: Expression): MakeValidUTF8 = {
copy(input = newChild)
}

}

/**
* A function that validates a UTF8 string, throwing an exception if the string is invalid.
*/
// scalastyle:off
@ExpressionDescription(
usage = "_FUNC_(str) - Returns the original string if `str` is a valid UTF-8 string, " +
"otherwise throws an exception.",
arguments = """
Arguments:
* str - a string expression
""",
examples = """
Examples:
> SELECT _FUNC_('Spark');
Spark
> SELECT _FUNC_(x'61');
a
""",
since = "4.0.0",
group = "string_funcs")
// scalastyle:on
case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = StaticInvoke(
classOf[ExpressionImplUtils],
input.dataType,
"validateUTF8String",
Seq(input),
inputTypes)

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)

override def nodeName: String = "validate_utf8"

override def nullable: Boolean = true

override def child: Expression = input

override protected def withNewChildInternal(newChild: Expression): ValidateUTF8 = {
copy(input = newChild)
}

}

/**
* A function that tries to validate a UTF8 string, returning NULL if the string is invalid.
*/
// scalastyle:off
@ExpressionDescription(
usage = "_FUNC_(str) - Returns the original string if `str` is a valid UTF-8 string, " +
"otherwise returns NULL.",
arguments = """
Arguments:
* str - a string expression
""",
examples = """
Examples:
> SELECT _FUNC_('Spark');
Spark
> SELECT _FUNC_(x'61');
a
> SELECT _FUNC_(x'80');
NULL
> SELECT _FUNC_(x'61C262');
NULL
""",
since = "4.0.0",
group = "string_funcs")
// scalastyle:on
case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes
with UnaryLike[Expression] with NullIntolerant {

override lazy val replacement: Expression = StaticInvoke(
classOf[ExpressionImplUtils],
input.dataType,
"tryValidateUTF8String",
Seq(input),
inputTypes)

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)

override def nodeName: String = "try_validate_utf8"

override def nullable: Boolean = true

override def child: Expression = input

override protected def withNewChildInternal(newChild: Expression): TryValidateUTF8 = {
copy(input = newChild)
}

}

/**
* Replace all occurrences with string.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
summary = getSummary(context))
}

def invalidUTF8StringError(str: UTF8String): SparkIllegalArgumentException = {
new SparkIllegalArgumentException(
errorClass = "INVALID_UTF8_STRING",
messageParameters = Map(
"str" -> str.getBytes.map(byte => f"\\x$byte%02X").mkString
)
)
}

def invalidArrayIndexError(
index: Int,
numElements: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.lang3.{JavaVersion, SystemUtils}

import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.unsafe.types.UTF8String

class ExpressionImplUtilsSuite extends SparkFunSuite {
Expand Down Expand Up @@ -353,4 +353,65 @@ class ExpressionImplUtilsSuite extends SparkFunSuite {
parameters = t.errorParamsMap
)
}

test("Validate UTF8 string") {
def validateUTF8(str: UTF8String, expected: UTF8String, except: Boolean): Unit = {
if (except) {
checkError(
exception = intercept[SparkIllegalArgumentException] {
ExpressionImplUtils.validateUTF8String(str)
},
errorClass = "INVALID_UTF8_STRING",
parameters = Map(
"str" -> str.getBytes.map(byte => f"\\x$byte%02X").mkString
)
)
} else {
assert(ExpressionImplUtils.validateUTF8String(str)== expected)
}
}
validateUTF8(UTF8String.EMPTY_UTF8,
UTF8String.fromString(""), except = false)
validateUTF8(UTF8String.fromString(""),
UTF8String.fromString(""), except = false)
validateUTF8(UTF8String.fromString("aa"),
UTF8String.fromString("aa"), except = false)
validateUTF8(UTF8String.fromString("\u0061"),
UTF8String.fromString("\u0061"), except = false)
validateUTF8(UTF8String.fromString(""),
UTF8String.fromString(""), except = false)
validateUTF8(UTF8String.fromString("abc"),
UTF8String.fromString("abc"), except = false)
validateUTF8(UTF8String.fromString("hello"),
UTF8String.fromString("hello"), except = false)
validateUTF8(UTF8String.fromBytes(Array.empty[Byte]),
UTF8String.fromString(""), except = false)
validateUTF8(UTF8String.fromBytes(Array[Byte](0x41)),
UTF8String.fromString("A"), except = false)
validateUTF8(UTF8String.fromBytes(Array[Byte](0x61)),
UTF8String.fromString("a"), except = false)
validateUTF8(UTF8String.fromBytes(Array[Byte](0x80.toByte)),
UTF8String.fromString("\uFFFD"), except = true)
validateUTF8(UTF8String.fromBytes(Array[Byte](0xFF.toByte)),
UTF8String.fromString("\uFFFD"), except = true)
}

test("TryValidate UTF8 string") {
def tryValidateUTF8(str: UTF8String, expected: UTF8String): Unit = {
assert(ExpressionImplUtils.tryValidateUTF8String(str) == expected)
}
tryValidateUTF8(UTF8String.fromString(""), UTF8String.fromString(""))
tryValidateUTF8(UTF8String.fromString("aa"), UTF8String.fromString("aa"))
tryValidateUTF8(UTF8String.fromString("\u0061"), UTF8String.fromString("\u0061"))
tryValidateUTF8(UTF8String.EMPTY_UTF8, UTF8String.fromString(""))
tryValidateUTF8(UTF8String.fromString(""), UTF8String.fromString(""))
tryValidateUTF8(UTF8String.fromString("abc"), UTF8String.fromString("abc"))
tryValidateUTF8(UTF8String.fromString("hello"), UTF8String.fromString("hello"))
tryValidateUTF8(UTF8String.fromBytes(Array.empty[Byte]), UTF8String.fromString(""))
tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0x41)), UTF8String.fromString("A"))
tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0x61)), UTF8String.fromString("a"))
tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0x80.toByte)), null)
tryValidateUTF8(UTF8String.fromBytes(Array[Byte](0xFF.toByte)), null)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
| org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT isnan(cast('NaN' as double)) | struct<isnan(CAST(NaN AS DOUBLE)):boolean> |
| org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT isnotnull(1) | struct<(1 IS NOT NULL):boolean> |
| org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) | struct<(1 IS NULL):boolean> |
| org.apache.spark.sql.catalyst.expressions.IsValidUTF8 | is_valid_utf8 | SELECT is_valid_utf8('Spark') | struct<is_valid_utf8(Spark):boolean> |
| org.apache.spark.sql.catalyst.expressions.JsonObjectKeys | json_object_keys | SELECT json_object_keys('{}') | struct<json_object_keys({}):array<string>> |
| org.apache.spark.sql.catalyst.expressions.JsonToStructs | from_json | SELECT from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE') | struct<from_json({"a":1, "b":0.8}):struct<a:int,b:double>> |
| org.apache.spark.sql.catalyst.expressions.JsonTuple | json_tuple | SELECT json_tuple('{"a":1, "b":2}', 'a', 'b') | struct<c0:string,c1:string> |
Expand Down Expand Up @@ -207,6 +208,7 @@
| org.apache.spark.sql.catalyst.expressions.MakeTimestamp | make_timestamp | SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887) | struct<make_timestamp(2014, 12, 28, 6, 30, 45.887):timestamp> |
| org.apache.spark.sql.catalyst.expressions.MakeTimestampLTZExpressionBuilder | make_timestamp_ltz | SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887) | struct<make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887):timestamp> |
| org.apache.spark.sql.catalyst.expressions.MakeTimestampNTZExpressionBuilder | make_timestamp_ntz | SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887) | struct<make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887):timestamp_ntz> |
| org.apache.spark.sql.catalyst.expressions.MakeValidUTF8 | make_valid_utf8 | SELECT make_valid_utf8('Spark') | struct<make_valid_utf8(Spark):string> |
| org.apache.spark.sql.catalyst.expressions.MakeYMInterval | make_ym_interval | SELECT make_ym_interval(1, 2) | struct<make_ym_interval(1, 2):interval year to month> |
| org.apache.spark.sql.catalyst.expressions.MapConcat | map_concat | SELECT map_concat(map(1, 'a', 2, 'b'), map(3, 'c')) | struct<map_concat(map(1, a, 2, b), map(3, c)):map<int,string>> |
| org.apache.spark.sql.catalyst.expressions.MapContainsKey | map_contains_key | SELECT map_contains_key(map(1, 'a', 2, 'b'), 1) | struct<map_contains_key(map(1, a, 2, b), 1):boolean> |
Expand Down Expand Up @@ -357,6 +359,7 @@
| org.apache.spark.sql.catalyst.expressions.TryToBinary | try_to_binary | SELECT try_to_binary('abc', 'utf-8') | struct<try_to_binary(abc, utf-8):binary> |
| org.apache.spark.sql.catalyst.expressions.TryToNumber | try_to_number | SELECT try_to_number('454', '999') | struct<try_to_number(454, 999):decimal(3,0)> |
| org.apache.spark.sql.catalyst.expressions.TryToTimestampExpressionBuilder | try_to_timestamp | SELECT try_to_timestamp('2016-12-31 00:12:00') | struct<try_to_timestamp(2016-12-31 00:12:00):timestamp> |
| org.apache.spark.sql.catalyst.expressions.TryValidateUTF8 | try_validate_utf8 | SELECT try_validate_utf8('Spark') | struct<try_validate_utf8(Spark):string> |
| org.apache.spark.sql.catalyst.expressions.TypeOf | typeof | SELECT typeof(1) | struct<typeof(1):string> |
| org.apache.spark.sql.catalyst.expressions.UnBase64 | unbase64 | SELECT unbase64('U3BhcmsgU1FM') | struct<unbase64(U3BhcmsgU1FM):binary> |
| org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct<negative(1):int> |
Expand All @@ -372,6 +375,7 @@
| org.apache.spark.sql.catalyst.expressions.UrlDecode | url_decode | SELECT url_decode('https%3A%2F%2Fspark.apache.org') | struct<url_decode(https%3A%2F%2Fspark.apache.org):string> |
| org.apache.spark.sql.catalyst.expressions.UrlEncode | url_encode | SELECT url_encode('https://spark.apache.org') | struct<url_encode(https://spark.apache.org):string> |
| org.apache.spark.sql.catalyst.expressions.Uuid | uuid | SELECT uuid() | struct<uuid():string> |
| org.apache.spark.sql.catalyst.expressions.ValidateUTF8 | validate_utf8 | SELECT validate_utf8('Spark') | struct<validate_utf8(Spark):string> |
| org.apache.spark.sql.catalyst.expressions.WeekDay | weekday | SELECT weekday('2009-07-30') | struct<weekday(2009-07-30):int> |
| org.apache.spark.sql.catalyst.expressions.WeekOfYear | weekofyear | SELECT weekofyear('2008-02-20') | struct<weekofyear(2008-02-20):int> |
| org.apache.spark.sql.catalyst.expressions.WidthBucket | width_bucket | SELECT width_bucket(5.3, 0.2, 10.6, 5) | struct<width_bucket(5.3, 0.2, 10.6, 5):bigint> |
Expand Down
Loading

0 comments on commit 068be4b

Please sign in to comment.