From d9223e5fc93178a1246e4705a1f768701a924fd9 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 4 Sep 2024 15:01:23 -0700 Subject: [PATCH 01/27] commit commit uniform expression commit commit commit --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/randomExpressions.scala | 218 +++++++- .../sql-functions/sql-expression-schema.md | 2 + .../sql-tests/analyzer-results/random.sql.out | 459 +++++++++++++++ .../resources/sql-tests/inputs/random.sql | 42 +- .../sql-tests/results/random.sql.out | 527 ++++++++++++++++++ 6 files changed, 1247 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index dfe1bd12bb7ff..6f816aa4305f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -384,7 +384,9 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Rand]("random", true, Some("3.0.0")), expression[Randn]("randn"), + expression[RandStr]("randstr"), expression[Stack]("stack"), + expression[Uniform]("uniform"), expression[ZeroIfNull]("zeroifnull"), CaseWhen.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f5db972a28643..65806459a3c39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -17,13 +17,19 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.Random + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike} +import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.XORShiftRandom /** @@ -181,3 +187,211 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { object Randn { def apply(seed: Long): Randn = Randn(Literal(seed, LongType)) } + +@ExpressionDescription( + usage = """ + _FUNC_(min, max, seed) - Returns a random value with independent and identically + distributed (i.i.d.) values with the specified range of numbers. The random seed is optional. + The provided numbers specifying the minimum and maximum values of the range must be constant. + If both of these numbers are integers, then the result will also be an integer. Otherwise if + one or both of these are floating-point numbers, then the result will also be a floating-point + number. + """, + examples = """ + Examples: + > SELECT _FUNC_(0, 1); + -0.3254147983080288 + > SELECT _FUNC_(10, 20, 0); + 26.034991609278433 + """, + since = "4.0.0", + group = "math_funcs") +case class Uniform(min: Expression, max: Expression, seed: Expression) + extends RuntimeReplaceable with TernaryLike[Expression] with ExpressionWithRandomSeed { + def this(min: Expression, max: Expression) = + this(min, max, Literal(Uniform.random.nextLong(), LongType)) + + final override lazy val deterministic: Boolean = false + override val nodePatterns: Seq[TreePattern] = + Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED) + + override val dataType: DataType = { + val first = min.dataType + val second = max.dataType + (min.dataType, max.dataType) match { + case _ if !valid(min) || !valid(max) => NullType + case (_, LongType) | (LongType, _) if Seq(first, second).forall(integer) => LongType + case (_, IntegerType) | (IntegerType, _) if Seq(first, second).forall(integer) => IntegerType + case (_, ShortType) | (ShortType, _) if Seq(first, second).forall(integer) => ShortType + case (_, DoubleType) | (DoubleType, _) => DoubleType + case (_, FloatType) | (FloatType, _) => FloatType + case _ => NullType + } + } + + private def valid(e: Expression): Boolean = e.dataType match { + case _ if !e.foldable => false + case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType => true + case _ => false + } + + private def integer(t: DataType): Boolean = t match { + case _: ShortType | _: IntegerType | _: LongType => true + case _ => false + } + + override def checkInputDataTypes(): TypeCheckResult = { + var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess + Seq(min, max, seed).zipWithIndex.foreach { case (expr: Expression, index: Int) => + if (!valid(expr)) { + result = DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(index), + "requiredType" -> "constant value of integer or floating-point", + "inputSql" -> toSQLExpr(expr), + "inputType" -> toSQLType(expr.dataType))) + } + } + result + } + + override def first: Expression = min + override def second: Expression = max + override def third: Expression = seed + + override def seedExpression: Expression = seed + override def withNewSeed(newSeed: Long): Expression = + Uniform(min, max, Literal(newSeed, LongType)) + + override def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + Uniform(newFirst, newSecond, newThird) + + override def replacement: Expression = { + def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to) + cast(Add( + cast(min, DoubleType), + Multiply( + Subtract( + cast(max, DoubleType), + cast(min, DoubleType)), + Rand(seed))), + dataType) + } +} + +object Uniform { + lazy val random = new Random() +} + +@ExpressionDescription( + usage = """ + _FUNC_(length, seed) - Returns a string of the specified length whose characters are chosen + uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is + optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT, + respectively). + """, + examples = + """ + Examples: + > SELECT _FUNC_(3, 0); + abc + """, + since = "4.0.0", + group = "math_funcs") +case class RandStr(length: Expression, override val seedExpression: Expression) + extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic { + def this(length: Expression) = this(length, Literal(Uniform.random.nextLong(), LongType)) + + override def nullable: Boolean = false + override def dataType: DataType = StringType + override def stateful: Boolean = true + override def left: Expression = length + override def right: Expression = seedExpression + + /** + * Record ID within each partition. By being transient, the Random Number Generator is + * reset every time we serialize and deserialize and initialize it. + */ + @transient protected var rng: XORShiftRandom = _ + + @transient protected lazy val seed: Long = seedExpression match { + case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int] + case e if e.dataType == LongType => e.eval().asInstanceOf[Long] + } + override protected def initializeInternal(partitionIndex: Int): Unit = { + rng = new XORShiftRandom(seed + partitionIndex) + } + + override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType)) + override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = + RandStr(newFirst, newSecond) + + override def checkInputDataTypes(): TypeCheckResult = { + var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess + Seq(length, seedExpression).zipWithIndex.foreach { case (expr: Expression, index: Int) => + val valid = expr.dataType match { + case _ if !expr.foldable => false + case _: ShortType | _: IntegerType => true + case _: LongType if index == 1 => true + case _ => false + } + if (!valid) { + result = DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(index), + "requiredType" -> "constant value of INT or SMALLINT", + "inputSql" -> toSQLExpr(expr), + "inputType" -> toSQLType(expr.dataType))) + } + } + result + } + + override def evalInternal(input: InternalRow): Any = { + val numChars: Int = length.eval(input).asInstanceOf[Int] + val bytes = new Array[Byte](numChars) + (0 until numChars).foreach { i => + val num = (rng.nextInt() % 30).abs + num match { + case _ if num < 10 => + bytes.update(i, ('0' + num).toByte) + case _ if num < 20 => + bytes.update(i, ('a' + num - 10).toByte) + case _ => + bytes.update(i, ('A' + num - 20).toByte) + } + } + val result: UTF8String = UTF8String.fromBytes(bytes.toArray) + result + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = classOf[XORShiftRandom].getName + val rngTerm = ctx.addMutableState(className, "rng") + ctx.addPartitionInitializationStatement( + s"$rngTerm = new $className(${seed}L + partitionIndex);") + val eval = length.genCode(ctx) + ev.copy(code = + code""" + |${eval.code} + |int length = (int)(${eval.value}); + |char[] chars = new char[length]; + |for (int i = 0; i < length; i++) { + | int v = Math.abs($rngTerm.nextInt() % 30); + | if (v < 10) { + | chars[i] = (char)('0' + v); + | } else if (v < 20) { + | chars[i] = (char)('a' + (v - 10)); + | } else { + | chars[i] = (char)('A' + (v - 20)); + | } + |} + |UTF8String ${ev.value} = UTF8String.fromString(new String(chars)); + |boolean ${ev.isNull} = false; + |""".stripMargin, + isNull = FalseLiteral) + } +} diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 316e5e9676723..abad4fee0a81b 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -265,6 +265,7 @@ | org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct | | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct | | org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | +| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) | struct | | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | | org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | | org.apache.spark.sql.catalyst.expressions.RegExpCount | regexp_count | SELECT regexp_count('Steven Jones and Stephen Smith are the best players', 'Ste(v|ph)en') | struct | @@ -367,6 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | +| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 3cacbdc141053..e975721c5c222 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -93,3 +93,462 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "rand('1')" } ] } + + +-- !query +SELECT uniform(0, 1, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, 10, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0L, 10L, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, 10L, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(0, 10S, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10.0F, 20.0F, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20.0F, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(10, 20.0F) IS NOT NULL AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT uniform(NULL, 1, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "first", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(NULL, 1, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(NULL, 1, 0)" + } ] +} + + +-- !query +SELECT uniform(0, NULL, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "second", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(0, NULL, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(0, NULL, 0)" + } ] +} + + +-- !query +SELECT uniform(0, 1, NULL) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "third", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(0, 1, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(0, 1, NULL)" + } ] +} + + +-- !query +SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "third", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(10, 20, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 27, + "fragment" : "uniform(10, 20, col)" + } ] +} + + +-- !query +SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "first", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(col, 10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(col, 10, 0)" + } ] +} + + +-- !query +SELECT uniform(10) AS result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 18, + "fragment" : "uniform(10)" + } ] +} + + +-- !query +SELECT uniform(10, 20, 30, 40) AS result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "4", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "uniform(10, 20, 30, 40)" + } ] +} + + +-- !query +SELECT randstr(1, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(5, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10S, 0) AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10) IS NOT NULL AS result +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT randstr(10L, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10\"", + "inputType" : "\"BIGINT\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(10L, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0F, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"FLOAT\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0F, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0D, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"DOUBLE\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0D, 0)" + } ] +} + + +-- !query +SELECT randstr(NULL, 0) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(NULL, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(NULL, 0)" + } ] +} + + +-- !query +SELECT randstr(0, NULL) AS result +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "second", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(0, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(0, NULL)" + } ] +} + + +-- !query +SELECT randstr(col, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(col, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(col, 0)" + } ] +} + + +-- !query +SELECT randstr(10, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "second", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(10, col)" + } ] +} + + +-- !query +SELECT randstr(10, 0, 1) AS result +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "3", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2]", + "functionName" : "`randstr`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10, 0, 1)" + } ] +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/random.sql b/sql/core/src/test/resources/sql-tests/inputs/random.sql index a1aae7b8759dc..a71b0293295fc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/random.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/random.sql @@ -14,4 +14,44 @@ SELECT randn(NULL); SELECT randn(cast(NULL AS long)); -- randn unsupported data type -SELECT rand('1') +SELECT rand('1'); + +-- The uniform random number generation function supports generating random numbers within a +-- specified range. We use a seed of zero for these queries to keep tests deterministic. +SELECT uniform(0, 1, 0) AS result; +SELECT uniform(0, 10, 0) AS result; +SELECT uniform(0L, 10L, 0) AS result; +SELECT uniform(0, 10L, 0) AS result; +SELECT uniform(0, 10S, 0) AS result; +SELECT uniform(10, 20, 0) AS result; +SELECT uniform(10.0F, 20.0F, 0) AS result; +SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result; +SELECT uniform(10, 20.0F, 0) AS result; +SELECT uniform(10, 20, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT uniform(10, 20.0F) IS NOT NULL AS result; +-- Negative test cases for the uniform random number generator. +SELECT uniform(NULL, 1, 0) AS result; +SELECT uniform(0, NULL, 0) AS result; +SELECT uniform(0, 1, NULL) AS result; +SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT uniform(10) AS result; +SELECT uniform(10, 20, 30, 40) AS result; + +-- The randstr random string generation function supports generating random strings within a +-- specified length. We use a seed of zero for these queries to keep tests deterministic. +SELECT randstr(1, 0) AS result; +SELECT randstr(5, 0) AS result; +SELECT randstr(10, 0) AS result; +SELECT randstr(10S, 0) AS result; +SELECT randstr(10, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT randstr(10) IS NOT NULL AS result; +-- Negative test cases for the randstr random number generator. +SELECT randstr(10L, 0) AS result; +SELECT randstr(10.0F, 0) AS result; +SELECT randstr(10.0D, 0) AS result; +SELECT randstr(NULL, 0) AS result; +SELECT randstr(0, NULL) AS result; +SELECT randstr(col, 0) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT randstr(10, col) AS result FROM VALUES (0), (1), (2) tab(col); +SELECT randstr(10, 0, 1) AS result; diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 16984de3ff257..78eb0c86b9d4d 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -113,3 +113,530 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "fragment" : "rand('1')" } ] } + + +-- !query +SELECT uniform(0, 1, 0) AS result +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT uniform(0, 10, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(0L, 10L, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(0, 10L, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(0, 10S, 0) AS result +-- !query schema +struct +-- !query output +7 + + +-- !query +SELECT uniform(10, 20, 0) AS result +-- !query schema +struct +-- !query output +17 + + +-- !query +SELECT uniform(10.0F, 20.0F, 0) AS result +-- !query schema +struct +-- !query output +17.604954 + + +-- !query +SELECT uniform(10.0D, 20.0D, CAST(3 / 7 AS LONG)) AS result +-- !query schema +struct +-- !query output +17.604953758285916 + + +-- !query +SELECT uniform(10, 20.0F, 0) AS result +-- !query schema +struct +-- !query output +17.604954 + + +-- !query +SELECT uniform(10, 20, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct +-- !query output +15 +16 +17 + + +-- !query +SELECT uniform(10, 20.0F) IS NOT NULL AS result +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT uniform(NULL, 1, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "first", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(NULL, 1, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(NULL, 1, 0)" + } ] +} + + +-- !query +SELECT uniform(0, NULL, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "second", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(0, NULL, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(0, NULL, 0)" + } ] +} + + +-- !query +SELECT uniform(0, 1, NULL) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "third", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(0, 1, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(0, 1, NULL)" + } ] +} + + +-- !query +SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "third", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(10, 20, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 27, + "fragment" : "uniform(10, 20, col)" + } ] +} + + +-- !query +SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "first", + "requiredType" : "constant value of integer or floating-point", + "sqlExpr" : "\"uniform(col, 10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 26, + "fragment" : "uniform(col, 10, 0)" + } ] +} + + +-- !query +SELECT uniform(10) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "1", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 18, + "fragment" : "uniform(10)" + } ] +} + + +-- !query +SELECT uniform(10, 20, 30, 40) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "4", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[2, 3]", + "functionName" : "`uniform`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 30, + "fragment" : "uniform(10, 20, 30, 40)" + } ] +} + + +-- !query +SELECT randstr(1, 0) AS result +-- !query schema +struct +-- !query output +8 + + +-- !query +SELECT randstr(5, 0) AS result +-- !query schema +struct +-- !query output +8i70B + + +-- !query +SELECT randstr(10, 0) AS result +-- !query schema +struct +-- !query output +8i70BBEJ6A + + +-- !query +SELECT randstr(10S, 0) AS result +-- !query schema +struct +-- !query output +8i70BBEJ6A + + +-- !query +SELECT randstr(10, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct +-- !query output +2iAhij5i3F +3AD1fIHj7B +8i70BBEJ6A + + +-- !query +SELECT randstr(10) IS NOT NULL AS result +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT randstr(10L, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10\"", + "inputType" : "\"BIGINT\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(10L, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0F, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"FLOAT\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0F, 0)" + } ] +} + + +-- !query +SELECT randstr(10.0D, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"10.0\"", + "inputType" : "\"DOUBLE\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10.0, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10.0D, 0)" + } ] +} + + +-- !query +SELECT randstr(NULL, 0) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(NULL, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(NULL, 0)" + } ] +} + + +-- !query +SELECT randstr(0, NULL) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"NULL\"", + "inputType" : "\"VOID\"", + "paramIndex" : "second", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(0, NULL)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(0, NULL)" + } ] +} + + +-- !query +SELECT randstr(col, 0) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "first", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(col, 0)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "randstr(col, 0)" + } ] +} + + +-- !query +SELECT randstr(10, col) AS result FROM VALUES (0), (1), (2) tab(col) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"col\"", + "inputType" : "\"INT\"", + "paramIndex" : "second", + "requiredType" : "constant value of INT or SMALLINT", + "sqlExpr" : "\"randstr(10, col)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "randstr(10, col)" + } ] +} + + +-- !query +SELECT randstr(10, 0, 1) AS result +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "3", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2]", + "functionName" : "`randstr`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "randstr(10, 0, 1)" + } ] +} From c0a255163c2946f1853138355a2b97707a837cd7 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 10:53:13 -0700 Subject: [PATCH 02/27] respond to code review comments --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 6 +++--- .../test/resources/sql-functions/sql-expression-schema.md | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 65806459a3c39..6b23fdb0a9bc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -287,7 +287,7 @@ object Uniform { @ExpressionDescription( usage = """ - _FUNC_(length, seed) - Returns a string of the specified length whose characters are chosen + _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). @@ -296,10 +296,10 @@ object Uniform { """ Examples: > SELECT _FUNC_(3, 0); - abc + 8i7 """, since = "4.0.0", - group = "math_funcs") + group = "string_funcs") case class RandStr(length: Expression, override val seedExpression: Expression) extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic { def this(length: Expression) = this(length, Literal(Uniform.random.nextLong(), LongType)) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index abad4fee0a81b..4c9a073c2cebd 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -368,7 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | -| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | From f6ffde0b90e20eeb1b1cc52852a636e4d20e87ce Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 10:55:33 -0700 Subject: [PATCH 03/27] respond to code review comments --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6b23fdb0a9bc8..38344715c8e18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -190,7 +190,7 @@ object Randn { @ExpressionDescription( usage = """ - _FUNC_(min, max, seed) - Returns a random value with independent and identically + _FUNC_(min, max[, seed]) - Returns a random value with independent and identically distributed (i.i.d.) values with the specified range of numbers. The random seed is optional. The provided numbers specifying the minimum and maximum values of the range must be constant. If both of these numbers are integers, then the result will also be an integer. Otherwise if From c22fef26b75f52f1ea2c10e4c32b98bfb5865d3b Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 13:26:10 -0700 Subject: [PATCH 04/27] respond to code review comments --- .../sql/catalyst/expressions/randomExpressions.scala | 8 ++++++++ .../test/resources/sql-functions/sql-expression-schema.md | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 38344715c8e18..1e8b3b32457bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -27,7 +27,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern} +import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.XORShiftRandom @@ -268,6 +270,9 @@ case class Uniform(min: Expression, max: Expression, seed: Expression) newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = Uniform(newFirst, newSecond, newThird) + override def toString: String = prettyName + truncatedString( + Seq(min, max), "(", ", ", ")", SQLConf.get.maxToStringFields) + override def replacement: Expression = { def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to) cast(Add( @@ -328,6 +333,9 @@ case class RandStr(length: Expression, override val seedExpression: Expression) override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = RandStr(newFirst, newSecond) + override def toString: String = prettyName + truncatedString( + Seq(length), "(", ", ", ")", SQLConf.get.maxToStringFields) + override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess Seq(length, seedExpression).zipWithIndex.foreach { case (expr: Expression, index: Int) => diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 4c9a073c2cebd..0cd2e14e6b7a0 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -368,7 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | -| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | From c78d8f0e15716f8624f60fc53bd5ab15f98c0298 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 16:36:13 -0700 Subject: [PATCH 05/27] fix test --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 6 ++---- .../test/resources/sql-functions/sql-expression-schema.md | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 1e8b3b32457bd..85f43ffed33e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -201,10 +201,8 @@ object Randn { """, examples = """ Examples: - > SELECT _FUNC_(0, 1); - -0.3254147983080288 - > SELECT _FUNC_(10, 20, 0); - 26.034991609278433 + > SELECT _FUNC_(10, 20) > 0; + true """, since = "4.0.0", group = "math_funcs") diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0cd2e14e6b7a0..1b83406f72a16 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -368,7 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | -| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | From 5b6194e1dcc7469dfcacebbfc551bd559908b7e0 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 20:29:15 -0700 Subject: [PATCH 06/27] fix test --- .../src/test/resources/sql-functions/sql-expression-schema.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 1b83406f72a16..0a6b97580a91b 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -368,7 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | -| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(0, 1) | struct | +| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(10, 20) | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | @@ -464,4 +464,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | From da390f97b7fbcea3fcb21b92876bd2330681b212 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 20:29:48 -0700 Subject: [PATCH 07/27] fix test --- .../src/test/resources/sql-functions/sql-expression-schema.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 0a6b97580a91b..7a554cda99ca4 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -464,4 +464,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file From 836ec8ee0528c5db07c13fe9fc4cf904ab446a55 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 20:34:49 -0700 Subject: [PATCH 08/27] commit commit --- .../src/test/resources/sql-functions/sql-expression-schema.md | 3 +-- .../scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 7a554cda99ca4..63722064faf0f 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -368,7 +368,6 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | -| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(10, 20) | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | @@ -464,4 +463,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala index 8c0231fddf39f..f782d28b707ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala @@ -118,11 +118,11 @@ class ExpressionsSchemaSuite extends QueryTest with SharedSparkSession { // SET spark.sql.parser.escapedStringLiterals=true example.split(" > ").tail.filterNot(_.trim.startsWith("SET")).take(1).foreach { case _ if funcName == "from_avro" || funcName == "to_avro" || - funcName == "from_protobuf" || funcName == "to_protobuf" => + funcName == "from_protobuf" || funcName == "to_protobuf" || funcName == "uniform" => // Skip running the example queries for the from_avro, to_avro, from_protobuf and // to_protobuf functions because these functions dynamically load the // AvroDataToCatalyst or CatalystDataToAvro classes which are not available in this - // test. + // test, or use random numbers. case exampleRe(sql, _) => val df = spark.sql(sql) val escapedSql = sql.replaceAll("\\|", "|") From 4b41b34e1939c7473e986d7145558ed8292adfcc Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 9 Sep 2024 20:39:21 -0700 Subject: [PATCH 09/27] commit --- .../src/test/resources/sql-functions/sql-expression-schema.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 63722064faf0f..7f997c5a54573 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -463,4 +463,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('b1b2b3c1c2','a/b/text()') | structb1b2b3c1c2, a/b/text()):array> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('12', 'sum(a/b)') | struct12, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('12', 'sum(a/b)') | struct12, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('bcc','a/c') | structbcc, a/c):string> | \ No newline at end of file From 54a1a2e74bce9961c00d3098a16f660838f3c378 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 11 Sep 2024 11:25:44 -0700 Subject: [PATCH 10/27] respond to code review comments --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/randomExpressions.scala | 93 ++++++++++++------- .../sql-functions/sql-expression-schema.md | 3 +- .../sql-tests/analyzer-results/random.sql.out | 37 ++++---- .../sql-tests/results/random.sql.out | 37 ++++---- .../spark/sql/ExpressionsSchemaSuite.scala | 4 +- 6 files changed, 101 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6f816aa4305f8..f1723a69c1fcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -386,7 +386,7 @@ object FunctionRegistry { expression[Randn]("randn"), expression[RandStr]("randstr"), expression[Stack]("stack"), - expression[Uniform]("uniform"), + expressionBuilder("uniform", UniformExpressionBuilder), expression[ZeroIfNull]("zeroifnull"), CaseWhen.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 85f43ffed33e7..679d6e99d1fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,16 +20,14 @@ package org.apache.spark.sql.catalyst.expressions import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern} -import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.random.XORShiftRandom @@ -201,16 +199,15 @@ object Randn { """, examples = """ Examples: - > SELECT _FUNC_(10, 20) > 0; + > SELECT _FUNC_(10, 20) > 0 AS result; true """, since = "4.0.0", group = "math_funcs") -case class Uniform(min: Expression, max: Expression, seed: Expression) - extends RuntimeReplaceable with TernaryLike[Expression] with ExpressionWithRandomSeed { - def this(min: Expression, max: Expression) = - this(min, max, Literal(Uniform.random.nextLong(), LongType)) +case class Uniform(min: Expression, max: Expression) + extends RuntimeReplaceable with BinaryLike[Expression] with ExpressionWithRandomSeed { + private var seed: Expression = Literal(Uniform.random.nextLong(), LongType) final override lazy val deterministic: Boolean = false override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED) @@ -230,7 +227,6 @@ case class Uniform(min: Expression, max: Expression, seed: Expression) } private def valid(e: Expression): Boolean = e.dataType match { - case _ if !e.foldable => false case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType => true case _ => false } @@ -242,34 +238,44 @@ case class Uniform(min: Expression, max: Expression, seed: Expression) override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess - Seq(min, max, seed).zipWithIndex.foreach { case (expr: Expression, index: Int) => - if (!valid(expr)) { - result = DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(index), - "requiredType" -> "constant value of integer or floating-point", - "inputSql" -> toSQLExpr(expr), - "inputType" -> toSQLType(expr.dataType))) - } + def requiredType = "integer or floating-point" + Seq((min, "min", 0), + (max, "max", 1), + (seed, "seed", 2)).foreach { + case (expr: Expression, name: String, index: Int) => + if (!expr.foldable && result == TypeCheckResult.TypeCheckSuccess) { + result = DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> name, + "inputType" -> requiredType, + "inputExpr" -> toSQLExpr(expr))) + } else if (!valid(expr) && result == TypeCheckResult.TypeCheckSuccess) { + result = DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(index), + "functionName" -> prettyName, + "requiredType" -> requiredType, + "inputSql" -> toSQLExpr(expr), + "inputType" -> toSQLType(expr.dataType))) + } } result } - override def first: Expression = min - override def second: Expression = max - override def third: Expression = seed + override def left: Expression = min + override def right: Expression = max override def seedExpression: Expression = seed - override def withNewSeed(newSeed: Long): Expression = - Uniform(min, max, Literal(newSeed, LongType)) - - override def withNewChildrenInternal( - newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - Uniform(newFirst, newSecond, newThird) + override def withNewSeed(newSeed: Long): Expression = { + val result = Uniform(min, max) + result.seed = Literal(newSeed, LongType) + result + } - override def toString: String = prettyName + truncatedString( - Seq(min, max), "(", ", ", ")", SQLConf.get.maxToStringFields) + override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = + Uniform(newFirst, newSecond) override def replacement: Expression = { def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to) @@ -286,6 +292,26 @@ case class Uniform(min: Expression, max: Expression, seed: Expression) object Uniform { lazy val random = new Random() + + def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform = { + val result = Uniform(min, max) + result.seed = seedExpression + result + } +} + +object UniformExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + expressions match { + case Seq(min, max) => + Uniform(min, max) + case Seq(min, max, seed) => + Uniform(min, max, seed) + case _ => + throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2, 3), numArgs) + } + } } @ExpressionDescription( @@ -298,7 +324,7 @@ object Uniform { examples = """ Examples: - > SELECT _FUNC_(3, 0); + > SELECT _FUNC_(3, 0) AS result; 8i7 """, since = "4.0.0", @@ -331,9 +357,6 @@ case class RandStr(length: Expression, override val seedExpression: Expression) override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = RandStr(newFirst, newSecond) - override def toString: String = prettyName + truncatedString( - Seq(length), "(", ", ", ")", SQLConf.get.maxToStringFields) - override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess Seq(length, seedExpression).zipWithIndex.foreach { case (expr: Expression, index: Int) => diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 7f997c5a54573..120dfa6c7edc3 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -265,7 +265,7 @@ | org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct | | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct | | org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | -| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) | struct | +| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) AS result | struct | | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | | org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | | org.apache.spark.sql.catalyst.expressions.RegExpCount | regexp_count | SELECT regexp_count('Steven Jones and Stephen Smith are the best players', 'Ste(v|ph)en') | struct | @@ -368,6 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | +| org.apache.spark.sql.catalyst.expressions.UniformExpressionBuilder | uniform | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index e975721c5c222..81afcc3eb54c4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -169,11 +169,12 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { + "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "first", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(NULL, 1, 0)\"" + "requiredType" : "integer or floating-point", + "sqlExpr" : "\"uniform(NULL, 1)\"" }, "queryContext" : [ { "objectType" : "", @@ -193,11 +194,12 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { + "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "second", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(0, NULL, 0)\"" + "requiredType" : "integer or floating-point", + "sqlExpr" : "\"uniform(0, NULL)\"" }, "queryContext" : [ { "objectType" : "", @@ -217,11 +219,12 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { + "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "third", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(0, 1, NULL)\"" + "requiredType" : "integer or floating-point", + "sqlExpr" : "\"uniform(0, 1)\"" }, "queryContext" : [ { "objectType" : "", @@ -238,14 +241,13 @@ SELECT uniform(10, 20, col) AS result FROM VALUES (0), (1), (2) tab(col) -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "third", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(10, 20, col)\"" + "inputExpr" : "\"col\"", + "inputName" : "seed", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(10, 20)\"" }, "queryContext" : [ { "objectType" : "", @@ -262,14 +264,13 @@ SELECT uniform(col, 10, 0) AS result FROM VALUES (0), (1), (2) tab(col) -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "first", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(col, 10, 0)\"" + "inputExpr" : "\"col\"", + "inputName" : "min", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(col, 10)\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 78eb0c86b9d4d..9e7cd60e91c19 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -215,11 +215,12 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { + "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "first", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(NULL, 1, 0)\"" + "requiredType" : "integer or floating-point", + "sqlExpr" : "\"uniform(NULL, 1)\"" }, "queryContext" : [ { "objectType" : "", @@ -241,11 +242,12 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { + "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "second", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(0, NULL, 0)\"" + "requiredType" : "integer or floating-point", + "sqlExpr" : "\"uniform(0, NULL)\"" }, "queryContext" : [ { "objectType" : "", @@ -267,11 +269,12 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { + "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "third", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(0, 1, NULL)\"" + "requiredType" : "integer or floating-point", + "sqlExpr" : "\"uniform(0, 1)\"" }, "queryContext" : [ { "objectType" : "", @@ -290,14 +293,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "third", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(10, 20, col)\"" + "inputExpr" : "\"col\"", + "inputName" : "seed", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(10, 20)\"" }, "queryContext" : [ { "objectType" : "", @@ -316,14 +318,13 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "first", - "requiredType" : "constant value of integer or floating-point", - "sqlExpr" : "\"uniform(col, 10, 0)\"" + "inputExpr" : "\"col\"", + "inputName" : "min", + "inputType" : "integer or floating-point", + "sqlExpr" : "\"uniform(col, 10)\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala index f782d28b707ba..8c0231fddf39f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionsSchemaSuite.scala @@ -118,11 +118,11 @@ class ExpressionsSchemaSuite extends QueryTest with SharedSparkSession { // SET spark.sql.parser.escapedStringLiterals=true example.split(" > ").tail.filterNot(_.trim.startsWith("SET")).take(1).foreach { case _ if funcName == "from_avro" || funcName == "to_avro" || - funcName == "from_protobuf" || funcName == "to_protobuf" || funcName == "uniform" => + funcName == "from_protobuf" || funcName == "to_protobuf" => // Skip running the example queries for the from_avro, to_avro, from_protobuf and // to_protobuf functions because these functions dynamically load the // AvroDataToCatalyst or CatalystDataToAvro classes which are not available in this - // test, or use random numbers. + // test. case exampleRe(sql, _) => val df = spark.sql(sql) val escapedSql = sql.replaceAll("\\|", "|") From c37207501e9dcef8c876d44febcf53b9990715c2 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 11 Sep 2024 15:26:07 -0700 Subject: [PATCH 11/27] fix function description --- .../expressions/randomExpressions.scala | 28 +++++++++---------- .../sql-functions/sql-expression-schema.md | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 679d6e99d1fc8..945b35131b8e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -204,6 +204,20 @@ object Randn { """, since = "4.0.0", group = "math_funcs") +object UniformExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + val numArgs = expressions.length + expressions match { + case Seq(min, max) => + Uniform(min, max) + case Seq(min, max, seed) => + Uniform(min, max, seed) + case _ => + throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2, 3), numArgs) + } + } +} + case class Uniform(min: Expression, max: Expression) extends RuntimeReplaceable with BinaryLike[Expression] with ExpressionWithRandomSeed { @@ -300,20 +314,6 @@ object Uniform { } } -object UniformExpressionBuilder extends ExpressionBuilder { - override def build(funcName: String, expressions: Seq[Expression]): Expression = { - val numArgs = expressions.length - expressions match { - case Seq(min, max) => - Uniform(min, max) - case Seq(min, max, seed) => - Uniform(min, max, seed) - case _ => - throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2, 3), numArgs) - } - } -} - @ExpressionDescription( usage = """ _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 120dfa6c7edc3..ccdbb544b3608 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -368,7 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | -| org.apache.spark.sql.catalyst.expressions.UniformExpressionBuilder | uniform | N/A | N/A | +| org.apache.spark.sql.catalyst.expressions.UniformExpressionBuilder | uniform | SELECT uniform(10, 20) > 0 AS result | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | From 515b046cf5e011fdd4a6fb94d933d1aebf583b3d Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 11 Sep 2024 20:16:38 -0700 Subject: [PATCH 12/27] fix test --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 945b35131b8e2..01a37f58e73bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -269,7 +269,6 @@ case class Uniform(min: Expression, max: Expression) errorSubClass = "UNEXPECTED_INPUT_TYPE", messageParameters = Map( "paramIndex" -> ordinalNumber(index), - "functionName" -> prettyName, "requiredType" -> requiredType, "inputSql" -> toSQLExpr(expr), "inputType" -> toSQLType(expr.dataType))) From 350412470cf33a697984cd8c2060a537b5aefe1b Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 11 Sep 2024 20:27:31 -0700 Subject: [PATCH 13/27] fix test --- .../test/resources/sql-tests/analyzer-results/random.sql.out | 3 --- sql/core/src/test/resources/sql-tests/results/random.sql.out | 3 --- 2 files changed, 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 81afcc3eb54c4..49ac259c01c63 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -169,7 +169,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "first", @@ -194,7 +193,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "second", @@ -219,7 +217,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "third", diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 9e7cd60e91c19..0decea059caf2 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -215,7 +215,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "first", @@ -242,7 +241,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "second", @@ -269,7 +267,6 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", "sqlState" : "42K09", "messageParameters" : { - "functionName" : "uniform", "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "third", From 068331872a7bc50d6ab733b06b21fdcd889e4faf Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Thu, 12 Sep 2024 15:03:48 -0700 Subject: [PATCH 14/27] respond to code review comments respond to code review comments respond to code review comments respond to code review comments --- .../catalyst/analysis/FunctionRegistry.scala | 2 +- .../expressions/randomExpressions.scala | 75 +++++++------------ .../sql-functions/sql-expression-schema.md | 2 +- .../sql-tests/analyzer-results/random.sql.out | 10 +-- .../sql-tests/results/random.sql.out | 10 +-- 5 files changed, 40 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f1723a69c1fcf..6f816aa4305f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -386,7 +386,7 @@ object FunctionRegistry { expression[Randn]("randn"), expression[RandStr]("randstr"), expression[Stack]("stack"), - expressionBuilder("uniform", UniformExpressionBuilder), + expression[Uniform]("uniform"), expression[ZeroIfNull]("zeroifnull"), CaseWhen.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 01a37f58e73bc..b99893624dc92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedSeed} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -39,8 +39,7 @@ import org.apache.spark.util.random.XORShiftRandom * * Since this expression is stateful, it cannot be a case object. */ -abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic - with ExpressionWithRandomSeed { +trait RDG extends Expression with ExpressionWithRandomSeed { /** * Record ID within each partition. By being transient, the Random Number Generator is * reset every time we serialize and deserialize and initialize it. @@ -49,12 +48,6 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm override def stateful: Boolean = true - override protected def initializeInternal(partitionIndex: Int): Unit = { - rng = new XORShiftRandom(seed + partitionIndex) - } - - override def seedExpression: Expression = child - @transient protected lazy val seed: Long = seedExpression match { case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int] case e if e.dataType == LongType => e.eval().asInstanceOf[Long] @@ -63,6 +56,15 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm override def nullable: Boolean = false override def dataType: DataType = DoubleType +} + +abstract class NondeterministicUnaryRDG + extends RDG with UnaryLike[Expression] with Nondeterministic with ExpectsInputTypes { + override def seedExpression: Expression = child + + override protected def initializeInternal(partitionIndex: Int): Unit = { + rng = new XORShiftRandom(seed + partitionIndex) + } override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType)) } @@ -105,7 +107,7 @@ private[catalyst] object ExpressionWithRandomSeed { since = "1.5.0", group = "math_funcs") // scalastyle:on line.size.limit -case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG { +case class Rand(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG { def this() = this(UnresolvedSeed, true) @@ -156,7 +158,7 @@ object Rand { since = "1.5.0", group = "math_funcs") // scalastyle:on line.size.limit -case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG { +case class Randn(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG { def this() = this(UnresolvedSeed, true) @@ -199,29 +201,16 @@ object Randn { """, examples = """ Examples: - > SELECT _FUNC_(10, 20) > 0 AS result; + > SELECT _FUNC_(10, 20, 0) > 0 AS result; true """, since = "4.0.0", group = "math_funcs") -object UniformExpressionBuilder extends ExpressionBuilder { - override def build(funcName: String, expressions: Seq[Expression]): Expression = { - val numArgs = expressions.length - expressions match { - case Seq(min, max) => - Uniform(min, max) - case Seq(min, max, seed) => - Uniform(min, max, seed) - case _ => - throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2, 3), numArgs) - } - } -} - -case class Uniform(min: Expression, max: Expression) - extends RuntimeReplaceable with BinaryLike[Expression] with ExpressionWithRandomSeed { +case class Uniform(min: Expression, max: Expression, seedExpression: Expression) + extends RuntimeReplaceable with TernaryLike[Expression] with RDG { + def this(min: Expression, max: Expression) = + this(min, max, Literal(Uniform.random.nextLong(), LongType)) - private var seed: Expression = Literal(Uniform.random.nextLong(), LongType) final override lazy val deterministic: Boolean = false override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED) @@ -255,7 +244,7 @@ case class Uniform(min: Expression, max: Expression) def requiredType = "integer or floating-point" Seq((min, "min", 0), (max, "max", 1), - (seed, "seed", 2)).foreach { + (seedExpression, "seed", 2)).foreach { case (expr: Expression, name: String, index: Int) => if (!expr.foldable && result == TypeCheckResult.TypeCheckSuccess) { result = DataTypeMismatch( @@ -277,18 +266,16 @@ case class Uniform(min: Expression, max: Expression) result } - override def left: Expression = min - override def right: Expression = max + override def first: Expression = min + override def second: Expression = max + override def third: Expression = seedExpression - override def seedExpression: Expression = seed - override def withNewSeed(newSeed: Long): Expression = { - val result = Uniform(min, max) - result.seed = Literal(newSeed, LongType) - result - } + override def withNewSeed(newSeed: Long): Expression = + Uniform(min, max, Literal(newSeed, LongType)) - override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = - Uniform(newFirst, newSecond) + override def withNewChildrenInternal( + newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + Uniform(newFirst, newSecond, newThird) override def replacement: Expression = { def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to) @@ -305,12 +292,6 @@ case class Uniform(min: Expression, max: Expression) object Uniform { lazy val random = new Random() - - def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform = { - val result = Uniform(min, max) - result.seed = seedExpression - result - } } @ExpressionDescription( diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index ccdbb544b3608..90dbd55858b22 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -368,7 +368,7 @@ | org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct | | org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> | | org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct | -| org.apache.spark.sql.catalyst.expressions.UniformExpressionBuilder | uniform | SELECT uniform(10, 20) > 0 AS result | struct | +| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(10, 20, 0) > 0 AS result | struct | | org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | | org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 49ac259c01c63..0c858592bf256 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -173,7 +173,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"VOID\"", "paramIndex" : "first", "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(NULL, 1)\"" + "sqlExpr" : "\"uniform(NULL, 1, 0)\"" }, "queryContext" : [ { "objectType" : "", @@ -197,7 +197,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"VOID\"", "paramIndex" : "second", "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, NULL)\"" + "sqlExpr" : "\"uniform(0, NULL, 0)\"" }, "queryContext" : [ { "objectType" : "", @@ -221,7 +221,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"VOID\"", "paramIndex" : "third", "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, 1)\"" + "sqlExpr" : "\"uniform(0, 1, NULL)\"" }, "queryContext" : [ { "objectType" : "", @@ -244,7 +244,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputExpr" : "\"col\"", "inputName" : "seed", "inputType" : "integer or floating-point", - "sqlExpr" : "\"uniform(10, 20)\"" + "sqlExpr" : "\"uniform(10, 20, col)\"" }, "queryContext" : [ { "objectType" : "", @@ -267,7 +267,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputExpr" : "\"col\"", "inputName" : "min", "inputType" : "integer or floating-point", - "sqlExpr" : "\"uniform(col, 10)\"" + "sqlExpr" : "\"uniform(col, 10, 0)\"" }, "queryContext" : [ { "objectType" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 0decea059caf2..c84ceec75ca99 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -219,7 +219,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"VOID\"", "paramIndex" : "first", "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(NULL, 1)\"" + "sqlExpr" : "\"uniform(NULL, 1, 0)\"" }, "queryContext" : [ { "objectType" : "", @@ -245,7 +245,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"VOID\"", "paramIndex" : "second", "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, NULL)\"" + "sqlExpr" : "\"uniform(0, NULL, 0)\"" }, "queryContext" : [ { "objectType" : "", @@ -271,7 +271,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputType" : "\"VOID\"", "paramIndex" : "third", "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, 1)\"" + "sqlExpr" : "\"uniform(0, 1, NULL)\"" }, "queryContext" : [ { "objectType" : "", @@ -296,7 +296,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputExpr" : "\"col\"", "inputName" : "seed", "inputType" : "integer or floating-point", - "sqlExpr" : "\"uniform(10, 20)\"" + "sqlExpr" : "\"uniform(10, 20, col)\"" }, "queryContext" : [ { "objectType" : "", @@ -321,7 +321,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputExpr" : "\"col\"", "inputName" : "min", "inputType" : "integer or floating-point", - "sqlExpr" : "\"uniform(col, 10)\"" + "sqlExpr" : "\"uniform(col, 10, 0)\"" }, "queryContext" : [ { "objectType" : "", From 7ec25a870ba8474f7309b4ec364f8cf0ec304449 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 13 Sep 2024 11:19:52 -0700 Subject: [PATCH 15/27] respond to code review comments --- .../expressions/randomExpressions.scala | 132 ++++++++++-------- .../catalyst/expressions/RandomSuite.scala | 24 ++++ .../sql-tests/analyzer-results/random.sql.out | 88 +++--------- .../sql-tests/results/random.sql.out | 94 +++---------- 4 files changed, 132 insertions(+), 206 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index b99893624dc92..e4f0dd4b050af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import scala.util.Random - +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -208,8 +207,7 @@ object Randn { group = "math_funcs") case class Uniform(min: Expression, max: Expression, seedExpression: Expression) extends RuntimeReplaceable with TernaryLike[Expression] with RDG { - def this(min: Expression, max: Expression) = - this(min, max, Literal(Uniform.random.nextLong(), LongType)) + def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed) final override lazy val deterministic: Boolean = false override val nodePatterns: Seq[TreePattern] = @@ -219,21 +217,23 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) val first = min.dataType val second = max.dataType (min.dataType, max.dataType) match { - case _ if !valid(min) || !valid(max) => NullType - case (_, LongType) | (LongType, _) if Seq(first, second).forall(integer) => LongType - case (_, IntegerType) | (IntegerType, _) if Seq(first, second).forall(integer) => IntegerType - case (_, ShortType) | (ShortType, _) if Seq(first, second).forall(integer) => ShortType + case _ if !seedExpression.resolved || seedExpression.dataType == NullType => + NullType + case (_, NullType) | (NullType, _) => NullType + case (_, LongType) | (LongType, _) + if Seq(first, second).forall(integer) => LongType + case (_, IntegerType) | (IntegerType, _) + if Seq(first, second).forall(integer) => IntegerType + case (_, ShortType) | (ShortType, _) + if Seq(first, second).forall(integer) => ShortType case (_, DoubleType) | (DoubleType, _) => DoubleType case (_, FloatType) | (FloatType, _) => FloatType - case _ => NullType + case _ => + throw SparkException.internalError( + s"Unexpected argument data types: ${min.dataType}, ${max.dataType}") } } - private def valid(e: Expression): Boolean = e.dataType match { - case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType => true - case _ => false - } - private def integer(t: DataType): Boolean = t match { case _: ShortType | _: IntegerType | _: LongType => true case _ => false @@ -246,21 +246,26 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) (max, "max", 1), (seedExpression, "seed", 2)).foreach { case (expr: Expression, name: String, index: Int) => - if (!expr.foldable && result == TypeCheckResult.TypeCheckSuccess) { - result = DataTypeMismatch( - errorSubClass = "NON_FOLDABLE_INPUT", - messageParameters = Map( - "inputName" -> name, - "inputType" -> requiredType, - "inputExpr" -> toSQLExpr(expr))) - } else if (!valid(expr) && result == TypeCheckResult.TypeCheckSuccess) { - result = DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(index), - "requiredType" -> requiredType, - "inputSql" -> toSQLExpr(expr), - "inputType" -> toSQLType(expr.dataType))) + if (result == TypeCheckResult.TypeCheckSuccess) { + if (!expr.foldable) { + result = DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> name, + "inputType" -> requiredType, + "inputExpr" -> toSQLExpr(expr))) + } else expr.dataType match { + case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | + _: NullType => + case _ => + result = DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(index), + "requiredType" -> requiredType, + "inputSql" -> toSQLExpr(expr), + "inputType" -> toSQLType(expr.dataType))) + } } } result @@ -278,22 +283,22 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) Uniform(newFirst, newSecond, newThird) override def replacement: Expression = { - def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to) - cast(Add( - cast(min, DoubleType), - Multiply( - Subtract( - cast(max, DoubleType), - cast(min, DoubleType)), - Rand(seed))), - dataType) + if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) { + Literal(null) + } else { + def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to) + cast(Add( + cast(min, DoubleType), + Multiply( + Subtract( + cast(max, DoubleType), + cast(min, DoubleType)), + Rand(seed))), + dataType) + } } } -object Uniform { - lazy val random = new Random() -} - @ExpressionDescription( usage = """ _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen @@ -311,7 +316,7 @@ object Uniform { group = "string_funcs") case class RandStr(length: Expression, override val seedExpression: Expression) extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic { - def this(length: Expression) = this(length, Literal(Uniform.random.nextLong(), LongType)) + def this(length: Expression) = this(length, UnresolvedSeed) override def nullable: Boolean = false override def dataType: DataType = StringType @@ -339,22 +344,31 @@ case class RandStr(length: Expression, override val seedExpression: Expression) override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess - Seq(length, seedExpression).zipWithIndex.foreach { case (expr: Expression, index: Int) => - val valid = expr.dataType match { - case _ if !expr.foldable => false - case _: ShortType | _: IntegerType => true - case _: LongType if index == 1 => true - case _ => false - } - if (!valid) { - result = DataTypeMismatch( - errorSubClass = "UNEXPECTED_INPUT_TYPE", - messageParameters = Map( - "paramIndex" -> ordinalNumber(index), - "requiredType" -> "constant value of INT or SMALLINT", - "inputSql" -> toSQLExpr(expr), - "inputType" -> toSQLType(expr.dataType))) - } + def requiredType = "INT or SMALLINT" + Seq((length, "length", 0), + (seedExpression, "seedExpression", 1)).foreach { + case (expr: Expression, name: String, index: Int) => + if (result == TypeCheckResult.TypeCheckSuccess) { + if (!expr.foldable) { + result = DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> name, + "inputType" -> requiredType, + "inputExpr" -> toSQLExpr(expr))) + } else expr.dataType match { + case _: ShortType | _: IntegerType => + case _: LongType if index == 1 => + case _ => + result = DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(index), + "requiredType" -> requiredType, + "inputSql" -> toSQLExpr(expr), + "inputType" -> toSQLType(expr.dataType))) + } + } } result } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 2aa53f581555f..7e95df86eee05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.{IntegerType, LongType} class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -41,4 +42,27 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { assert(Rand(Literal(1L), false).sql === "rand(1L)") assert(Randn(Literal(1L), false).sql === "randn(1L)") } + + test("SPARK-49505: Test the RANDSTR and UNIFORM SQL functions without codegen") { + // Note that we use a seed of zero in these tests to keep the results deterministic. + def testRandStr(first: Any, result: Any): Unit = { + checkEvaluationWithoutCodegen( + RandStr(Literal(first), Literal(0)), CatalystTypeConverters.convertToCatalyst(result)) + } + testRandStr(1, "8") + testRandStr(5, "8i70B") + testRandStr(10, "8i70BBEJ6A") + testRandStr(10L, "8i70BBEJ6A") + + def testUniform(first: Any, second: Any, result: Any): Unit = { + checkEvaluationWithoutCodegen( + Uniform(Literal(first), Literal(second), Literal(0)), + CatalystTypeConverters.convertToCatalyst(result)) + } + testUniform(0, 1, 0) + testUniform(0, 10, 7) + testUniform(0L, 10L, 7L) + testUniform(10.0F, 20.0F, 17.604954F) + testUniform(10L, 20.0F, 17.604954F) + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 0c858592bf256..133cd6a60a4fb 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -164,73 +164,19 @@ SELECT uniform(10, 20.0F) IS NOT NULL AS result -- !query SELECT uniform(NULL, 1, 0) AS result -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"NULL\"", - "inputType" : "\"VOID\"", - "paramIndex" : "first", - "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(NULL, 1, 0)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 26, - "fragment" : "uniform(NULL, 1, 0)" - } ] -} +[Analyzer test output redacted due to nondeterminism] -- !query SELECT uniform(0, NULL, 0) AS result -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"NULL\"", - "inputType" : "\"VOID\"", - "paramIndex" : "second", - "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, NULL, 0)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 26, - "fragment" : "uniform(0, NULL, 0)" - } ] -} +[Analyzer test output redacted due to nondeterminism] -- !query SELECT uniform(0, 1, NULL) AS result -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"NULL\"", - "inputType" : "\"VOID\"", - "paramIndex" : "third", - "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, 1, NULL)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 26, - "fragment" : "uniform(0, 1, NULL)" - } ] -} +[Analyzer test output redacted due to nondeterminism] -- !query @@ -372,7 +318,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"10\"", "inputType" : "\"BIGINT\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, 0)\"" }, "queryContext" : [ { @@ -396,7 +342,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"10.0\"", "inputType" : "\"FLOAT\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10.0, 0)\"" }, "queryContext" : [ { @@ -420,7 +366,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"10.0\"", "inputType" : "\"DOUBLE\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10.0, 0)\"" }, "queryContext" : [ { @@ -444,7 +390,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(NULL, 0)\"" }, "queryContext" : [ { @@ -468,7 +414,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "second", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(0, NULL)\"" }, "queryContext" : [ { @@ -486,13 +432,12 @@ SELECT randstr(col, 0) AS result FROM VALUES (0), (1), (2) tab(col) -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "inputExpr" : "\"col\"", + "inputName" : "length", + "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, "queryContext" : [ { @@ -510,13 +455,12 @@ SELECT randstr(10, col) AS result FROM VALUES (0), (1), (2) tab(col) -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "second", - "requiredType" : "constant value of INT or SMALLINT", + "inputExpr" : "\"col\"", + "inputName" : "seedExpression", + "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index c84ceec75ca99..c3ff36f54156c 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -208,79 +208,25 @@ true -- !query SELECT uniform(NULL, 1, 0) AS result -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"NULL\"", - "inputType" : "\"VOID\"", - "paramIndex" : "first", - "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(NULL, 1, 0)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 26, - "fragment" : "uniform(NULL, 1, 0)" - } ] -} +NULL -- !query SELECT uniform(0, NULL, 0) AS result -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"NULL\"", - "inputType" : "\"VOID\"", - "paramIndex" : "second", - "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, NULL, 0)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 26, - "fragment" : "uniform(0, NULL, 0)" - } ] -} +NULL -- !query SELECT uniform(0, 1, NULL) AS result -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", - "messageParameters" : { - "inputSql" : "\"NULL\"", - "inputType" : "\"VOID\"", - "paramIndex" : "third", - "requiredType" : "integer or floating-point", - "sqlExpr" : "\"uniform(0, 1, NULL)\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 8, - "stopIndex" : 26, - "fragment" : "uniform(0, 1, NULL)" - } ] -} +NULL -- !query @@ -446,7 +392,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"10\"", "inputType" : "\"BIGINT\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, 0)\"" }, "queryContext" : [ { @@ -472,7 +418,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"10.0\"", "inputType" : "\"FLOAT\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10.0, 0)\"" }, "queryContext" : [ { @@ -498,7 +444,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"10.0\"", "inputType" : "\"DOUBLE\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10.0, 0)\"" }, "queryContext" : [ { @@ -524,7 +470,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(NULL, 0)\"" }, "queryContext" : [ { @@ -550,7 +496,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "inputSql" : "\"NULL\"", "inputType" : "\"VOID\"", "paramIndex" : "second", - "requiredType" : "constant value of INT or SMALLINT", + "requiredType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(0, NULL)\"" }, "queryContext" : [ { @@ -570,13 +516,12 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "first", - "requiredType" : "constant value of INT or SMALLINT", + "inputExpr" : "\"col\"", + "inputName" : "length", + "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, "queryContext" : [ { @@ -596,13 +541,12 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", "sqlState" : "42K09", "messageParameters" : { - "inputSql" : "\"col\"", - "inputType" : "\"INT\"", - "paramIndex" : "second", - "requiredType" : "constant value of INT or SMALLINT", + "inputExpr" : "\"col\"", + "inputName" : "seedExpression", + "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, "queryContext" : [ { From 1f5e866874f28f1b416301e31e622717c92767d3 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 13 Sep 2024 17:23:44 -0700 Subject: [PATCH 16/27] fix RandomSuite --- .../spark/sql/catalyst/expressions/randomExpressions.scala | 6 +++++- .../apache/spark/sql/catalyst/expressions/RandomSuite.scala | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index e4f0dd4b050af..6fb991efc06bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -374,7 +374,11 @@ case class RandStr(length: Expression, override val seedExpression: Expression) } override def evalInternal(input: InternalRow): Any = { - val numChars: Int = length.eval(input).asInstanceOf[Int] + val numChars: Int = length.eval(input) match { + case i: Int => i + case n: Long => n.toInt + case s: Short => s.toInt + } val bytes = new Array[Byte](numChars) (0 until numChars).foreach { i => val num = (rng.nextInt() % 30).abs diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 7e95df86eee05..013f083983d0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -56,7 +56,7 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { def testUniform(first: Any, second: Any, result: Any): Unit = { checkEvaluationWithoutCodegen( - Uniform(Literal(first), Literal(second), Literal(0)), + Uniform(Literal(first), Literal(second), Literal(0)).replacement, CatalystTypeConverters.convertToCatalyst(result)) } testUniform(0, 1, 0) From 8b64f33e3ffe5d839f10b0feb5786addcd3ff371 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 16 Sep 2024 16:22:31 +0200 Subject: [PATCH 17/27] respond to code review comments --- .../expressions/randomExpressions.scala | 21 +++++++++---------- .../catalyst/expressions/RandomSuite.scala | 8 +++---- .../sql-tests/results/random.sql.out | 14 ++++++------- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6fb991efc06bf..b67241756c077 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -374,21 +374,20 @@ case class RandStr(length: Expression, override val seedExpression: Expression) } override def evalInternal(input: InternalRow): Any = { - val numChars: Int = length.eval(input) match { - case i: Int => i - case n: Long => n.toInt - case s: Short => s.toInt - } + val numChars = length.eval(input).asInstanceOf[Number].intValue() val bytes = new Array[Byte](numChars) (0 until numChars).foreach { i => - val num = (rng.nextInt() % 30).abs + // We generate a random number between 0 and 61, inclusive. Between the 62 different choices + // we choose 0-9, a-z, or A-Z, where each category comprises 10 choices, 26 choices, or 26 + // choices, respectively (10 + 26 + 26 = 62). + val num = (rng.nextInt() % 62).abs num match { case _ if num < 10 => bytes.update(i, ('0' + num).toByte) - case _ if num < 20 => + case _ if num < 36 => bytes.update(i, ('a' + num - 10).toByte) case _ => - bytes.update(i, ('A' + num - 20).toByte) + bytes.update(i, ('A' + num - 36).toByte) } } val result: UTF8String = UTF8String.fromBytes(bytes.toArray) @@ -407,13 +406,13 @@ case class RandStr(length: Expression, override val seedExpression: Expression) |int length = (int)(${eval.value}); |char[] chars = new char[length]; |for (int i = 0; i < length; i++) { - | int v = Math.abs($rngTerm.nextInt() % 30); + | int v = Math.abs($rngTerm.nextInt() % 62); | if (v < 10) { | chars[i] = (char)('0' + v); - | } else if (v < 20) { + | } else if (v < 36) { | chars[i] = (char)('a' + (v - 10)); | } else { - | chars[i] = (char)('A' + (v - 20)); + | chars[i] = (char)('A' + (v - 36)); | } |} |UTF8String ${ev.value} = UTF8String.fromString(new String(chars)); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 013f083983d0c..2d58d9d3136aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -49,10 +49,10 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluationWithoutCodegen( RandStr(Literal(first), Literal(0)), CatalystTypeConverters.convertToCatalyst(result)) } - testRandStr(1, "8") - testRandStr(5, "8i70B") - testRandStr(10, "8i70BBEJ6A") - testRandStr(10L, "8i70BBEJ6A") + testRandStr(1, "c") + testRandStr(5, "ceV0P") + testRandStr(10, "ceV0PXaR2I") + testRandStr(10L, "ceV0PXaR2I") def testUniform(first: Any, second: Any, result: Any): Unit = { checkEvaluationWithoutCodegen( diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index c3ff36f54156c..0b4e5e078ee15 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -334,7 +334,7 @@ SELECT randstr(1, 0) AS result -- !query schema struct -- !query output -8 +c -- !query @@ -342,7 +342,7 @@ SELECT randstr(5, 0) AS result -- !query schema struct -- !query output -8i70B +ceV0P -- !query @@ -350,7 +350,7 @@ SELECT randstr(10, 0) AS result -- !query schema struct -- !query output -8i70BBEJ6A +ceV0PXaR2I -- !query @@ -358,7 +358,7 @@ SELECT randstr(10S, 0) AS result -- !query schema struct -- !query output -8i70BBEJ6A +ceV0PXaR2I -- !query @@ -366,9 +366,9 @@ SELECT randstr(10, 0) AS result FROM VALUES (0), (1), (2) tab(col) -- !query schema struct -- !query output -2iAhij5i3F -3AD1fIHj7B -8i70BBEJ6A +ceV0PXaR2I +fYxVfArnv7 +iSIv0VT2XL -- !query From 0313fbe578e5bc4730bb0533696bad9184fe16b2 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 16 Sep 2024 18:29:27 +0200 Subject: [PATCH 18/27] commit --- .../reference/pyspark.sql/functions.rst | 2 + .../pyspark/sql/connect/functions/builtin.py | 20 ++++++ python/pyspark/sql/functions/builtin.py | 66 +++++++++++++++++++ python/pyspark/sql/tests/test_functions.py | 11 +++- 4 files changed, 98 insertions(+), 1 deletion(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index dc4329c603241..0b2769a7a0b59 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -148,6 +148,7 @@ Mathematical Functions try_multiply try_subtract unhex + uniform width_bucket @@ -189,6 +190,7 @@ String Functions overlay position printf + randstr regexp_count regexp_extract regexp_extract_all diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index ad6dbbf58e48d..18bb37acb3550 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1008,6 +1008,16 @@ def unhex(col: "ColumnOrName") -> Column: unhex.__doc__ = pysparkfuncs.unhex.__doc__ +def uniform(min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: + if seed is None: + return _invoke_function_over_columns("uniform", min, max) + else: + return _invoke_function_over_columns("uniform", min, max, seed) + + +uniform.__doc__ = pysparkfuncs.uniform.__doc__ + + def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", FutureWarning) return approx_count_distinct(col, rsd) @@ -2571,6 +2581,16 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__ +def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: + if seed is None: + return _invoke_function_over_columns("randstr", length) + else: + return _invoke_function_over_columns("randstr", length, seed) + + +randstr.__doc__ = pysparkfuncs.randstr.__doc__ + + def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_count", str, regexp) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 344ba8d009ac4..5a013378d7b17 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11841,6 +11841,37 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_like", str, regexp) +@_try_remote_functions +def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: + """Returns a string of the specified length whose characters are chosen uniformly at random from + the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length + must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + + .. versionadded:: 4.0.0 + + Parameters + ---------- + length : :class:`~pyspark.sql.Column` or int + Number of characters in the string to generate. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random string with the specified length. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']).select(randstr(5).alias('x')).select(isnull('x')).collect() + [Row(false)] + """ + if seed is None: + return _invoke_function_over_columns("randstr", length) + else: + return _invoke_function_over_columns("randstr", length, seed) + + @_try_remote_functions def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched @@ -12207,6 +12238,41 @@ def unhex(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("unhex", col) +@_try_remote_functions +def uniform(min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: + """Returns a random value with independent and identically distributed (i.i.d.) values with the + specified range of numbers. The random seed is optional. The provided numbers specifying the + minimum and maximum values of the range must be constant. If both of these numbers are integers, + then the result will also be an integer. Otherwise if one or both of these are floating-point + numbers, then the result will also be a floating-point number. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + min : :class:`~pyspark.sql.Column`, int, or float + Minimum value in the range. + max : :class:`~pyspark.sql.Column`, int, or float + Maximum value in the range. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random number within the specified range. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']).select(uniform(0, 10).alias('x')).select(isnull('x')).collect() + [Row(false)] + """ + if seed is None: + return _invoke_function_over_columns("uniform", min, max) + else: + return _invoke_function_over_columns("uniform", min, max, seed) + + @_try_remote_functions def length(col: "ColumnOrName") -> Column: """Computes the character length of string data or number of bytes of binary data. diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f7f2485a43e16..672d685e809d5 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -29,7 +29,7 @@ from pyspark.sql import Row, Window, functions as F, types from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column -from pyspark.sql.functions.builtin import nullifzero, zeroifnull +from pyspark.sql.functions.builtin import isnull, nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy @@ -1603,6 +1603,15 @@ def test_nullifzero_zeroifnull(self): result = df.select(zeroifnull(df.a).alias("r")).collect() self.assertEqual([Row(r=0), Row(r=1)], result) + def test_randstr_uniform(self): + df = self.spark.createDataFrame([(0,), (1,)], ["a"]) + result = df.select(randstr(5).alias("x")).select(isnull("x")).collect() + self.assertEqual([Row(x=False)], result) + + df = self.spark.createDataFrame([(None,), (1,)], ["a"]) + result = df.select(uniform(0, 10).alias("x")).select(isnull("x")).collect() + self.assertEqual([Row(x=False)], result) + class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): pass From 48f1ddea4c2111c1fccab5c46873e5d9c536558e Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Tue, 17 Sep 2024 18:13:04 +0200 Subject: [PATCH 19/27] sync --- .../org/apache/spark/sql/functions.scala | 22 +++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 24 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 86f8923f36b40..2170591424f8f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1884,6 +1884,16 @@ object functions { */ def randn(): Column = randn(SparkClassUtils.random.nextLong) + /** + * Returns a string of the specified length whose characters are chosen uniformly at random from + * the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length + * must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + * + * @group string_funcs + * @since 4.0.0 + */ + def randstr(length: Column, seed: Column): Column = Column.fn("randstr", length, seed) + /** * Partition ID. * @@ -3728,6 +3738,18 @@ object functions { */ def stack(cols: Column*): Column = Column.fn("stack", cols: _*) + /** + * Returns a random value with independent and identically distributed (i.i.d.) values with the + * specified range of numbers. The random seed is optional. The provided numbers specifying the + * minimum and maximum values of the range must be constant. If both of these numbers are + * integers, then the result will also be an integer. Otherwise if one or both of these are + * floating-point numbers, then the result will also be a floating-point number. + * + * @group math_funcs + * @since 4.0.0 + */ + def uniform(min: Column, max: Column, seed: Column): Column = Column.fn("uniform", min, max, seed) + /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly * distributed values in [0, 1). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f16171940df21..ba8888b10d21f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -409,6 +409,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null))) } + test("randstr function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(randstr(lit(5), lit(0)).alias("x")).select(isnull(col("x"))), + Seq(Row(false))) + } + // TODO: add some error cases + } + + test("uniform function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(uniform(lit(0), lit(10), lit(0)).alias("x")).select(isnull(col("x"))), + Seq(Row(false))) + } + // TODO: add some error cases + } + test("zeroifnull function") { withTable("t") { // Here we exercise a non-nullable, non-foldable column. From f496e861ca49d3fe17fc4301579b6aa68582af4f Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 18 Sep 2024 17:58:43 +0200 Subject: [PATCH 20/27] commit --- .../pyspark/sql/connect/functions/builtin.py | 6 +- python/pyspark/sql/functions/builtin.py | 6 +- python/pyspark/sql/tests/test_functions.py | 26 ++++-- .../org/apache/spark/sql/functions.scala | 31 ++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 92 +++++++++++++++++-- 5 files changed, 139 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index f2f23dd7497f6..721a8bc5878d8 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1008,7 +1008,9 @@ def unhex(col: "ColumnOrName") -> Column: unhex.__doc__ = pysparkfuncs.unhex.__doc__ -def uniform(min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: +def uniform( + min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"] = None +) -> Column: if seed is None: return _invoke_function_over_columns("uniform", min, max) else: @@ -2588,7 +2590,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__ -def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: +def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"] = None) -> Column: if seed is None: return _invoke_function_over_columns("randstr", length) else: diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index cf65f05c45a07..33991659633e5 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11862,7 +11862,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: @_try_remote_functions -def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: +def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"] = None) -> Column: """Returns a string of the specified length whose characters are chosen uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). @@ -12259,7 +12259,9 @@ def unhex(col: "ColumnOrName") -> Column: @_try_remote_functions -def uniform(min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"]) -> Column: +def uniform( + min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"] = None +) -> Column: """Returns a random value with independent and identically distributed (i.i.d.) values with the specified range of numbers. The random seed is optional. The provided numbers specifying the minimum and maximum values of the range must be constant. If both of these numbers are integers, diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 88870e877b28a..f57c3418b105f 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -29,7 +29,7 @@ from pyspark.sql import Row, Window, functions as F, types from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column -from pyspark.sql.functions.builtin import isnull, nullifzero, randstr, uniform, zeroifnull +from pyspark.sql.functions.builtin import isnull, length, nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy @@ -1611,13 +1611,23 @@ def test_nullifzero_zeroifnull(self): self.assertEqual([Row(r=0), Row(r=1)], result) def test_randstr_uniform(self): - df = self.spark.createDataFrame([(0,), (1,)], ["a"]) - result = df.select(randstr(5).alias("x")).select(isnull("x")).collect() - self.assertEqual([Row(x=False)], result) - - df = self.spark.createDataFrame([(None,), (1,)], ["a"]) - result = df.select(uniform(0, 10).alias("x")).select(isnull("x")).collect() - self.assertEqual([Row(x=False)], result) + df = self.spark.createDataFrame([(0,)], ["a"]) + result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + # The random seed is optional. + result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + + df = self.spark.createDataFrame([(0,)], ["a"]) + result = ( + df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")) + .selectExpr("x > 5") + .collect() + ) + self.assertEqual([Row(True)], result) + # The random seed is optional. + result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect() + self.assertEqual([Row(True)], result) class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 2170591424f8f..dbf94b30c47e2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1886,8 +1886,18 @@ object functions { /** * Returns a string of the specified length whose characters are chosen uniformly at random from - * the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length - * must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + * the following pool of characters: 0-9, a-z, A-Z. The string length must be a constant + * two-byte or four-byte integer (SMALLINT or INT, respectively). + * + * @group string_funcs + * @since 4.0.0 + */ + def randstr(length: Column): Column = Column.fn("randstr", length) + + /** + * Returns a string of the specified length whose characters are chosen uniformly at random from + * the following pool of characters: 0-9, a-z, A-Z, with the chosen random seed. The string + * length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). * * @group string_funcs * @since 4.0.0 @@ -3740,7 +3750,19 @@ object functions { /** * Returns a random value with independent and identically distributed (i.i.d.) values with the - * specified range of numbers. The random seed is optional. The provided numbers specifying the + * specified range of numbers. The provided numbers specifying the minimum and maximum values of + * the range must be constant. If both of these numbers are integers, then the result will also + * be an integer. Otherwise if one or both of these are floating-point numbers, then the result + * will also be a floating-point number. + * + * @group math_funcs + * @since 4.0.0 + */ + def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max) + + /** + * Returns a random value with independent and identically distributed (i.i.d.) values with the + * specified range of numbers, with the chosen random seed. The provided numbers specifying the * minimum and maximum values of the range must be constant. If both of these numbers are * integers, then the result will also be an integer. Otherwise if one or both of these are * floating-point numbers, then the result will also be a floating-point number. @@ -3748,7 +3770,8 @@ object functions { * @group math_funcs * @since 4.0.0 */ - def uniform(min: Column, max: Column, seed: Column): Column = Column.fn("uniform", min, max, seed) + def uniform(min: Column, max: Column, seed: Column): Column = + Column.fn("uniform", min, max, seed) /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ba8888b10d21f..ca8fa92a90f2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -415,10 +415,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sql("insert into t values (0)") val df = sql("select col from t") checkAnswer( - df.select(randstr(lit(5), lit(0)).alias("x")).select(isnull(col("x"))), - Seq(Row(false))) + df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))), + Seq(Row(5))) + // The random seed is optional. + checkAnswer( + df.select(randstr(lit(5)).alias("x")).select(length(col("x"))), + Seq(Row(5))) } - // TODO: add some error cases + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = randstr(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"randstr(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "INT or SMALLINT"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = randstr(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "length", + "inputType" -> "INT or SMALLINT", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"randstr(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) } test("uniform function") { @@ -427,10 +467,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sql("insert into t values (0)") val df = sql("select col from t") checkAnswer( - df.select(uniform(lit(0), lit(10), lit(0)).alias("x")).select(isnull(col("x"))), - Seq(Row(false))) + df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) + // The random seed is optional. + checkAnswer( + df.select(uniform(lit(10), lit(20)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) } - // TODO: add some error cases + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = uniform(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"uniform(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "INT or SMALLINT"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = uniform(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "length", + "inputType" -> "INT or SMALLINT", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"uniform(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) } test("zeroifnull function") { From 160a5f44072fae857961b69c32db196e73453384 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Thu, 19 Sep 2024 12:11:39 +0200 Subject: [PATCH 21/27] respond to code review comments --- .../pyspark/sql/connect/functions/builtin.py | 12 +++-- python/pyspark/sql/functions/builtin.py | 26 +++++++--- .../expressions/randomExpressions.scala | 49 +++++++++++++++---- .../spark/sql/DataFrameFunctionsSuite.scala | 6 +-- 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 721a8bc5878d8..be889bb02164a 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1009,10 +1009,14 @@ def unhex(col: "ColumnOrName") -> Column: def uniform( - min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"] = None + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: if seed is None: - return _invoke_function_over_columns("uniform", min, max) + return _invoke_function_over_columns( + "uniform", min, max, lit(random.randint(0, sys.maxsize)) + ) else: return _invoke_function_over_columns("uniform", min, max, seed) @@ -2590,9 +2594,9 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__ -def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"] = None) -> Column: +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: if seed is None: - return _invoke_function_over_columns("randstr", length) + return _invoke_function_over_columns("randstr", length, lit(random.randint(0, sys.maxsize))) else: return _invoke_function_over_columns("randstr", length, seed) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 33991659633e5..991f3a4308831 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11862,7 +11862,7 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: @_try_remote_functions -def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"] = None) -> Column: +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: """Returns a string of the specified length whose characters are chosen uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). @@ -11883,8 +11883,14 @@ def randstr(length: "ColumnOrName", seed: Optional["ColumnOrName"] = None) -> Co Examples -------- - >>> spark.createDataFrame([('3',)], ['a']).select(randstr(5).alias('x')).select(isnull('x')).collect() - [Row(false)] + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([('3',)], ['a']) + >>> df.select(sf.randstr(5, 0).alias('result')).show() + +------+ + |result| + +------+ + | ceV0P| + +------+ """ if seed is None: return _invoke_function_over_columns("randstr", length) @@ -12260,7 +12266,9 @@ def unhex(col: "ColumnOrName") -> Column: @_try_remote_functions def uniform( - min: "ColumnOrName", max: "ColumnOrName", seed: Optional["ColumnOrName"] = None + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: """Returns a random value with independent and identically distributed (i.i.d.) values with the specified range of numbers. The random seed is optional. The provided numbers specifying the @@ -12286,8 +12294,14 @@ def uniform( Examples -------- - >>> spark.createDataFrame([('3',)], ['a']).select(uniform(0, 10).alias('x')).select(isnull('x')).collect() - [Row(false)] + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([('3',)], ['a']) + >>> df.select(sf.uniform(0, 10, 0).alias('result')).show() + +------+ + |result| + +------+ + | 7| + +------+ """ if seed is None: return _invoke_function_over_columns("uniform", min, max) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index ea9ca451c2cb1..e4c0a3ad8091e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -205,15 +205,18 @@ object Randn { """, since = "4.0.0", group = "math_funcs") -case class Uniform(min: Expression, max: Expression, seedExpression: Expression) +case class Uniform(min: Expression, max: Expression, seedExpression: Expression, hideSeed: Boolean) extends RuntimeReplaceable with TernaryLike[Expression] with RDG { - def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed) + def this(min: Expression, max: Expression) = + this(min, max, UnresolvedSeed, hideSeed = true) + def this(min: Expression, max: Expression, seedExpression: Expression) = + this(min, max, seedExpression, hideSeed = false) final override lazy val deterministic: Boolean = false override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED) - override val dataType: DataType = { + override def dataType: DataType = { val first = min.dataType val second = max.dataType (min.dataType, max.dataType) match { @@ -239,6 +242,10 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) case _ => false } + override def sql: String = { + s"uniform(${min.sql}, ${max.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})" + } + override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess def requiredType = "integer or floating-point" @@ -276,11 +283,11 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) override def third: Expression = seedExpression override def withNewSeed(newSeed: Long): Expression = - Uniform(min, max, Literal(newSeed, LongType)) + Uniform(min, max, Literal(newSeed, LongType), hideSeed) override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - Uniform(newFirst, newSecond, newThird) + Uniform(newFirst, newSecond, newThird, hideSeed) override def replacement: Expression = { if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) { @@ -299,6 +306,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) } } +object Uniform { + def apply(min: Expression, max: Expression): Uniform = + Uniform(min, max, UnresolvedSeed, hideSeed = true) + def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform = + Uniform(min, max, seedExpression, hideSeed = false) +} + @ExpressionDescription( usage = """ _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen @@ -314,9 +328,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) """, since = "4.0.0", group = "string_funcs") -case class RandStr(length: Expression, override val seedExpression: Expression) +case class RandStr( + length: Expression, override val seedExpression: Expression, hideSeed: Boolean) extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic { - def this(length: Expression) = this(length, UnresolvedSeed) + def this(length: Expression) = + this(length, UnresolvedSeed, hideSeed = true) + def this(length: Expression, seedExpression: Expression) = + this(length, seedExpression, hideSeed = false) override def nullable: Boolean = false override def dataType: DataType = StringType @@ -338,9 +356,14 @@ case class RandStr(length: Expression, override val seedExpression: Expression) rng = new XORShiftRandom(seed + partitionIndex) } - override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType)) + override def withNewSeed(newSeed: Long): Expression = + RandStr(length, Literal(newSeed, LongType), hideSeed) override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = - RandStr(newFirst, newSecond) + RandStr(newFirst, newSecond, hideSeed) + + override def sql: String = { + s"randstr(${length.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})" + } override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess @@ -421,3 +444,11 @@ case class RandStr(length: Expression, override val seedExpression: Expression) isNull = FalseLiteral) } } + +object RandStr { + def apply(length: Expression): RandStr = + RandStr(length, UnresolvedSeed, hideSeed = true) + def apply(length: Expression, seedExpression: Expression): RandStr = + RandStr(length, seedExpression, hideSeed = false) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ca8fa92a90f2d..3c35933f58a00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -485,7 +485,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "second", "inputSql" -> "\"a\"", "inputType" -> "\"STRING\"", - "requiredType" -> "INT or SMALLINT"), + "requiredType" -> "integer or floating-point"), context = ExpectedContext( contextType = QueryContextType.DataFrame, fragment = "uniform", @@ -499,8 +499,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "length", - "inputType" -> "INT or SMALLINT", + "inputName" -> "min", + "inputType" -> "integer or floating-point", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"uniform(a, 10)\""), context = ExpectedContext( From 795f5a6c415cf41fb28d410b78dfd7b9d409ffcc Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Thu, 19 Sep 2024 15:13:41 +0200 Subject: [PATCH 22/27] fix test --- python/pyspark/sql/functions/builtin.py | 10 ++++++---- python/pyspark/sql/tests/test_functions.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 991f3a4308831..e4c2cfc2f9745 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11885,12 +11885,13 @@ def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = Non -------- >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([('3',)], ['a']) - >>> df.select(sf.randstr(5, 0).alias('result')).show() + >>> df.select(sf.randstr(lit(5), lit(0)).alias('result')).show() +------+ |result| +------+ - | ceV0P| + | nurJI| +------+ + """ if seed is None: return _invoke_function_over_columns("randstr", length) @@ -12296,12 +12297,13 @@ def uniform( -------- >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([('3',)], ['a']) - >>> df.select(sf.uniform(0, 10, 0).alias('result')).show() + >>> df.select(sf.uniform(lit(0), lit(10), lit(0)).alias('result')).show() +------+ |result| +------+ - | 7| + | 2| +------+ + """ if seed is None: return _invoke_function_over_columns("uniform", min, max) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f57c3418b105f..4c7d512e5dde0 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -29,7 +29,7 @@ from pyspark.sql import Row, Window, functions as F, types from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column -from pyspark.sql.functions.builtin import isnull, length, nullifzero, randstr, uniform, zeroifnull +from pyspark.sql.functions.builtin import length, nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy From 6bb123f3a27519dc15d59a5c93e262284bb8afef Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Thu, 19 Sep 2024 17:23:05 +0200 Subject: [PATCH 23/27] fix --- python/pyspark/sql/functions/builtin.py | 4 ++-- python/pyspark/sql/tests/test_functions.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index e4c2cfc2f9745..0ed9014c5101d 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11889,7 +11889,7 @@ def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = Non +------+ |result| +------+ - | nurJI| + | ceV0P| +------+ """ @@ -12301,7 +12301,7 @@ def uniform( +------+ |result| +------+ - | 2| + | 7| +------+ """ diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 4c7d512e5dde0..a51156e895c62 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -29,7 +29,7 @@ from pyspark.sql import Row, Window, functions as F, types from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column -from pyspark.sql.functions.builtin import length, nullifzero, randstr, uniform, zeroifnull +from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy From 05a674c1e897a4f05c0307c756061d2e00a9d6a1 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 20 Sep 2024 13:59:59 +0200 Subject: [PATCH 24/27] commit --- .../sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 20 +++++++++++------- .../sql/catalyst/parser/AstBuilder.scala | 4 +++- .../sql-tests/inputs/pipe-operators.sql | 6 ++++++ .../sql/execution/SparkSqlParserSuite.scala | 21 ++++++++++++++++--- 5 files changed, 40 insertions(+), 12 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f13dde773496a..023c3384d560b 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1488,6 +1488,7 @@ version operatorPipeRightSide : selectClause + | joinRelation ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 08c5b3531b4c8..c93e54f834a3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -550,14 +550,18 @@ abstract class TypeCoercionBase { object CaseWhenCoercion extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => - val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) - maybeCommonType.map { commonType => - val newBranches = c.branches.map { case (condition, value) => - (condition, castIfNotSameType(value, commonType)) - } - val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) - CaseWhen(newBranches, newElseValue) - }.getOrElse(c) + convert(c) + } + + def convert(c: CaseWhen): CaseWhen = { + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) + maybeCommonType.map { commonType => + val newBranches = c.branches.map { case (condition, value) => + (condition, castIfNotSameType(value, commonType)) + } + val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) + CaseWhen(newBranches, newElseValue) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cb0e0e35c3704..d075e806be6a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5854,7 +5854,9 @@ class AstBuilder extends DataTypeAstBuilder windowClause = null, relation = left, isPipeOperatorSelect = true) - }.get + }.getOrElse(Option(ctx.joinRelation()).map { c => + withJoinRelation(c, left) + }.get) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 7d0966e7f2095..597aea73aa7cb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -95,6 +95,12 @@ table t table t |> select y, length(y) + sum(x) as result; +-- Joins are supported. +----------------------- + +-- Join operators: negative tests. +---------------------------------- + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a80444feb68ae..e6e83b4646499 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{JOIN, LOCAL_RELATION, PROJECT, TreePattern, UNRESOLVED_RELATION} import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -887,14 +887,29 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { // Basic selection. // Here we check that every parsed plan contains a projection and a source relation or // inline table. - def checkPipeSelect(query: String): Unit = { + def checkPlan(query: String, pattern: TreePattern): Unit = { val plan: LogicalPlan = parser.parsePlan(query) - assert(plan.containsPattern(PROJECT)) + assert(plan.containsPattern(pattern)) assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) } + def checkPipeSelect(query: String): Unit = checkPlan(query, PROJECT) checkPipeSelect("TABLE t |> SELECT 1 AS X") checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") + // Join operations. + def checkJoin(query: String): Unit = checkPlan(query, JOIN) + checkJoin( + """ + |SELECT * FROM VALUES + | ("dotNET", 15000, 48000, 22500), + | ("Java", 20000, 30000, NULL) + | AS courseEarnings(course, `2012`, `2013`, `2014`) + ||> INNER JOIN (SELECT * FROM VALUES + | ("dotNET", 15000, 48000, 22500), + | ("Java", 20000, 30000, NULL) + | AS otherCourseEarnings(course, `2012`, `2013`, `2014`)) + | USING (course) + |""".stripMargin) } } } From fd3262fc7e1840aa0289c677975340c992d3965d Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 20 Sep 2024 14:02:58 +0200 Subject: [PATCH 25/27] fix test --- python/pyspark/sql/functions/builtin.py | 28 ++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 0ed9014c5101d..ae0c7e54a170b 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11885,13 +11885,13 @@ def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = Non -------- >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([('3',)], ['a']) - >>> df.select(sf.randstr(lit(5), lit(0)).alias('result')).show() - +------+ - |result| - +------+ - | ceV0P| - +------+ - + >>> df = df.select(sf.randstr(lit(5), lit(0)).alias('result')) + >>> df.selectExpr("result != ''") + +------------+ + |result != ''| + +------------+ + | true| + +------------+ """ if seed is None: return _invoke_function_over_columns("randstr", length) @@ -12297,13 +12297,13 @@ def uniform( -------- >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([('3',)], ['a']) - >>> df.select(sf.uniform(lit(0), lit(10), lit(0)).alias('result')).show() - +------+ - |result| - +------+ - | 7| - +------+ - + >>> df = df.select(sf.uniform(lit(0), lit(10), lit(0)).alias('result')) + >>> df.selectExpr("result < 15").show() + +-----------+ + |result < 15| + +-----------+ + | true| + +-----------+ """ if seed is None: return _invoke_function_over_columns("uniform", min, max) From 4929fbd9a36d1c003ef36af393fdfdd14db9c1d7 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 20 Sep 2024 17:05:08 +0200 Subject: [PATCH 26/27] fix --- python/pyspark/sql/functions/builtin.py | 42 +++++++++---------- .../sql/catalyst/parser/SqlBaseParser.g4 | 1 - .../sql/catalyst/analysis/TypeCoercion.scala | 20 ++++----- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../sql-tests/inputs/pipe-operators.sql | 6 --- .../sql/execution/SparkSqlParserSuite.scala | 21 ++-------- 6 files changed, 32 insertions(+), 62 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index ae0c7e54a170b..e2918966ed1c3 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11883,20 +11883,19 @@ def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = Non Examples -------- - >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([('3',)], ['a']) - >>> df = df.select(sf.randstr(lit(5), lit(0)).alias('result')) - >>> df.selectExpr("result != ''") - +------------+ - |result != ''| - +------------+ - | true| - +------------+ + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(randstr(lit(5), lit(0)).alias('result')) \\ + ... .selectExpr("length(result) > 0").show() + +--------------------+ + |(length(result) > 0)| + +--------------------+ + | true| + +--------------------+ """ if seed is None: - return _invoke_function_over_columns("randstr", length) + return _invoke_function("randstr", lit(length)) else: - return _invoke_function_over_columns("randstr", length, seed) + return _invoke_function("randstr", lit(length), lit(seed)) @_try_remote_functions @@ -12295,20 +12294,19 @@ def uniform( Examples -------- - >>> from pyspark.sql import functions as sf - >>> df = spark.createDataFrame([('3',)], ['a']) - >>> df = df.select(sf.uniform(lit(0), lit(10), lit(0)).alias('result')) - >>> df.selectExpr("result < 15").show() - +-----------+ - |result < 15| - +-----------+ - | true| - +-----------+ + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(uniform(lit(0), lit(10), lit(0)).alias('result')) \\ + ... .selectExpr("result < 15").show() + +-------------+ + |(result < 15)| + +-------------+ + | true| + +-------------+ """ if seed is None: - return _invoke_function_over_columns("uniform", min, max) + return _invoke_function("uniform", lit(min), lit(max)) else: - return _invoke_function_over_columns("uniform", min, max, seed) + return _invoke_function("uniform", lit(min), lit(max), lit(seed)) @_try_remote_functions diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 023c3384d560b..f13dde773496a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1488,7 +1488,6 @@ version operatorPipeRightSide : selectClause - | joinRelation ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c93e54f834a3a..08c5b3531b4c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -550,18 +550,14 @@ abstract class TypeCoercionBase { object CaseWhenCoercion extends TypeCoercionRule { override val transform: PartialFunction[Expression, Expression] = { case c: CaseWhen if c.childrenResolved && !haveSameType(c.inputTypesForMerging) => - convert(c) - } - - def convert(c: CaseWhen): CaseWhen = { - val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) - maybeCommonType.map { commonType => - val newBranches = c.branches.map { case (condition, value) => - (condition, castIfNotSameType(value, commonType)) - } - val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) - CaseWhen(newBranches, newElseValue) - }.getOrElse(c) + val maybeCommonType = findWiderCommonType(c.inputTypesForMerging) + maybeCommonType.map { commonType => + val newBranches = c.branches.map { case (condition, value) => + (condition, castIfNotSameType(value, commonType)) + } + val newElseValue = c.elseValue.map(castIfNotSameType(_, commonType)) + CaseWhen(newBranches, newElseValue) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d075e806be6a7..cb0e0e35c3704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5854,9 +5854,7 @@ class AstBuilder extends DataTypeAstBuilder windowClause = null, relation = left, isPipeOperatorSelect = true) - }.getOrElse(Option(ctx.joinRelation()).map { c => - withJoinRelation(c, left) - }.get) + }.get } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 597aea73aa7cb..7d0966e7f2095 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -95,12 +95,6 @@ table t table t |> select y, length(y) + sum(x) as result; --- Joins are supported. ------------------------ - --- Join operators: negative tests. ----------------------------------- - -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index e6e83b4646499..a80444feb68ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{JOIN, LOCAL_RELATION, PROJECT, TreePattern, UNRESOLVED_RELATION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -887,29 +887,14 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { // Basic selection. // Here we check that every parsed plan contains a projection and a source relation or // inline table. - def checkPlan(query: String, pattern: TreePattern): Unit = { + def checkPipeSelect(query: String): Unit = { val plan: LogicalPlan = parser.parsePlan(query) - assert(plan.containsPattern(pattern)) + assert(plan.containsPattern(PROJECT)) assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) } - def checkPipeSelect(query: String): Unit = checkPlan(query, PROJECT) checkPipeSelect("TABLE t |> SELECT 1 AS X") checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") - // Join operations. - def checkJoin(query: String): Unit = checkPlan(query, JOIN) - checkJoin( - """ - |SELECT * FROM VALUES - | ("dotNET", 15000, 48000, 22500), - | ("Java", 20000, 30000, NULL) - | AS courseEarnings(course, `2012`, `2013`, `2014`) - ||> INNER JOIN (SELECT * FROM VALUES - | ("dotNET", 15000, 48000, 22500), - | ("Java", 20000, 30000, NULL) - | AS otherCourseEarnings(course, `2012`, `2013`, `2014`)) - | USING (course) - |""".stripMargin) } } } From ef7f9e7f6b729094de8f67f4d17f11c0111e3b80 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Fri, 20 Sep 2024 17:25:30 +0200 Subject: [PATCH 27/27] fix --- python/pyspark/sql/functions/builtin.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index e2918966ed1c3..ef4f3465ca870 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11892,10 +11892,14 @@ def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = Non | true| +--------------------+ """ + length = _enum_to_value(length) + length = lit(length) if isinstance(length, int) else length if seed is None: - return _invoke_function("randstr", lit(length)) + return _invoke_function_over_columns("randstr", length) else: - return _invoke_function("randstr", lit(length), lit(seed)) + seed = _enum_to_value(seed) + seed = lit(seed) if isinstance(seed, int) else seed + return _invoke_function_over_columns("randstr", length, seed) @_try_remote_functions @@ -12303,10 +12307,16 @@ def uniform( | true| +-------------+ """ + min = _enum_to_value(min) + min = lit(min) if isinstance(min, int) else min + max = _enum_to_value(max) + max = lit(max) if isinstance(max, int) else max if seed is None: - return _invoke_function("uniform", lit(min), lit(max)) + return _invoke_function_over_columns("uniform", min, max) else: - return _invoke_function("uniform", lit(min), lit(max), lit(seed)) + seed = _enum_to_value(seed) + seed = lit(seed) if isinstance(seed, int) else seed + return _invoke_function_over_columns("uniform", min, max, seed) @_try_remote_functions