From effcf22a029f2f61aa2513ae06554d171a774f5b Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Thu, 29 Aug 2024 22:55:08 +0800 Subject: [PATCH 001/230] [SPARK-49457][BUILD] Remove uncommon curl option `--retry-all-errors` ### What changes were proposed in this pull request? Remove uncommon curl option `--retry-all-errors`, which is added in curl 7.71.0 - June 24 2020, old versions can not recognize this option. ### Why are the changes needed? It causes `build/mvn` to fail on Ubuntu 20.04. ``` exec: curl --retry 3 --retry-all-errors --silent --show-error -L https://www.apache.org/dyn/closer.lua/maven/maven-3/3.9.9/binaries/apache-maven-3.9.9-bin.tar.gz?action=download curl: option --retry-all-errors: is unknown curl: try 'curl --help' or 'curl --manual' for more information ``` ``` $ curl --version curl 7.68.0 (aarch64-unknown-linux-gnu) libcurl/7.68.0 OpenSSL/1.1.1f zlib/1.2.11 brotli/1.0.7 libidn2/2.2.0 libpsl/0.21.0 (+libidn2/2.2.0) libssh/0.9.3/openssl/zlib nghttp2/1.40.0 librtmp/2.3 Release-Date: 2020-01-08 Protocols: dict file ftp ftps gopher http https imap imaps ldap ldaps pop3 pop3s rtmp rtsp scp sftp smb smbs smtp smtps telnet tftp Features: AsynchDNS brotli GSS-API HTTP2 HTTPS-proxy IDN IPv6 Kerberos Largefile libz NTLM NTLM_WB PSL SPNEGO SSL TLS-SRP UnixSockets ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47926 from pan3793/SPARK-49457. Authored-by: Cheng Pan Signed-off-by: yangjie01 --- build/mvn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index 28454c68fd128..060209ac1ac4d 100755 --- a/build/mvn +++ b/build/mvn @@ -58,7 +58,7 @@ install_app() { local local_checksum="${local_tarball}.${checksum_suffix}" local remote_checksum="https://archive.apache.org/dist/${url_path}.${checksum_suffix}" - local curl_opts="--retry 3 --retry-all-errors --silent --show-error -L" + local curl_opts="--retry 3 --silent --show-error -L" local wget_opts="--no-verbose" if [ ! -f "$binary" ]; then From dad2b763f004a72613276b31738a958e80d02b37 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 29 Aug 2024 11:00:51 -0400 Subject: [PATCH 002/230] [SPARK-49419][CONNECT][SQL] Create shared DataFrameStatFunctions ### What changes were proposed in this pull request? This PR creates an interface for DataFrameStatFunctions that is shared between Classic and Connect. ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47898 from hvanhovell/SPARK-49419. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../spark/sql/DataFrameStatFunctions.scala | 588 ++---------------- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- sql/api/pom.xml | 5 + .../sql/api/DataFrameStatFunctions.scala | 514 +++++++++++++++ .../org/apache/spark/sql/api/Dataset.scala | 12 + .../spark/sql/DataFrameStatFunctions.scala | 494 +-------------- 6 files changed, 594 insertions(+), 1021 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 6365f387afce4..9f5ada0d7ec35 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,89 +18,23 @@ package org.apache.spark.sql import java.{lang => jl, util => ju} -import java.io.ByteArrayInputStream - -import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} import org.apache.spark.sql.functions.lit -import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. * * @since 3.4.0 */ -final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, root: Relation) { - import sparkSession.RichColumn - - /** - * Calculates the approximate quantiles of a numerical column of a DataFrame. - * - * The result of this algorithm has the following deterministic bound: If the DataFrame has N - * elements and if we request the quantile at probability `p` up to error `err`, then the - * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is - * close to (p * N). More precisely, - * - * {{{ - * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N) - * }}} - * - * This method implements a variation of the Greenwald-Khanna algorithm (with some speed - * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile - * Summaries by Greenwald and Khanna. - * - * @param col - * the name of the numerical column - * @param probabilities - * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the - * minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError - * The relative target precision to achieve (greater than or equal to 0). If set to zero, the - * exact quantiles are computed, which could be very expensive. Note that values greater than - * 1 are accepted but give the same result as 1. - * @return - * the approximate quantiles at the given probabilities - * - * @note - * null and NaN values will be removed from the numerical column before calculation. If the - * dataframe is empty or the column only contains null or NaN, an empty array is returned. - * - * @since 3.4.0 - */ - def approxQuantile( - col: String, - probabilities: Array[Double], - relativeError: Double): Array[Double] = { - approxQuantile(Array(col), probabilities, relativeError).head - } +final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) + extends api.DataFrameStatFunctions[Dataset] { + private def root: Relation = df.plan.getRoot + private val sparkSession: SparkSession = df.sparkSession - /** - * Calculates the approximate quantiles of numerical columns of a DataFrame. - * @see - * `approxQuantile(col:Str* approxQuantile)` for detailed description. - * - * @param cols - * the names of the numerical columns - * @param probabilities - * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the - * minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError - * The relative target precision to achieve (greater than or equal to 0). If set to zero, the - * exact quantiles are computed, which could be very expensive. Note that values greater than - * 1 are accepted but give the same result as 1. - * @return - * the approximate quantiles at the given probabilities of each column - * - * @note - * null and NaN values will be ignored in numerical columns before calculation. For columns - * only containing null or NaN values, an empty array is returned. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def approxQuantile( cols: Array[String], probabilities: Array[Double], @@ -120,24 +54,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculate the sample covariance of two numerical columns of a DataFrame. - * @param col1 - * the name of the first column - * @param col2 - * the name of the second column - * @return - * the covariance of the two columns. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.cov("rand1", "rand2") - * res1: Double = 0.065... - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def cov(col1: String, col2: String): Double = { sparkSession .newDataset(PrimitiveDoubleEncoder) { builder => @@ -146,27 +63,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in - * MLlib's Statistics. - * - * @param col1 - * the name of the column - * @param col2 - * the name of the column to calculate the correlation against - * @return - * The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2") - * res1: Double = 0.613... - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def corr(col1: String, col2: String, method: String): Double = { require( method == "pearson", @@ -179,289 +76,48 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. - * - * @param col1 - * the name of the column - * @param col2 - * the name of the column to calculate the correlation against - * @return - * The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2", "pearson") - * res1: Double = 0.613... - * }}} - * - * @since 3.4.0 - */ - def corr(col1: String, col2: String): Double = { - corr(col1, col2, "pearson") - } - - /** - * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. - * The first column of each row will be the distinct values of `col1` and the column names will - * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts - * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. - * Null elements will be replaced by "null", and back ticks will be dropped from elements if - * they exist. - * - * @param col1 - * The name of the first column. Distinct items will make the first item of each row. - * @param col2 - * The name of the second column. Distinct items will make the column names of the DataFrame. - * @return - * A DataFrame containing for the contingency table. - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))) - * .toDF("key", "value") - * val ct = df.stat.crosstab("key", "value") - * ct.show() - * +---------+---+---+---+ - * |key_value| 1| 2| 3| - * +---------+---+---+---+ - * | 2| 2| 0| 1| - * | 1| 1| 1| 0| - * | 3| 0| 1| 1| - * +---------+---+---+---+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def crosstab(col1: String, col2: String): DataFrame = { sparkSession.newDataFrame { builder => builder.getCrosstabBuilder.setInput(root).setCol1(col1).setCol2(col2) } } - /** - * Finding frequent items for columns, possibly with false positives. Using the frequent element - * count algorithm described in here, - * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @param support - * The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 3.4.0 - */ - def freqItems(cols: Array[String], support: Double): DataFrame = { - sparkSession.newDataFrame { builder => - val freqItemsBuilder = builder.getFreqItemsBuilder.setInput(root).setSupport(support) - cols.foreach(freqItemsBuilder.addCols) - } - } + /** @inheritdoc */ + override def freqItems(cols: Array[String], support: Double): DataFrame = + super.freqItems(cols, support) - /** - * Finding frequent items for columns, possibly with false positives. Using the frequent element - * count algorithm described in here, - * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * @since 3.4.0 - */ - def freqItems(cols: Array[String]): DataFrame = { - freqItems(cols, 0.01) - } + /** @inheritdoc */ + override def freqItems(cols: Array[String]): DataFrame = super.freqItems(cols) + + /** @inheritdoc */ + override def freqItems(cols: Seq[String]): DataFrame = super.freqItems(cols) - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in here, proposed by Karp, Schenker, and - * Papadimitriou. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def freqItems(cols: Seq[String], support: Double): DataFrame = { - freqItems(cols.toArray, support) + df.sparkSession.newDataFrame { builder => + val freqItemsBuilder = builder.getFreqItemsBuilder + .setInput(df.plan.getRoot) + .setSupport(support) + cols.foreach(freqItemsBuilder.addCols) + } } - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in here, proposed by Karp, Schenker, and - * Papadimitriou. Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * @since 3.4.0 - */ - def freqItems(cols: Seq[String]): DataFrame = { - freqItems(cols.toArray, 0.01) - } + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), - * (3, 3))).toDF("key", "value") - * val fractions = Map(1 -> 1.0, 3 -> 0.5) - * df.stat.sampleBy("key", fractions, 36L).show() - * +---+-----+ - * |key|value| - * +---+-----+ - * | 1| 1| - * | 1| 2| - * | 3| 2| - * +---+-----+ - * }}} - * - * @since 3.4.0 - */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { - sampleBy(Column(col), fractions, seed) - } + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * @since 3.4.0 - */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) - } + /** @inheritdoc */ + override def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * The stratified sample can be performed over multiple columns: - * {{{ - * import org.apache.spark.sql.Row - * import org.apache.spark.sql.functions.struct - * - * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), - * ("Alice", 10))).toDF("name", "age") - * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0) - * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show() - * +-----+---+ - * | name|age| - * +-----+---+ - * | Nico| 8| - * |Alice| 10| - * +-----+---+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + import sparkSession.RichColumn require( fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") @@ -479,180 +135,6 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo } } } - - /** - * (Java-specific) Returns a stratified sample without replacement based on the fraction given - * on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * @since 3.4.0 - */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName - * name of the column over which the sketch is built - * @param depth - * depth of the sketch - * @param width - * width of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(Column(colName), depth, width, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName - * name of the column over which the sketch is built - * @param eps - * relative error of the sketch - * @param confidence - * confidence of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch( - colName: String, - eps: Double, - confidence: Double, - seed: Int): CountMinSketch = { - countMinSketch(Column(colName), eps, confidence, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col - * the column over which the sketch is built - * @param depth - * depth of the sketch - * @param width - * width of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(col, eps = 2.0 / width, confidence = 1 - 1 / Math.pow(2, depth), seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col - * the column over which the sketch is built - * @param eps - * relative error of the sketch - * @param confidence - * confidence of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { - val agg = Column.fn("count_min_sketch", col, lit(eps), lit(confidence), lit(seed)) - val ds = sparkSession.newDataset(BinaryEncoder) { builder => - builder.getProjectBuilder - .setInput(root) - .addExpressions(agg.expr) - } - CountMinSketch.readFrom(ds.head()) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName - * name of the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param fpp - * expected false positive probability of the filter. - * @since 3.5.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, fpp) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col - * the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param fpp - * expected false positive probability of the filter. - * @since 3.5.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { - val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp) - bloomFilter(col, expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName - * name of the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param numBits - * expected number of bits of the filter. - * @since 3.5.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col - * the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param numBits - * expected number of bits of the filter. - * @since 3.5.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { - val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) - val ds = sparkSession.newDataset(BinaryEncoder) { builder => - builder.getProjectBuilder - .setInput(root) - .addExpressions(agg.expr) - } - BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) - } } private object DataFrameStatFunctions { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index a4d1c804685fe..37a182675b6cd 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -300,7 +300,7 @@ class Dataset[T] private[sql] ( * @group untypedrel * @since 3.4.0 */ - def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession, plan.getRoot) + def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = { checkSameSparkSession(right) diff --git a/sql/api/pom.xml b/sql/api/pom.xml index 9a63f73ab1918..54cdc96fc40a2 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -53,6 +53,11 @@ spark-unsafe_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sketch_${scala.binary.version} + ${project.version} + org.json4s json4s-jackson_${scala.binary.version} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala new file mode 100644 index 0000000000000..c3ecc7b90d5b4 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala @@ -0,0 +1,514 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import scala.jdk.CollectionConverters._ + +import _root_.java.{lang => jl, util => ju} + +import org.apache.spark.annotation.Stable +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.BinaryEncoder +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.functions.{count_min_sketch, lit} +import org.apache.spark.util.ArrayImplicits.SparkArrayOps +import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} + +/** + * Statistic functions for `DataFrame`s. + * + * @since 1.4.0 + */ +@Stable +abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { + protected def df: DS[Row] + + /** + * Calculates the approximate quantiles of a numerical column of a DataFrame. + * + * The result of this algorithm has the following deterministic bound: + * If the DataFrame has N elements and if we request the quantile at probability `p` up to error + * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank + * of `x` is close to (p * N). + * More precisely, + * + * {{{ + * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N) + * }}} + * + * This method implements a variation of the Greenwald-Khanna algorithm (with some speed + * optimizations). + * The algorithm was first present in + * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. + * + * @param col the name of the numerical column + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (greater than or equal to 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * @return the approximate quantiles at the given probabilities + * + * @note null and NaN values will be removed from the numerical column before calculation. If + * the dataframe is empty or the column only contains null or NaN, an empty array is returned. + * + * @since 2.0.0 + */ + def approxQuantile( + col: String, + probabilities: Array[Double], + relativeError: Double): Array[Double] = withOrigin { + approxQuantile(Array(col), probabilities, relativeError).head + } + + /** + * Calculates the approximate quantiles of numerical columns of a DataFrame. + * @see `approxQuantile(col:Str* approxQuantile)` for detailed description. + * + * @param cols the names of the numerical columns + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (greater than or equal to 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * @return the approximate quantiles at the given probabilities of each column + * + * @note null and NaN values will be ignored in numerical columns before calculation. For + * columns only containing null or NaN values, an empty array is returned. + * + * @since 2.2.0 + */ + def approxQuantile( + cols: Array[String], + probabilities: Array[Double], + relativeError: Double): Array[Array[Double]] + + /** + * Calculate the sample covariance of two numerical columns of a DataFrame. + * + * @param col1 the name of the first column + * @param col2 the name of the second column + * @return the covariance of the two columns. + * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.cov("rand1", "rand2") + * res1: Double = 0.065... + * }}} + * @since 1.4.0 + */ + def cov(col1: String, col2: String): Double + + /** + * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * MLlib's Statistics. + * + * @param col1 the name of the column + * @param col2 the name of the column to calculate the correlation against + * @return The Pearson Correlation Coefficient as a Double. + * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.corr("rand1", "rand2") + * res1: Double = 0.613... + * }}} + * @since 1.4.0 + */ + def corr(col1: String, col2: String, method: String): Double + + /** + * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. + * + * @param col1 the name of the column + * @param col2 the name of the column to calculate the correlation against + * @return The Pearson Correlation Coefficient as a Double. + * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.corr("rand1", "rand2", "pearson") + * res1: Double = 0.613... + * }}} + * @since 1.4.0 + */ + def corr(col1: String, col2: String): Double = { + corr(col1, col2, "pearson") + } + + /** + * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. + * The first column of each row will be the distinct values of `col1` and the column names will + * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts + * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. + * Null elements will be replaced by "null", and back ticks will be dropped from elements if they + * exist. + * + * @param col1 The name of the first column. Distinct items will make the first item of + * each row. + * @param col2 The name of the second column. Distinct items will make the column names + * of the DataFrame. + * @return A DataFrame containing for the contingency table. + * + * {{{ + * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))) + * .toDF("key", "value") + * val ct = df.stat.crosstab("key", "value") + * ct.show() + * +---------+---+---+---+ + * |key_value| 1| 2| 3| + * +---------+---+---+---+ + * | 2| 2| 0| 1| + * | 1| 1| 1| 0| + * | 3| 0| 1| 1| + * +---------+---+---+---+ + * }}} + * + * @since 1.4.0 + */ + def crosstab(col1: String, col2: String): DS[Row] + + /** + * Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * here, proposed by Karp, + * Schenker, and Papadimitriou. + * The `support` should be greater than 1e-4. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting `DataFrame`. + * + * @param cols the names of the columns to search frequent items in. + * @param support The minimum frequency for an item to be considered `frequent`. Should be greater + * than 1e-4. + * @return A Local DataFrame with the Array of frequent items for each column. + * + * {{{ + * val rows = Seq.tabulate(100) { i => + * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) + * } + * val df = spark.createDataFrame(rows).toDF("a", "b") + * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns + * // "a" and "b" + * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) + * freqSingles.show() + * +-----------+-------------+ + * |a_freqItems| b_freqItems| + * +-----------+-------------+ + * | [1, 99]|[-1.0, -99.0]| + * +-----------+-------------+ + * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" + * val pairDf = df.select(struct("a", "b").as("a-b")) + * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1) + * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() + * +----------+ + * | freq_ab| + * +----------+ + * | [1,-1.0]| + * | ... | + * +----------+ + * }}} + * @since 1.4.0 + */ + def freqItems(cols: Array[String], support: Double): DS[Row] = + freqItems(cols.toImmutableArraySeq, support) + + /** + * Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * here, proposed by Karp, + * Schenker, and Papadimitriou. + * Uses a `default` support of 1%. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting `DataFrame`. + * + * @param cols the names of the columns to search frequent items in. + * @return A Local DataFrame with the Array of frequent items for each column. + * @since 1.4.0 + */ + def freqItems(cols: Array[String]): DS[Row] = freqItems(cols, 0.01) + + /** + * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * here, proposed by Karp, Schenker, + * and Papadimitriou. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting `DataFrame`. + * + * @param cols the names of the columns to search frequent items in. + * @return A Local DataFrame with the Array of frequent items for each column. + * + * {{{ + * val rows = Seq.tabulate(100) { i => + * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) + * } + * val df = spark.createDataFrame(rows).toDF("a", "b") + * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns + * // "a" and "b" + * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) + * freqSingles.show() + * +-----------+-------------+ + * |a_freqItems| b_freqItems| + * +-----------+-------------+ + * | [1, 99]|[-1.0, -99.0]| + * +-----------+-------------+ + * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" + * val pairDf = df.select(struct("a", "b").as("a-b")) + * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1) + * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() + * +----------+ + * | freq_ab| + * +----------+ + * | [1,-1.0]| + * | ... | + * +----------+ + * }}} + * + * @since 1.4.0 + */ + def freqItems(cols: Seq[String], support: Double): DS[Row] + + /** + * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the + * frequent element count algorithm described in + * here, proposed by Karp, Schenker, + * and Papadimitriou. + * Uses a `default` support of 1%. + * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting `DataFrame`. + * + * @param cols the names of the columns to search frequent items in. + * @return A Local DataFrame with the Array of frequent items for each column. + * @since 1.4.0 + */ + def freqItems(cols: Seq[String]): DS[Row] = freqItems(cols, 0.01) + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * + * {{{ + * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), + * (3, 3))).toDF("key", "value") + * val fractions = Map(1 -> 1.0, 3 -> 0.5) + * df.stat.sampleBy("key", fractions, 36L).show() + * +---+-----+ + * |key|value| + * +---+-----+ + * | 1| 1| + * | 1| 2| + * | 3| 2| + * +---+-----+ + * }}} + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DS[Row] = { + sampleBy(Column(col), fractions, seed) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * + * The stratified sample can be performed over multiple columns: + * {{{ + * import org.apache.spark.sql.Row + * import org.apache.spark.sql.functions.struct + * + * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), + * ("Alice", 10))).toDF("name", "age") + * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0) + * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show() + * +-----+---+ + * | name|age| + * +-----+---+ + * | Nico| 8| + * |Alice| 10| + * +-----+---+ + * }}} + * + * @since 3.0.0 + */ + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DS[Row] + + + /** + * (Java-specific) Returns a stratified sample without replacement based on the fraction given + * on each stratum. + * + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new `DataFrame` that represents the stratified sample + * @since 3.0.0 + */ + def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a `CountMinSketch` over column `colName` + * @since 2.0.0 + */ + def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), depth, width, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param colName name of the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a `CountMinSketch` over column `colName` + * @since 2.0.0 + */ + def countMinSketch( + colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + countMinSketch(Column(colName), eps, confidence, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param depth depth of the sketch + * @param width width of the sketch + * @param seed random seed + * @return a `CountMinSketch` over column `colName` + * @since 2.0.0 + */ + def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { + val eps = 2.0 / width + val confidence = 1 - 1 / Math.pow(2, depth) + countMinSketch(col, eps, confidence, seed) + } + + /** + * Builds a Count-min Sketch over a specified column. + * + * @param col the column over which the sketch is built + * @param eps relative error of the sketch + * @param confidence confidence of the sketch + * @param seed random seed + * @return a `CountMinSketch` over column `colName` + * @since 2.0.0 + */ + def countMinSketch( + col: Column, + eps: Double, + confidence: Double, + seed: Int): CountMinSketch = withOrigin { + val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed)) + val bytes: Array[Byte] = df.select(cms).as(BinaryEncoder).head() + CountMinSketch.readFrom(bytes) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { + bloomFilter(Column(colName), expectedNumItems, fpp) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param fpp expected false positive probability of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { + val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp) + bloomFilter(col, expectedNumItems, numBits) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param colName name of the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { + bloomFilter(Column(colName), expectedNumItems, numBits) + } + + /** + * Builds a Bloom filter over a specified column. + * + * @param col the column over which the filter is built + * @param expectedNumItems expected number of items which will be put into the filter. + * @param numBits expected number of bits of the filter. + * @since 2.0.0 + */ + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin { + val bf = Column.internalFn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) + val bytes: Array[Byte] = df.select(bf).as(BinaryEncoder).head() + BloomFilter.readFrom(bytes) + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 5b4ebed12c17c..16f15205cabea 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -538,6 +538,18 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { // scalastyle:off println def show(numRows: Int, truncate: Int, vertical: Boolean): Unit + /** + * Returns a [[DataFrameStatFunctions]] for working statistic functions support. + * {{{ + * // Finding frequent items in column with name 'a'. + * ds.stat.freqItems(Seq("a")) + * }}} + * + * @group untypedrel + * @since 1.6.0 + */ + def stat: DataFrameStatFunctions[DS] + /** * Join with another `DataFrame`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 9346739cbbd99..a5ab237bb7041 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,12 +22,10 @@ import java.{lang => jl, util => ju} import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.Encoders.BINARY import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.execution.stat._ -import org.apache.spark.sql.functions.{col, count_min_sketch, lit} +import org.apache.spark.sql.functions.col import org.apache.spark.util.ArrayImplicits._ -import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. @@ -35,65 +33,10 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * @since 1.4.0 */ @Stable -final class DataFrameStatFunctions private[sql](df: DataFrame) { +final class DataFrameStatFunctions private[sql](protected val df: DataFrame) + extends api.DataFrameStatFunctions[Dataset] { - /** - * Calculates the approximate quantiles of a numerical column of a DataFrame. - * - * The result of this algorithm has the following deterministic bound: - * If the DataFrame has N elements and if we request the quantile at probability `p` up to error - * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank - * of `x` is close to (p * N). - * More precisely, - * - * {{{ - * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N) - * }}} - * - * This method implements a variation of the Greenwald-Khanna algorithm (with some speed - * optimizations). - * The algorithm was first present in - * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna. - * - * @param col the name of the numerical column - * @param probabilities a list of quantile probabilities - * Each number must belong to [0, 1]. - * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError The relative target precision to achieve (greater than or equal to 0). - * If set to zero, the exact quantiles are computed, which could be very expensive. - * Note that values greater than 1 are accepted but give the same result as 1. - * @return the approximate quantiles at the given probabilities - * - * @note null and NaN values will be removed from the numerical column before calculation. If - * the dataframe is empty or the column only contains null or NaN, an empty array is returned. - * - * @since 2.0.0 - */ - def approxQuantile( - col: String, - probabilities: Array[Double], - relativeError: Double): Array[Double] = withOrigin { - approxQuantile(Array(col), probabilities, relativeError).head - } - - /** - * Calculates the approximate quantiles of numerical columns of a DataFrame. - * @see `approxQuantile(col:Str* approxQuantile)` for detailed description. - * - * @param cols the names of the numerical columns - * @param probabilities a list of quantile probabilities - * Each number must belong to [0, 1]. - * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError The relative target precision to achieve (greater than or equal to 0). - * If set to zero, the exact quantiles are computed, which could be very expensive. - * Note that values greater than 1 are accepted but give the same result as 1. - * @return the approximate quantiles at the given probabilities of each column - * - * @note null and NaN values will be ignored in numerical columns before calculation. For - * columns only containing null or NaN values, an empty array is returned. - * - * @since 2.2.0 - */ + /** @inheritdoc */ def approxQuantile( cols: Array[String], probabilities: Array[Double], @@ -105,7 +48,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { relativeError).map(_.toArray).toArray } - /** * Python-friendly version of [[approxQuantile()]] */ @@ -117,304 +59,49 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { .map(_.toList.asJava).toList.asJava } - /** - * Calculate the sample covariance of two numerical columns of a DataFrame. - * @param col1 the name of the first column - * @param col2 the name of the second column - * @return the covariance of the two columns. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.cov("rand1", "rand2") - * res1: Double = 0.065... - * }}} - * - * @since 1.4.0 - */ + /** @inheritdoc */ def cov(col1: String, col2: String): Double = withOrigin { StatFunctions.calculateCov(df, Seq(col1, col2)) } - /** - * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in - * MLlib's Statistics. - * - * @param col1 the name of the column - * @param col2 the name of the column to calculate the correlation against - * @return The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2") - * res1: Double = 0.613... - * }}} - * - * @since 1.4.0 - */ + /** @inheritdoc */ def corr(col1: String, col2: String, method: String): Double = withOrigin { require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) } - /** - * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. - * - * @param col1 the name of the column - * @param col2 the name of the column to calculate the correlation against - * @return The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2", "pearson") - * res1: Double = 0.613... - * }}} - * - * @since 1.4.0 - */ - def corr(col1: String, col2: String): Double = { - corr(col1, col2, "pearson") - } - - /** - * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. - * The first column of each row will be the distinct values of `col1` and the column names will - * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts - * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. - * Null elements will be replaced by "null", and back ticks will be dropped from elements if they - * exist. - * - * @param col1 The name of the first column. Distinct items will make the first item of - * each row. - * @param col2 The name of the second column. Distinct items will make the column names - * of the DataFrame. - * @return A DataFrame containing for the contingency table. - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))) - * .toDF("key", "value") - * val ct = df.stat.crosstab("key", "value") - * ct.show() - * +---------+---+---+---+ - * |key_value| 1| 2| 3| - * +---------+---+---+---+ - * | 2| 2| 0| 1| - * | 1| 1| 1| 0| - * | 3| 0| 1| 1| - * +---------+---+---+---+ - * }}} - * - * @since 1.4.0 - */ + /** @inheritdoc */ def crosstab(col1: String, col2: String): DataFrame = withOrigin { StatFunctions.crossTabulate(df, col1, col2) } - /** - * Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, - * Schenker, and Papadimitriou. - * The `support` should be greater than 1e-4. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols the names of the columns to search frequent items in. - * @param support The minimum frequency for an item to be considered `frequent`. Should be greater - * than 1e-4. - * @return A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 1.4.0 - */ - def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin { - FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, support) - } - - /** - * Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, - * Schenker, and Papadimitriou. - * Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols the names of the columns to search frequent items in. - * @return A Local DataFrame with the Array of frequent items for each column. - * - * @since 1.4.0 - */ - def freqItems(cols: Array[String]): DataFrame = withOrigin { - FrequentItems.singlePassFreqItems(df, cols.toImmutableArraySeq, 0.01) - } - - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, Schenker, - * and Papadimitriou. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols the names of the columns to search frequent items in. - * @return A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 1.4.0 - */ + /** @inheritdoc */ def freqItems(cols: Seq[String], support: Double): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, support) } - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in - * here, proposed by Karp, Schenker, - * and Papadimitriou. - * Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols the names of the columns to search frequent items in. - * @return A Local DataFrame with the Array of frequent items for each column. - * - * @since 1.4.0 - */ - def freqItems(cols: Seq[String]): DataFrame = withOrigin { - FrequentItems.singlePassFreqItems(df, cols, 0.01) - } + /** @inheritdoc */ + override def freqItems(cols: Array[String], support: Double): DataFrame = + super.freqItems(cols, support) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), - * (3, 3))).toDF("key", "value") - * val fractions = Map(1 -> 1.0, 3 -> 0.5) - * df.stat.sampleBy("key", fractions, 36L).show() - * +---+-----+ - * |key|value| - * +---+-----+ - * | 1| 1| - * | 1| 2| - * | 3| 2| - * +---+-----+ - * }}} - * - * @since 1.5.0 - */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { - sampleBy(Column(col), fractions, seed) + /** @inheritdoc */ + override def freqItems(cols: Array[String]): DataFrame = super.freqItems(cols) + + /** @inheritdoc */ + override def freqItems(cols: Seq[String]): DataFrame = super.freqItems(cols) + + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + super.sampleBy(col, fractions, seed) } - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample - * - * @since 1.5.0 - */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + super.sampleBy(col, fractions, seed) } - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample - * - * The stratified sample can be performed over multiple columns: - * {{{ - * import org.apache.spark.sql.Row - * import org.apache.spark.sql.functions.struct - * - * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), - * ("Alice", 10))).toDF("name", "age") - * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0) - * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show() - * +-----+---+ - * | name|age| - * +-----+---+ - * | Nico| 8| - * |Alice| 10| - * +-----+---+ - * }}} - * - * @since 3.0.0 - */ + /** @inheritdoc */ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = withOrigin { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") @@ -426,135 +113,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { df.filter(f(col, r)) } - /** - * (Java-specific) Returns a stratified sample without replacement based on the fraction given - * on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @tparam T stratum type - * @return a new `DataFrame` that represents the stratified sample - * - * @since 3.0.0 - */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName name of the column over which the sketch is built - * @param depth depth of the sketch - * @param width width of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` - * @since 2.0.0 - */ - def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(Column(colName), depth, width, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName name of the column over which the sketch is built - * @param eps relative error of the sketch - * @param confidence confidence of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` - * @since 2.0.0 - */ - def countMinSketch( - colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = { - countMinSketch(Column(colName), eps, confidence, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col the column over which the sketch is built - * @param depth depth of the sketch - * @param width width of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` - * @since 2.0.0 - */ - def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { - val eps = 2.0 / width - val confidence = 1 - 1 / Math.pow(2, depth) - countMinSketch(col, eps, confidence, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col the column over which the sketch is built - * @param eps relative error of the sketch - * @param confidence confidence of the sketch - * @param seed random seed - * @return a `CountMinSketch` over column `colName` - * @since 2.0.0 - */ - def countMinSketch( - col: Column, - eps: Double, - confidence: Double, - seed: Int): CountMinSketch = withOrigin { - val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed)) - val bytes = df.select(cms).as(BINARY).head() - CountMinSketch.readFrom(bytes) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName name of the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param fpp expected false positive probability of the filter. - * @since 2.0.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, fpp) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param fpp expected false positive probability of the filter. - * @since 2.0.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { - val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp) - bloomFilter(col, expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName name of the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param numBits expected number of bits of the filter. - * @since 2.0.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col the column over which the filter is built - * @param expectedNumItems expected number of items which will be put into the filter. - * @param numBits expected number of bits of the filter. - * @since 2.0.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin { - val bf = Column.internalFn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) - val bytes = df.select(bf).as(BINARY).head() - BloomFilter.readFrom(bytes) + /** @inheritdoc */ + override def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + super.sampleBy(col, fractions, seed) } } From 1e6765910c44009cd09698f8136608edfe7ea098 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Thu, 29 Aug 2024 19:36:40 +0200 Subject: [PATCH 003/230] [SPARK-49410][SQL][TESTS] Update collation benchmarks ### What changes were proposed in this pull request? Updating collation benchmarks to compute relative time, instead of relative speed. Also, re-run all collation benchmarks (after introducing various perf improvements as part of the recent collation effort). ### Why are the changes needed? Relative speed is displayed in format: "x0.0" for collations that are slower than the referent collation (UTF8_BINARY). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47893 from uros-db/update-benchmarks. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- .../apache/spark/benchmark/Benchmark.scala | 12 ++-- .../CollationBenchmark-jdk21-results.txt | 60 +++++++++---------- .../benchmarks/CollationBenchmark-results.txt | 60 +++++++++---------- ...llationNonASCIIBenchmark-jdk21-results.txt | 60 +++++++++---------- .../CollationNonASCIIBenchmark-results.txt | 60 +++++++++---------- .../benchmark/CollationBenchmark.scala | 12 ++-- 6 files changed, 134 insertions(+), 130 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala index e7315d6119be0..7e88c7ee684bd 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala @@ -94,9 +94,11 @@ private[spark] class Benchmark( /** * Runs the benchmark and outputs the results to stdout. This should be copied and added as * a comment with the benchmark. Although the results vary from machine to machine, it should - * provide some baseline. + * provide some baseline. If `relativeTime` is set to `true`, the `Relative` column will be + * the relative time of each case relative to the first case (less is better). Otherwise, it + * will be the relative execution speed of each case relative to the first case (more is better). */ - def run(): Unit = { + def run(relativeTime: Boolean = false): Unit = { require(benchmarks.nonEmpty) // scalastyle:off println("Running benchmark: " + name) @@ -112,10 +114,12 @@ private[spark] class Benchmark( out.println(Benchmark.getJVMOSInfo()) out.println(Benchmark.getProcessorName()) val nameLen = Math.max(40, Math.max(name.length, benchmarks.map(_.name.length).max)) + val relativeHeader = if (relativeTime) "Relative time" else "Relative" out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n", - name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", "Relative") + name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", relativeHeader) out.println("-" * (nameLen + 80)) results.zip(benchmarks).foreach { case (result, benchmark) => + val relative = if (relativeTime) result.bestMs / firstBest else firstBest / result.bestMs out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n", benchmark.name, "%5.0f" format result.bestMs, @@ -123,7 +127,7 @@ private[spark] class Benchmark( "%5.0f" format result.stdevMs, "%10.1f" format result.bestRate, "%6.1f" format (1000 / result.bestRate), - "%3.1fX" format (firstBest / result.bestMs)) + "%3.1fX" format relative) } out.println() // scalastyle:on diff --git a/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt b/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt index c3ca6fb5e4f65..b2df218c8fbb4 100644 --- a/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt +++ b/sql/core/benchmarks/CollationBenchmark-jdk21-results.txt @@ -1,54 +1,54 @@ OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time -------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 1344 1345 1 0.1 13438.0 1.0X -UTF8_LCASE 2617 2619 3 0.0 26172.7 0.5X -UNICODE 16947 16976 41 0.0 169465.6 0.1X -UNICODE_CI 16500 16507 10 0.0 164997.5 0.1X +UTF8_BINARY 1353 1353 1 0.1 13526.6 1.0X +UTF8_LCASE 2703 2705 3 0.0 27032.4 2.0X +UNICODE 16848 16894 65 0.0 168482.9 12.5X +UNICODE_CI 16362 16367 8 0.0 163615.6 12.1X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time --------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 2555 2560 7 0.0 25553.1 1.0X -UTF8_LCASE 3539 3540 0 0.0 35392.8 0.7X -UNICODE 17154 17161 10 0.0 171541.5 0.1X -UNICODE_CI 16915 16926 17 0.0 169146.6 0.2X +UTF8_BINARY 2640 2642 3 0.0 26401.5 1.0X +UTF8_LCASE 3616 3618 2 0.0 36164.8 1.4X +UNICODE 17465 17470 7 0.0 174650.9 6.6X +UNICODE_CI 17251 17264 18 0.0 172510.9 6.5X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 2771 2771 0 0.0 27708.6 1.0X -UTF8_LCASE 5346 5347 2 0.0 53462.4 0.5X -UNICODE 67678 67692 21 0.0 676775.1 0.0X -UNICODE_CI 57978 57982 6 0.0 579780.7 0.0X +UTF8_BINARY 2843 2844 1 0.0 28427.2 1.0X +UTF8_LCASE 5417 5437 28 0.0 54170.7 1.9X +UNICODE 68601 68619 26 0.0 686010.8 24.1X +UNICODE_CI 56342 56361 26 0.0 563422.2 19.8X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 8793 8794 1 0.0 87929.3 1.0X -UTF8_LCASE 19382 19394 16 0.0 193824.8 0.5X -UNICODE 363790 363911 171 0.0 3637901.0 0.0X -UNICODE_CI 414597 415090 697 0.0 4145972.5 0.0X +UTF8_BINARY 7674 7674 1 0.0 76735.3 1.0X +UTF8_LCASE 20367 20376 14 0.0 203665.1 2.7X +UNICODE 377133 377909 1098 0.0 3771328.8 49.1X +UNICODE_CI 434710 435099 551 0.0 4347097.2 56.7X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 7692 7695 4 0.0 76921.8 1.0X -UTF8_LCASE 16451 16457 9 0.0 164507.8 0.5X -UNICODE 356828 358151 1871 0.0 3568280.7 0.0X -UNICODE_CI 417621 418820 1697 0.0 4176205.0 0.0X +UTF8_BINARY 6956 6959 4 0.0 69561.7 1.0X +UTF8_LCASE 14246 14262 23 0.0 142459.6 2.0X +UNICODE 369940 370072 186 0.0 3699400.9 53.2X +UNICODE_CI 442072 442365 414 0.0 4420718.1 63.6X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 7175 7175 0 0.0 71748.8 1.0X -UTF8_LCASE 15267 15291 34 0.0 152674.6 0.5X -UNICODE 371582 371601 26 0.0 3715822.1 0.0X -UNICODE_CI 429637 430336 989 0.0 4296371.0 0.0X +UTF8_BINARY 6927 6927 0 0.0 69265.2 1.0X +UTF8_LCASE 15505 15514 12 0.0 155054.5 2.2X +UNICODE 382361 382426 93 0.0 3823606.6 55.2X +UNICODE_CI 449956 450063 151 0.0 4499562.9 65.0X diff --git a/sql/core/benchmarks/CollationBenchmark-results.txt b/sql/core/benchmarks/CollationBenchmark-results.txt index bd29e04b8d98f..a63b80f005ed0 100644 --- a/sql/core/benchmarks/CollationBenchmark-results.txt +++ b/sql/core/benchmarks/CollationBenchmark-results.txt @@ -1,54 +1,54 @@ OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time -------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 1372 1373 1 0.1 13721.8 1.0X -UTF8_LCASE 3161 3163 2 0.0 31614.5 0.4X -UNICODE 20065 20074 13 0.0 200648.2 0.1X -UNICODE_CI 19950 19952 3 0.0 199497.1 0.1X +UTF8_BINARY 1372 1372 1 0.1 13718.5 1.0X +UTF8_LCASE 3115 3116 1 0.0 31154.4 2.3X +UNICODE 19813 19820 9 0.0 198132.2 14.4X +UNICODE_CI 19669 19686 24 0.0 196694.2 14.3X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time --------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 1730 1730 0 0.1 17300.7 1.0X -UTF8_LCASE 3181 3183 3 0.0 31807.5 0.5X -UNICODE 18827 18845 26 0.0 188267.7 0.1X -UNICODE_CI 18669 18671 3 0.0 186692.7 0.1X +UTF8_BINARY 1727 1728 1 0.1 17271.3 1.0X +UTF8_LCASE 3034 3035 1 0.0 30337.2 1.8X +UNICODE 19230 19243 18 0.0 192301.2 11.1X +UNICODE_CI 19080 19082 3 0.0 190802.0 11.0X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 3080 3081 1 0.0 30796.9 1.0X -UTF8_LCASE 9640 9670 42 0.0 96402.5 0.3X -UNICODE 65966 66004 53 0.0 659660.4 0.0X -UNICODE_CI 57631 57813 256 0.0 576314.7 0.1X +UTF8_BINARY 3080 3080 0 0.0 30796.4 1.0X +UTF8_LCASE 6436 6454 25 0.0 64360.0 2.1X +UNICODE 68095 68167 101 0.0 680951.3 22.1X +UNICODE_CI 62122 62123 2 0.0 621215.8 20.2X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 8792 8794 2 0.0 87923.1 1.0X -UTF8_LCASE 23937 23942 8 0.0 239367.4 0.4X -UNICODE 371949 372083 189 0.0 3719485.9 0.0X -UNICODE_CI 431636 432296 933 0.0 4316361.9 0.0X +UTF8_BINARY 8260 8261 1 0.0 82604.0 1.0X +UTF8_LCASE 23629 23629 0 0.0 236286.4 2.9X +UNICODE 364843 366078 1747 0.0 3648427.9 44.2X +UNICODE_CI 425728 426449 1020 0.0 4257275.1 51.5X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 6865 6877 16 0.0 68652.5 1.0X -UTF8_LCASE 22100 22118 26 0.0 220997.8 0.3X -UNICODE 365404 365627 316 0.0 3654037.8 0.0X -UNICODE_CI 431409 432257 1199 0.0 4314085.3 0.0X +UTF8_BINARY 6844 6848 5 0.0 68440.4 1.0X +UTF8_LCASE 21849 21870 30 0.0 218486.3 3.2X +UNICODE 363474 363811 476 0.0 3634738.4 53.1X +UNICODE_CI 427563 428029 659 0.0 4275629.8 62.5X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 6998 6999 2 0.0 69980.7 1.0X -UTF8_LCASE 22332 22358 36 0.0 223323.5 0.3X -UNICODE 382527 382805 393 0.0 3825268.8 0.0X -UNICODE_CI 447621 447949 465 0.0 4476209.7 0.0X +UTF8_BINARY 6904 6907 4 0.0 69039.3 1.0X +UTF8_LCASE 22007 22009 3 0.0 220067.8 3.2X +UNICODE 376402 377858 2060 0.0 3764015.4 54.5X +UNICODE_CI 444485 444809 458 0.0 4444850.8 64.4X diff --git a/sql/core/benchmarks/CollationNonASCIIBenchmark-jdk21-results.txt b/sql/core/benchmarks/CollationNonASCIIBenchmark-jdk21-results.txt index 9882b1fced172..574e3c5359100 100644 --- a/sql/core/benchmarks/CollationNonASCIIBenchmark-jdk21-results.txt +++ b/sql/core/benchmarks/CollationNonASCIIBenchmark-jdk21-results.txt @@ -1,54 +1,54 @@ OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time -------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 165 166 1 0.2 4129.8 1.0X -UTF8_LCASE 6978 6980 3 0.0 174454.3 0.0X -UNICODE 5542 5550 11 0.0 138543.0 0.0X -UNICODE_CI 5287 5289 3 0.0 132179.7 0.0X +UTF8_BINARY 165 165 0 0.2 4118.0 1.0X +UTF8_LCASE 6996 7019 33 0.0 174899.5 42.5X +UNICODE 5395 5407 18 0.0 134874.5 32.8X +UNICODE_CI 5670 5672 2 0.0 141756.7 34.4X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time --------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 301 301 0 0.1 7523.1 1.0X -UTF8_LCASE 6857 6871 19 0.0 171426.6 0.0X -UNICODE 5163 5174 16 0.0 129074.3 0.1X -UNICODE_CI 5106 5108 3 0.0 127640.2 0.1X +UTF8_BINARY 306 306 0 0.1 7656.1 1.0X +UTF8_LCASE 6950 6957 11 0.0 173739.0 22.7X +UNICODE 5120 5123 3 0.0 128010.6 16.7X +UNICODE_CI 5080 5099 27 0.0 127011.6 16.6X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 726 727 1 0.1 18158.8 1.0X -UTF8_LCASE 3483 3486 5 0.0 87069.0 0.2X -UNICODE 14715 14717 3 0.0 367874.8 0.0X -UNICODE_CI 11639 11648 12 0.0 290985.0 0.1X +UTF8_BINARY 384 384 1 0.1 9591.1 1.0X +UTF8_LCASE 3549 3550 2 0.0 88721.7 9.3X +UNICODE 14143 14145 3 0.0 353570.2 36.9X +UNICODE_CI 11925 11929 6 0.0 298126.4 31.1X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 1335 1335 0 0.0 33374.9 1.0X -UTF8_LCASE 9042 9070 40 0.0 226042.8 0.1X -UNICODE 69091 69108 24 0.0 1727283.6 0.0X -UNICODE_CI 77261 77295 49 0.0 1931515.1 0.0X +UTF8_BINARY 1375 1376 1 0.0 34375.4 1.0X +UTF8_LCASE 8740 8744 6 0.0 218504.1 6.4X +UNICODE 68707 68818 158 0.0 1717667.1 50.0X +UNICODE_CI 77167 77197 42 0.0 1929168.6 56.1X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 968 969 1 0.0 24189.2 1.0X -UTF8_LCASE 5756 5770 20 0.0 143900.0 0.2X -UNICODE 68123 68210 122 0.0 1703076.1 0.0X -UNICODE_CI 77853 78018 233 0.0 1946331.7 0.0X +UTF8_BINARY 1064 1065 2 0.0 26587.9 1.0X +UTF8_LCASE 5820 5827 10 0.0 145506.0 5.5X +UNICODE 67636 67675 54 0.0 1690904.3 63.6X +UNICODE_CI 77750 77796 65 0.0 1943738.2 73.1X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 990 991 2 0.0 24752.7 1.0X -UTF8_LCASE 5932 5935 3 0.0 148305.3 0.2X -UNICODE 74982 75051 98 0.0 1874541.8 0.0X -UNICODE_CI 83019 83060 58 0.0 2075485.7 0.0X +UTF8_BINARY 1090 1091 0 0.0 27260.9 1.0X +UTF8_LCASE 6049 6054 7 0.0 151221.3 5.5X +UNICODE 74589 74633 62 0.0 1864725.7 68.4X +UNICODE_CI 83674 83708 49 0.0 2091841.0 76.7X diff --git a/sql/core/benchmarks/CollationNonASCIIBenchmark-results.txt b/sql/core/benchmarks/CollationNonASCIIBenchmark-results.txt index 9c3f9b0552783..d4e70f29c245b 100644 --- a/sql/core/benchmarks/CollationNonASCIIBenchmark-results.txt +++ b/sql/core/benchmarks/CollationNonASCIIBenchmark-results.txt @@ -1,54 +1,54 @@ OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - equalsFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time -------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 135 135 0 0.3 3377.6 1.0X -UTF8_LCASE 7196 7196 0 0.0 179892.8 0.0X -UNICODE 6133 6137 6 0.0 153325.1 0.0X -UNICODE_CI 5828 5828 1 0.0 145690.2 0.0X +UTF8_BINARY 133 133 1 0.3 3317.1 1.0X +UTF8_LCASE 7092 7097 6 0.0 177310.9 53.5X +UNICODE 5946 5966 29 0.0 148638.1 44.8X +UNICODE_CI 5715 5717 2 0.0 142885.1 43.1X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - compareFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time --------------------------------------------------------------------------------------------------------------------------- -UTF8_BINARY 436 436 0 0.1 10892.8 1.0X -UTF8_LCASE 7193 7215 31 0.0 179823.7 0.1X -UNICODE 5985 5988 4 0.0 149636.2 0.1X -UNICODE_CI 5945 5946 1 0.0 148634.7 0.1X +UTF8_BINARY 433 435 2 0.1 10816.6 1.0X +UTF8_LCASE 7365 7369 5 0.0 184135.4 17.0X +UNICODE 5785 5790 7 0.0 144616.9 13.4X +UNICODE_CI 5742 5744 3 0.0 143557.1 13.3X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - hashFunction: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 712 713 1 0.1 17808.5 1.0X -UTF8_LCASE 3594 3595 2 0.0 89843.6 0.2X -UNICODE 15549 15553 6 0.0 388714.3 0.0X -UNICODE_CI 13296 13311 22 0.0 332387.9 0.1X +UTF8_BINARY 410 411 1 0.1 10246.1 1.0X +UTF8_LCASE 3588 3589 1 0.0 89698.8 8.8X +UNICODE 15788 15802 20 0.0 394702.8 38.5X +UNICODE_CI 12179 12192 19 0.0 304466.6 29.7X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - contains: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 1374 1376 3 0.0 34344.0 1.0X -UTF8_LCASE 10613 10615 2 0.0 265321.4 0.1X -UNICODE 65820 65876 79 0.0 1645497.8 0.0X -UNICODE_CI 74936 74964 39 0.0 1873403.8 0.0X +UTF8_BINARY 1367 1370 4 0.0 34182.9 1.0X +UTF8_LCASE 9644 9645 1 0.0 241101.2 7.1X +UNICODE 67169 67171 3 0.0 1679230.1 49.1X +UNICODE_CI 79077 79209 188 0.0 1976919.1 57.8X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - startsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 1026 1027 1 0.0 25655.0 1.0X -UTF8_LCASE 6048 6052 5 0.0 151197.9 0.2X -UNICODE 64742 64763 30 0.0 1618539.6 0.0X -UNICODE_CI 74924 74925 1 0.0 1873110.8 0.0X +UTF8_BINARY 1064 1067 3 0.0 26608.1 1.0X +UTF8_LCASE 6487 6491 4 0.0 162186.5 6.1X +UNICODE 68473 68523 71 0.0 1711818.5 64.3X +UNICODE_CI 79374 79419 64 0.0 1984338.0 74.6X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor -collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +collation unit benchmarks - endsWith: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative time ------------------------------------------------------------------------------------------------------------------------ -UTF8_BINARY 1045 1046 1 0.0 26126.0 1.0X -UTF8_LCASE 6047 6049 3 0.0 151163.3 0.2X -UNICODE 72297 72327 42 0.0 1807434.1 0.0X -UNICODE_CI 81904 81917 18 0.0 2047587.7 0.0X +UTF8_BINARY 1002 1004 2 0.0 25061.8 1.0X +UTF8_LCASE 6052 6052 0 0.0 151298.7 6.0X +UNICODE 74506 74551 64 0.0 1862644.2 74.3X +UNICODE_CI 83607 83756 211 0.0 2090164.5 83.4X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala index 86e9320ae9cde..59c2a2847fd08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala @@ -49,7 +49,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { } } } - benchmark.run() + benchmark.run(relativeTime = true) } def benchmarkUTFStringCompare(collationTypes: Seq[String], utf8Strings: Seq[UTF8String]): Unit = { @@ -73,7 +73,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { } } } - benchmark.run() + benchmark.run(relativeTime = true) } def benchmarkUTFStringHashFunction( @@ -99,7 +99,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { } } } - benchmark.run() + benchmark.run(relativeTime = true) } def benchmarkContains( @@ -127,7 +127,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { } } } - benchmark.run() + benchmark.run(relativeTime = true) } def benchmarkStartsWith( @@ -155,7 +155,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { } } } - benchmark.run() + benchmark.run(relativeTime = true) } def benchmarkEndsWith( @@ -183,7 +183,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { } } } - benchmark.run() + benchmark.run(relativeTime = true) } } From 8a85f22ebb8066be52dda420ef09861aa27a7421 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Thu, 29 Aug 2024 14:43:31 -0700 Subject: [PATCH 004/230] [SPARK-49454][SQL] Avoid double normalization in the cache process ### What changes were proposed in this pull request? This PR fixes the issue introduced in https://github.com/apache/spark/pull/46465, which is that normalization is applied twice during the cache process. Some normalization rules may not be idempotent, so applying them repeatedly may break the plan shape and cause an unexpected cache miss. ### Why are the changes needed? Fix a bug as stated above. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Trivial fix; run existing test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47923 from anchovYu/avoid-double-normalization. Authored-by: Xinyi Yu Signed-off-by: Gengliang Wang --- .../spark/sql/execution/CacheManager.scala | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index b96f257e6b5b6..aae424afcb8ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -94,7 +94,13 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { query: Dataset[_], tableName: Option[String], storageLevel: StorageLevel): Unit = { - cacheQueryInternal(query.sparkSession, query.queryExecution.normalized, tableName, storageLevel) + cacheQueryInternal( + query.sparkSession, + query.queryExecution.analyzed, + query.queryExecution.normalized, + tableName, + storageLevel + ) } /** @@ -107,23 +113,25 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { tableName: Option[String], storageLevel: StorageLevel): Unit = { val normalized = QueryExecution.normalize(spark, planToCache) - cacheQueryInternal(spark, normalized, tableName, storageLevel) + cacheQueryInternal(spark, planToCache, normalized, tableName, storageLevel) } - // The `planToCache` should have been normalized. + // The `normalizedPlan` should have been normalized. It is the cache key. private def cacheQueryInternal( spark: SparkSession, - planToCache: LogicalPlan, + unnormalizedPlan: LogicalPlan, + normalizedPlan: LogicalPlan, tableName: Option[String], storageLevel: StorageLevel): Unit = { if (storageLevel == StorageLevel.NONE) { // Do nothing for StorageLevel.NONE since it will not actually cache any data. - } else if (lookupCachedDataInternal(planToCache).nonEmpty) { + } else if (lookupCachedDataInternal(normalizedPlan).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) val inMemoryRelation = sessionWithConfigsOff.withActive { - val qe = sessionWithConfigsOff.sessionState.executePlan(planToCache) + // it creates query execution from unnormalizedPlan plan to avoid multiple normalization. + val qe = sessionWithConfigsOff.sessionState.executePlan(unnormalizedPlan) InMemoryRelation( storageLevel, qe, @@ -131,10 +139,11 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } this.synchronized { - if (lookupCachedDataInternal(planToCache).nonEmpty) { + if (lookupCachedDataInternal(normalizedPlan).nonEmpty) { logWarning("Data has already been cached.") } else { - val cd = CachedData(planToCache, inMemoryRelation) + // the cache key is the normalized plan + val cd = CachedData(normalizedPlan, inMemoryRelation) cachedData = cd +: cachedData CacheManager.logCacheOperation(log"Added Dataframe cache entry:" + log"${MDC(DATAFRAME_CACHE_ENTRY, cd)}") From c7a9c1e0776e9a8df3af5141626e494aab8734d6 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Fri, 30 Aug 2024 07:38:39 +0900 Subject: [PATCH 005/230] [SPARK-49021][SS] Add support for reading transformWithState value state variables with state data source reader ### What changes were proposed in this pull request? Add support for reading transformWithState value state variables with state data source reader Co-authored with jingz-db ### Why are the changes needed? Changes are needed to integrate reading state reading with new operator metadata and state schema format for the value state types used in state variables within transformWithState ### Does this PR introduce _any_ user-facing change? Yes Users can now read valueState variables used in the `transformWithState` operator using the state data source reader. ``` spark .read .format("statestore") .option("operatorId", ) .option("stateVarName", ) .load() ``` ### How was this patch tested? Added unit tests ``` ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.streaming.TransformWithStateSuite, threads: ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3 (daemon=true), rpc-boss-3-1 (daemon=true), ForkJoinPool.commonPool-worker-5 (daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), ForkJoinPool.commonPool-worker-2 (daemon=true), shuffle-boss-6-1 (daemon=true), ForkJoinPool.commonPool-worker-1 (daemon=true) ===== [info] Run completed in 2 minutes, 28 seconds. [info] Total number of tests run: 42 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 42, failed 0, canceled 0, ignored 1, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #47574 from anishshri-db/task/SPARK-49021. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../v2/state/StateDataSource.scala | 245 +++++++++++++++--- .../v2/state/StatePartitionReader.scala | 114 +++++--- .../v2/state/StateScanBuilder.scala | 17 +- .../datasources/v2/state/StateTable.scala | 40 +-- .../state/metadata/StateMetadataSource.scala | 55 ++-- .../v2/state/utils/SchemaUtil.scala | 126 ++++++++- .../v2/state/StateDataSourceReadSuite.scala | 55 +++- ...ateDataSourceTransformWithStateSuite.scala | 220 ++++++++++++++++ 8 files changed, 731 insertions(+), 141 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index acd5303350dec..83399e2cac01b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -24,26 +24,28 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.internal.Logging import org.apache.spark.sql.{RuntimeConfig, SparkSession} import org.apache.spark.sql.catalyst.DataSourceOptions import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, STATE_VAR_NAME} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader -import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata} +import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} +import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil +import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} -import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration /** * An implementation of [[TableProvider]] with [[DataSourceRegister]] for State Store data source. */ -class StateDataSource extends TableProvider with DataSourceRegister { +class StateDataSource extends TableProvider with DataSourceRegister with Logging { private lazy val session: SparkSession = SparkSession.active private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf() @@ -58,24 +60,26 @@ class StateDataSource extends TableProvider with DataSourceRegister { properties: util.Map[String, String]): Table = { val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties) val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId) - // Read the operator metadata once to see if we can find the information for prefix scan - // encoder used in session window aggregation queries. - val allStateStoreMetadata = new StateMetadataPartitionReader( - sourceOptions.stateCheckpointLocation.getParent.toString, serializedHadoopConf, - sourceOptions.batchId) - .stateMetadata.toArray - val stateStoreMetadata = allStateStoreMetadata.filter { entry => - entry.operatorId == sourceOptions.operatorId && - entry.stateStoreName == sourceOptions.storeName + val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions) + + // The key state encoder spec should be available for all operators except stream-stream joins + val keyStateEncoderSpec = if (stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) { + stateStoreReaderInfo.keyStateEncoderSpecOpt.get + } else { + val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] + NoPrefixKeyStateEncoderSpec(keySchema) } - new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata) + new StateTable(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, + stateStoreReaderInfo.transformWithStateVariableInfoOpt, + stateStoreReaderInfo.stateStoreColFamilySchemaOpt) } override def inferSchema(options: CaseInsensitiveStringMap): StructType = { - val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA val sourceOptions = StateSourceOptions.apply(session, hadoopConf, options) + val stateStoreReaderInfo: StateStoreReaderInfo = getStoreMetadataAndRunChecks(sourceOptions) + val stateCheckpointLocation = sourceOptions.stateCheckpointLocation try { val (keySchema, valueSchema) = sourceOptions.joinSide match { @@ -88,34 +92,24 @@ class StateDataSource extends TableProvider with DataSourceRegister { sourceOptions.operatorId, RightSide) case JoinSideValues.none => - val storeId = new StateStoreId(stateCheckpointLocation.toString, sourceOptions.operatorId, - partitionId, sourceOptions.storeName) - val providerId = new StateStoreProviderId(storeId, UUID.randomUUID()) - val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf) - val stateSchema = manager.readSchemaFile().head - (stateSchema.keySchema, stateSchema.valueSchema) - } - - if (sourceOptions.readChangeFeed) { - new StructType() - .add("batch_id", LongType) - .add("change_type", StringType) - .add("key", keySchema) - .add("value", valueSchema) - .add("partition_id", IntegerType) - } else { - new StructType() - .add("key", keySchema) - .add("value", valueSchema) - .add("partition_id", IntegerType) + // we should have the schema for the state store if joinSide is none + require(stateStoreReaderInfo.stateStoreColFamilySchemaOpt.isDefined) + val resultSchema = stateStoreReaderInfo.stateStoreColFamilySchemaOpt.get + (resultSchema.keySchema, resultSchema.valueSchema) } + SchemaUtil.getSourceSchema(sourceOptions, keySchema, + valueSchema, + stateStoreReaderInfo.transformWithStateVariableInfoOpt, + stateStoreReaderInfo.stateStoreColFamilySchemaOpt) } catch { case NonFatal(e) => throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e) } } + override def supportsExternalMetadata(): Boolean = false + private def buildStateStoreConf(checkpointLocation: String, batchId: Long): StateStoreConf = { val offsetLog = new OffsetSeqLog(session, new Path(checkpointLocation, DIR_NAME_OFFSETS).toString) @@ -134,7 +128,161 @@ class StateDataSource extends TableProvider with DataSourceRegister { } } - override def supportsExternalMetadata(): Boolean = false + private def runStateVarChecks( + sourceOptions: StateSourceOptions, + stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = { + val twsShortName = "transformWithStateExec" + if (sourceOptions.stateVarName.isDefined) { + // Perform checks for transformWithState operator in case state variable name is provided + require(stateStoreMetadata.size == 1) + val opMetadata = stateStoreMetadata.head + if (opMetadata.operatorName != twsShortName) { + // if we are trying to query state source with state variable name, then the operator + // should be transformWithState + val errorMsg = "Providing state variable names is only supported with the " + + s"transformWithState operator. Found operator=${opMetadata.operatorName}. " + + s"Please remove this option and re-run the query." + throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, errorMsg) + } + + // if the operator is transformWithState, but the operator properties are empty, then + // the user has not defined any state variables for the operator + val operatorProperties = opMetadata.operatorPropertiesJson + if (operatorProperties.isEmpty) { + throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, + "No state variable names are defined for the transformWithState operator") + } + + // if the state variable is not one of the defined/available state variables, then we + // fail the query + val stateVarName = sourceOptions.stateVarName.get + val twsOperatorProperties = TransformWithStateOperatorProperties.fromJson(operatorProperties) + val stateVars = twsOperatorProperties.stateVariables + if (stateVars.filter(stateVar => stateVar.stateName == stateVarName).size != 1) { + throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, + s"State variable $stateVarName is not defined for the transformWithState operator.") + } + + // TODO: Support change feed and transformWithState together + if (sourceOptions.readChangeFeed) { + throw StateDataSourceErrors.conflictOptions(Seq(StateSourceOptions.READ_CHANGE_FEED, + StateSourceOptions.STATE_VAR_NAME)) + } + } else { + // if the operator is transformWithState, then a state variable argument is mandatory + if (stateStoreMetadata.size == 1 && + stateStoreMetadata.head.operatorName == twsShortName) { + throw StateDataSourceErrors.requiredOptionUnspecified("stateVarName") + } + } + } + + private def getStateStoreMetadata(stateSourceOptions: StateSourceOptions): + Array[StateMetadataTableEntry] = { + val allStateStoreMetadata = new StateMetadataPartitionReader( + stateSourceOptions.stateCheckpointLocation.getParent.toString, + serializedHadoopConf, stateSourceOptions.batchId).stateMetadata.toArray + val stateStoreMetadata = allStateStoreMetadata.filter { entry => + entry.operatorId == stateSourceOptions.operatorId && + entry.stateStoreName == stateSourceOptions.storeName + } + stateStoreMetadata + } + + private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions): + StateStoreReaderInfo = { + val storeMetadata = getStateStoreMetadata(sourceOptions) + runStateVarChecks(sourceOptions, storeMetadata) + var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None + var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None + var transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo] = None + + if (sourceOptions.joinSide == JoinSideValues.none) { + val stateVarName = sourceOptions.stateVarName + .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) + + // Read the schema file path from operator metadata version v2 onwards + // for the transformWithState operator + val oldSchemaFilePath = if (storeMetadata.length > 0 && storeMetadata.head.version == 2 + && storeMetadata.head.operatorName.contains("transformWithStateExec")) { + val storeMetadataEntry = storeMetadata.head + val operatorProperties = TransformWithStateOperatorProperties.fromJson( + storeMetadataEntry.operatorPropertiesJson) + val stateVarInfoList = operatorProperties.stateVariables + .filter(stateVar => stateVar.stateName == stateVarName) + require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " + + s"for state variable $stateVarName in operator ${sourceOptions.operatorId}") + val stateVarInfo = stateVarInfoList.head + transformWithStateVariableInfoOpt = Some(stateVarInfo) + val schemaFilePath = new Path(storeMetadataEntry.stateSchemaFilePath.get) + Some(schemaFilePath) + } else { + None + } + + try { + // Read the actual state schema from the provided path for v2 or from the dedicated path + // for v1 + val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA + val stateCheckpointLocation = sourceOptions.stateCheckpointLocation + val storeId = new StateStoreId(stateCheckpointLocation.toString, sourceOptions.operatorId, + partitionId, sourceOptions.storeName) + val providerId = new StateStoreProviderId(storeId, UUID.randomUUID()) + val manager = new StateSchemaCompatibilityChecker(providerId, hadoopConf, + oldSchemaFilePath = oldSchemaFilePath) + val stateSchema = manager.readSchemaFile() + + // Based on the version and read schema, populate the keyStateEncoderSpec used for + // reading the column families + val resultSchema = stateSchema.filter(_.colFamilyName == stateVarName).head + keyStateEncoderSpecOpt = Some(getKeyStateEncoderSpec(resultSchema, storeMetadata)) + stateStoreColFamilySchemaOpt = Some(resultSchema) + } catch { + case NonFatal(ex) => + throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, ex) + } + } + + StateStoreReaderInfo( + keyStateEncoderSpecOpt, + stateStoreColFamilySchemaOpt, + transformWithStateVariableInfoOpt + ) + } + + private def getKeyStateEncoderSpec( + colFamilySchema: StateStoreColFamilySchema, + storeMetadata: Array[StateMetadataTableEntry]): KeyStateEncoderSpec = { + // If operator metadata is not found, then log a warning and continue with using the no-prefix + // key state encoder + val keyStateEncoderSpec = if (storeMetadata.length == 0) { + logWarning("Metadata for state store not found, possible cause is this checkpoint " + + "is created by older version of spark. If the query has session window aggregation, " + + "the state can't be read correctly and runtime exception will be thrown. " + + "Run the streaming query in newer spark version to generate state metadata " + + "can fix the issue.") + NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema) + } else { + require(storeMetadata.length == 1) + val storeMetadataEntry = storeMetadata.head + // if version has metadata info, then use numColsPrefixKey as specified + if (storeMetadataEntry.version == 1 && storeMetadataEntry.numColsPrefixKey == 0) { + NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema) + } else if (storeMetadataEntry.version == 1 && storeMetadataEntry.numColsPrefixKey > 0) { + PrefixKeyScanStateEncoderSpec(colFamilySchema.keySchema, + storeMetadataEntry.numColsPrefixKey) + } else if (storeMetadataEntry.version == 2) { + // for version 2, we have the encoder spec recorded to the state schema file. so we just + // use that directly + require(colFamilySchema.keyStateEncoderSpec.isDefined) + colFamilySchema.keyStateEncoderSpec.get + } else { + throw StateDataSourceErrors.internalError(s"Failed to read " + + s"key state encoder spec for operator=${storeMetadataEntry.operatorId}") + } + } + keyStateEncoderSpec + } } case class FromSnapshotOptions( @@ -154,12 +302,14 @@ case class StateSourceOptions( joinSide: JoinSideValues, readChangeFeed: Boolean, fromSnapshotOptions: Option[FromSnapshotOptions], - readChangeFeedOptions: Option[ReadChangeFeedOptions]) { + readChangeFeedOptions: Option[ReadChangeFeedOptions], + stateVarName: Option[String]) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + - s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide" + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + + s"stateVarName=${stateVarName.getOrElse("None")}" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" @@ -183,6 +333,7 @@ object StateSourceOptions extends DataSourceOptions { val READ_CHANGE_FEED = newOption("readChangeFeed") val CHANGE_START_BATCH_ID = newOption("changeStartBatchId") val CHANGE_END_BATCH_ID = newOption("changeEndBatchId") + val STATE_VAR_NAME = newOption("stateVarName") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -219,6 +370,10 @@ object StateSourceOptions extends DataSourceOptions { throw StateDataSourceErrors.invalidOptionValueIsEmpty(STORE_NAME) } + // Check if the state variable name is provided. Used with the transformWithState operator. + val stateVarName = Option(options.get(STATE_VAR_NAME)) + .map(_.trim) + val joinSide = try { Option(options.get(JOIN_SIDE)) .map(JoinSideValues.withName).getOrElse(JoinSideValues.none) @@ -322,7 +477,7 @@ object StateSourceOptions extends DataSourceOptions { StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, - readChangeFeed, fromSnapshotOptions, readChangeFeedOptions) + readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName) } private def resolvedCheckpointLocation( @@ -342,3 +497,11 @@ object StateSourceOptions extends DataSourceOptions { } } } + +// Case class to store information around the key state encoder, col family schema and +// operator specific state used primarily for the transformWithState operator. +case class StateStoreReaderInfo( + keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema], + transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo] +) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 6201cf1157ab3..53576c335cb01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -20,8 +20,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil +import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType} import org.apache.spark.sql.types.StructType @@ -36,16 +36,21 @@ class StatePartitionReaderFactory( storeConf: StateStoreConf, hadoopConf: SerializableConfiguration, schema: StructType, - stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory { + keyStateEncoderSpec: KeyStateEncoderSpec, + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) + extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition] if (stateStoreInputPartition.sourceOptions.readChangeFeed) { new StateStoreChangeDataPartitionReader(storeConf, hadoopConf, - stateStoreInputPartition, schema, stateStoreMetadata) + stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, + stateStoreColFamilySchemaOpt) } else { new StatePartitionReader(storeConf, hadoopConf, - stateStoreInputPartition, schema, stateStoreMetadata) + stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt, + stateStoreColFamilySchemaOpt) } } } @@ -59,40 +64,44 @@ abstract class StatePartitionReaderBase( hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, schema: StructType, - stateStoreMetadata: Array[StateMetadataTableEntry]) + keyStateEncoderSpec: KeyStateEncoderSpec, + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends PartitionReader[InternalRow] with Logging { - private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] - private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType] + protected val keySchema = SchemaUtil.getSchemaAsDataType( + schema, "key").asInstanceOf[StructType] + protected val valueSchema = SchemaUtil.getSchemaAsDataType( + schema, "value").asInstanceOf[StructType] protected lazy val provider: StateStoreProvider = { val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString, partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName) val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId) - val numColsPrefixKey = if (stateStoreMetadata.isEmpty) { - logWarning("Metadata for state store not found, possible cause is this checkpoint " + - "is created by older version of spark. If the query has session window aggregation, " + - "the state can't be read correctly and runtime exception will be thrown. " + - "Run the streaming query in newer spark version to generate state metadata " + - "can fix the issue.") - 0 - } else { - require(stateStoreMetadata.length == 1) - stateStoreMetadata.head.numColsPrefixKey - } - // TODO: currently we don't support RangeKeyScanStateEncoderSpec. Support for this will be - // added in the future along with state metadata changes. - // Filed JIRA here: https://issues.apache.org/jira/browse/SPARK-47524 - val keyStateEncoderType = if (numColsPrefixKey > 0) { - PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) + val useColFamilies = if (stateVariableInfoOpt.isDefined) { + true } else { - NoPrefixKeyStateEncoderSpec(keySchema) + false } - StateStoreProvider.createAndInit( - stateStoreProviderId, keySchema, valueSchema, keyStateEncoderType, - useColumnFamilies = false, storeConf, hadoopConf.value, + val provider = StateStoreProvider.createAndInit( + stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, + useColumnFamilies = useColFamilies, storeConf, hadoopConf.value, useMultipleValuesPerKey = false) + + if (useColFamilies) { + val store = provider.getStore(partition.sourceOptions.batchId + 1) + require(stateStoreColFamilySchemaOpt.isDefined) + val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get + require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined) + store.createColFamilyIfAbsent( + stateStoreColFamilySchema.colFamilyName, + stateStoreColFamilySchema.keySchema, + stateStoreColFamilySchema.valueSchema, + stateStoreColFamilySchema.keyStateEncoderSpec.get, + useMultipleValuesPerKey = false) + } + provider } protected val iter: Iterator[InternalRow] @@ -126,8 +135,11 @@ class StatePartitionReader( hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, schema: StructType, - stateStoreMetadata: Array[StateMetadataTableEntry]) - extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { + keyStateEncoderSpec: KeyStateEncoderSpec, + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) + extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, + keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt) { private lazy val store: ReadStateStore = { partition.sourceOptions.fromSnapshotOptions match { @@ -146,21 +158,40 @@ class StatePartitionReader( } override lazy val iter: Iterator[InternalRow] = { - store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value))) + val stateVarName = stateVariableInfoOpt + .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) + store + .iterator(stateVarName) + .map { pair => + stateVariableInfoOpt match { + case Some(stateVarInfo) => + val stateVarType = stateVarInfo.stateVariableType + val hasTTLEnabled = stateVarInfo.ttlEnabled + + stateVarType match { + case StateVariableType.ValueState => + if (hasTTLEnabled) { + SchemaUtil.unifyStateRowPairWithTTL((pair.key, pair.value), valueSchema, + partition.partition) + } else { + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + } + + case _ => + throw new IllegalStateException( + s"Unsupported state variable type: $stateVarType") + } + + case None => + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) + } + } } override def close(): Unit = { store.abort() super.close() } - - private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = { - val row = new GenericInternalRow(3) - row.update(0, pair._1) - row.update(1, pair._2) - row.update(2, partition.partition) - row - } } /** @@ -172,8 +203,11 @@ class StateStoreChangeDataPartitionReader( hadoopConf: SerializableConfiguration, partition: StateStoreInputPartition, schema: StructType, - stateStoreMetadata: Array[StateMetadataTableEntry]) - extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, stateStoreMetadata) { + keyStateEncoderSpec: KeyStateEncoderSpec, + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) + extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, + keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt) { private lazy val changeDataReader: NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala index 01f966ae948ac..1bb992eb9addd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala @@ -25,9 +25,9 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} -import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo +import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -37,9 +37,11 @@ class StateScanBuilder( schema: StructType, sourceOptions: StateSourceOptions, stateStoreConf: StateStoreConf, - stateStoreMetadata: Array[StateMetadataTableEntry]) extends ScanBuilder { + keyStateEncoderSpec: KeyStateEncoderSpec, + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends ScanBuilder { override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf, - stateStoreMetadata) + keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt) } /** An implementation of [[InputPartition]] for State Store data source. */ @@ -54,7 +56,10 @@ class StateScan( schema: StructType, sourceOptions: StateSourceOptions, stateStoreConf: StateStoreConf, - stateStoreMetadata: Array[StateMetadataTableEntry]) extends Scan with Batch { + keyStateEncoderSpec: KeyStateEncoderSpec, + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends Scan + with Batch { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val hadoopConfBroadcast = session.sparkContext.broadcast( @@ -123,7 +128,7 @@ class StateScan( case JoinSideValues.none => new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema, - stateStoreMetadata) + keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt) } override def toBatch: Batch = this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 2fc85cd8aa968..4069a52f38b13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -24,12 +24,11 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil -import org.apache.spark.sql.execution.streaming.state.StateStoreConf -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo +import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, StateStoreColFamilySchema, StateStoreConf} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.util.ArrayImplicits._ /** An implementation of [[Table]] with [[SupportsRead]] for State Store data source. */ class StateTable( @@ -37,12 +36,14 @@ class StateTable( override val schema: StructType, sourceOptions: StateSourceOptions, stateConf: StateStoreConf, - stateStoreMetadata: Array[StateMetadataTableEntry]) + keyStateEncoderSpec: KeyStateEncoderSpec, + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends Table with SupportsRead with SupportsMetadataColumns { import StateTable._ - if (!isValidSchema(schema)) { + if (!SchemaUtil.isValidSchema(sourceOptions, schema, stateVariableInfoOpt)) { throw StateDataSourceErrors.internalError( s"Invalid schema is provided. Provided schema: $schema for " + s"checkpoint location: ${sourceOptions.stateCheckpointLocation} , operatorId: " + @@ -77,34 +78,11 @@ class StateTable( override def capabilities(): util.Set[TableCapability] = CAPABILITY override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = - new StateScanBuilder(session, schema, sourceOptions, stateConf, stateStoreMetadata) + new StateScanBuilder(session, schema, sourceOptions, stateConf, keyStateEncoderSpec, + stateVariableInfoOpt, stateStoreColFamilySchemaOpt) override def properties(): util.Map[String, String] = Map.empty[String, String].asJava - private def isValidSchema(schema: StructType): Boolean = { - val expectedFieldNames = - if (sourceOptions.readChangeFeed) { - Seq("batch_id", "change_type", "key", "value", "partition_id") - } else { - Seq("key", "value", "partition_id") - } - val expectedTypes = Map( - "batch_id" -> classOf[LongType], - "change_type" -> classOf[StringType], - "key" -> classOf[StructType], - "value" -> classOf[StructType], - "partition_id" -> classOf[IntegerType]) - - if (schema.fieldNames.toImmutableArraySeq != expectedFieldNames) { - false - } else { - schema.fieldNames.forall { fieldName => - expectedTypes(fieldName).isAssignableFrom( - SchemaUtil.getSchemaAsDataType(schema, fieldName).getClass) - } - } - } - override def metadataColumns(): Array[MetadataColumn] = Array.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala index afd6a190b0ca5..64fdfb7997623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala @@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow @@ -46,6 +47,7 @@ case class StateMetadataTableEntry( numPartitions: Int, minBatchId: Long, maxBatchId: Long, + version: Int, operatorPropertiesJson: String, numColsPrefixKey: Int, stateSchemaFilePath: Option[String]) { @@ -87,7 +89,7 @@ class StateMetadataSource extends TableProvider with DataSourceRegister { override def inferSchema(options: CaseInsensitiveStringMap): StructType = { // The schema of state metadata table is static. - StateMetadataTableEntry.schema + StateMetadataTableEntry.schema } } @@ -159,7 +161,7 @@ case class StateMetadataPartitionReaderFactory( class StateMetadataPartitionReader( checkpointLocation: String, serializedHadoopConf: SerializableConfiguration, - batchId: Long) extends PartitionReader[InternalRow] { + batchId: Long) extends PartitionReader[InternalRow] with Logging { override def next(): Boolean = { stateMetadata.hasNext @@ -205,26 +207,35 @@ class StateMetadataPartitionReader( // Need this to be accessible from IncrementalExecution for the planning rule. private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = { - val stateDir = new Path(checkpointLocation, "state") - val opIds = fileManager - .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted - opIds.map { opId => - val operatorIdPath = new Path(stateDir, opId.toString) - // check if OperatorStateMetadataV2 path exists, if it does, read it - // otherwise, fall back to OperatorStateMetadataV1 - val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataDirPath(operatorIdPath) - val operatorStateMetadataVersion = if (fileManager.exists(operatorStateMetadataV2Path)) { - 2 - } else { - 1 - } - - OperatorStateMetadataReader.createReader( - operatorIdPath, hadoopConf, operatorStateMetadataVersion, batchId).read() match { - case Some(metadata) => metadata - case None => throw StateDataSourceErrors.failedToReadOperatorMetadata(checkpointLocation, - batchId) + try { + val stateDir = new Path(checkpointLocation, "state") + val opIds = fileManager + .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted + opIds.map { opId => + val operatorIdPath = new Path(stateDir, opId.toString) + // check if OperatorStateMetadataV2 path exists, if it does, read it + // otherwise, fall back to OperatorStateMetadataV1 + val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataDirPath(operatorIdPath) + val operatorStateMetadataVersion = if (fileManager.exists(operatorStateMetadataV2Path)) { + 2 + } else { + 1 + } + OperatorStateMetadataReader.createReader( + operatorIdPath, hadoopConf, operatorStateMetadataVersion, batchId).read() match { + case Some(metadata) => metadata + case None => throw StateDataSourceErrors.failedToReadOperatorMetadata(checkpointLocation, + batchId) + } } + } catch { + // if the operator metadata is not present, catch the exception + // and return an empty array + case ex: Exception => + logWarning(log"Failed to find operator metadata for " + + log"path=${MDC(LogKeys.CHECKPOINT_LOCATION, checkpointLocation)} " + + log"with exception=${MDC(LogKeys.EXCEPTION, ex)}") + Array.empty } } @@ -242,6 +253,7 @@ class StateMetadataPartitionReader( stateStoreMetadata.numPartitions, if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, + operatorStateMetadata.version, null, stateStoreMetadata.numColsPrefixKey, None @@ -255,6 +267,7 @@ class StateMetadataPartitionReader( stateStoreMetadata.numPartitions, if (batchIds.nonEmpty) batchIds.head else -1, if (batchIds.nonEmpty) batchIds.last else -1, + operatorStateMetadata.version, v2.operatorPropertiesJson, -1, // numColsPrefixKey is not available in OperatorStateMetadataV2 Some(stateStoreMetadata.stateSchemaFilePath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 54c6b34db9723..9dd357530ec40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -17,7 +17,13 @@ package org.apache.spark.sql.execution.datasources.v2.state.utils import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions} +import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType} +import org.apache.spark.util.ArrayImplicits._ object SchemaUtil { def getSchemaAsDataType(schema: StructType, fieldName: String): DataType = { @@ -30,4 +36,122 @@ object SchemaUtil { "schema" -> schema.toString())) } } + + def getSourceSchema( + sourceOptions: StateSourceOptions, + keySchema: StructType, + valueSchema: StructType, + transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo], + stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]): StructType = { + if (sourceOptions.readChangeFeed) { + new StructType() + .add("batch_id", LongType) + .add("change_type", StringType) + .add("key", keySchema) + .add("value", valueSchema) + .add("partition_id", IntegerType) + } else if (transformWithStateVariableInfoOpt.isDefined) { + require(stateStoreColFamilySchemaOpt.isDefined) + generateSchemaForStateVar(transformWithStateVariableInfoOpt.get, + stateStoreColFamilySchemaOpt.get) + } else { + new StructType() + .add("key", keySchema) + .add("value", valueSchema) + .add("partition_id", IntegerType) + } + } + + def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow), partition: Int): InternalRow = { + val row = new GenericInternalRow(3) + row.update(0, pair._1) + row.update(1, pair._2) + row.update(2, partition) + row + } + + def unifyStateRowPairWithTTL( + pair: (UnsafeRow, UnsafeRow), + valueSchema: StructType, + partition: Int): InternalRow = { + val row = new GenericInternalRow(4) + row.update(0, pair._1) + row.update(1, pair._2.get(0, valueSchema)) + row.update(2, pair._2.get(1, LongType)) + row.update(3, partition) + row + } + + def isValidSchema( + sourceOptions: StateSourceOptions, + schema: StructType, + transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = { + val expectedTypes = Map( + "batch_id" -> classOf[LongType], + "change_type" -> classOf[StringType], + "key" -> classOf[StructType], + "value" -> classOf[StructType], + "partition_id" -> classOf[IntegerType], + "expiration_timestamp" -> classOf[LongType]) + + val expectedFieldNames = if (sourceOptions.readChangeFeed) { + Seq("batch_id", "change_type", "key", "value", "partition_id") + } else if (transformWithStateVariableInfoOpt.isDefined) { + val stateVarInfo = transformWithStateVariableInfoOpt.get + val hasTTLEnabled = stateVarInfo.ttlEnabled + val stateVarType = stateVarInfo.stateVariableType + + stateVarType match { + case StateVariableType.ValueState => + if (hasTTLEnabled) { + Seq("key", "value", "expiration_timestamp", "partition_id") + } else { + Seq("key", "value", "partition_id") + } + + case _ => + throw StateDataSourceErrors + .internalError(s"Unsupported state variable type $stateVarType") + } + } else { + Seq("key", "value", "partition_id") + } + + if (schema.fieldNames.toImmutableArraySeq != expectedFieldNames) { + false + } else { + schema.fieldNames.forall { fieldName => + expectedTypes(fieldName).isAssignableFrom( + SchemaUtil.getSchemaAsDataType(schema, fieldName).getClass) + } + } + } + + private def generateSchemaForStateVar( + stateVarInfo: TransformWithStateVariableInfo, + stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = { + val stateVarType = stateVarInfo.stateVariableType + val hasTTLEnabled = stateVarInfo.ttlEnabled + + stateVarType match { + case StateVariableType.ValueState => + if (hasTTLEnabled) { + val ttlValueSchema = SchemaUtil.getSchemaAsDataType( + stateStoreColFamilySchema.valueSchema, "value").asInstanceOf[StructType] + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("value", ttlValueSchema) + .add("expiration_timestamp", LongType) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("value", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } + + case _ => + throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index e6cdd0dce9efa..97c88037a7171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, Offset import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{OutputMode, TimeMode, TransformWithStateSuiteUtils} import org.apache.spark.sql.types.{IntegerType, StructType} class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase { @@ -268,6 +268,25 @@ class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase { "message" -> s"value should be less than or equal to $endBatchId")) } } + + test("ERROR: trying to specify state variable name with " + + "non-transformWithState operator") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) + + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.STATE_VAR_NAME, "test") + .load(tempDir.getAbsolutePath) + } + checkError(exc, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", Some("42616"), + Map("optionName" -> StateSourceOptions.STATE_VAR_NAME, + "message" -> ".*"), + matchPVals = true) + } + } } /** @@ -429,6 +448,40 @@ class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite { spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled", "false") } + + test("ERROR: Do not provide state variable name with " + + "transformWithState operator") { + import testImplicits._ + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithSingleValueVar(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + StopStream + ) + + val e = intercept[StateDataSourceUnspecifiedRequiredOption] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .load() + } + checkError(e, "STDS_REQUIRED_OPTION_UNSPECIFIED", Some("42601"), + Map("optionName" -> "stateVarName")) + } + } + } } class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite extends diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala new file mode 100644 index 0000000000000..ccd4e005756ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.time.Duration + +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState} + +/** Stateful processor of single value state var with non-primitive type */ +class StatefulProcessorWithSingleValueVar extends RunningCountStatefulProcessor { + @transient private var _valueState: ValueState[TestClass] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _valueState = getHandle.getValueState[TestClass]( + "valueState", Encoders.product[TestClass]) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _valueState.getOption().getOrElse(TestClass(0L, "dummyKey")).id + 1 + _valueState.update(TestClass(count, "dummyKey")) + Iterator((key, count.toString)) + } +} + +class StatefulProcessorWithTTL + extends StatefulProcessor[String, String, (String, String)] { + @transient protected var _countState: ValueState[Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle.getValueState[Long]("countState", + Encoders.scalaLong, TTLConfig(Duration.ofMillis(30000))) + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + +/** + * Test suite to verify integration of state data source reader with the transformWithState operator + */ +class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest + with AlsoTestWithChangelogCheckpointingEnabled { + + import testImplicits._ + + test("state data source integration - value state with single variable") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithSingleValueVar(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData, "b"), + CheckNewAnswer(("b", "1")), + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "valueState") + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", + "value.id AS valueId", "value.name AS valueName", + "partition_id") + + checkAnswer(resultDf, + Seq(Row("a", 1L, "dummyKey", 0), Row("b", 1L, "dummyKey", 1))) + + // non existent state variable should fail + val ex = intercept[Exception] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "non-exist") + .load() + } + assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue]) + assert(ex.getMessage.contains("State variable non-exist is not defined")) + + // TODO: this should be removed when readChangeFeed is supported for value state + val ex1 = intercept[Exception] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "valueState") + .option(StateSourceOptions.READ_CHANGE_FEED, "true") + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + } + assert(ex1.isInstanceOf[StateDataSourceConflictOptions]) + } + } + } + + test("state data source integration - value state with single variable and TTL") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, "a"), + AddData(inputData, "b"), + Execute { _ => + // wait for the batch to run since we are using processing time + Thread.sleep(5000) + }, + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value", "value.value", "expiration_timestamp", "partition_id") + + var count = 0L + resultDf.collect().foreach { row => + count = count + 1 + assert(row.getLong(2) > 0) + } + + // verify that 2 state rows are present + assert(count === 2) + + val answerDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", + "value.value AS valueId", "partition_id") + checkAnswer(answerDf, + Seq(Row("a", 1L, 0), Row("b", 1L, 1))) + + // non existent state variable should fail + val ex = intercept[Exception] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "non-exist") + .load() + } + assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue]) + assert(ex.getMessage.contains("State variable non-exist is not defined")) + + // TODO: this should be removed when readChangeFeed is supported for TTL based state + // variables + val ex1 = intercept[Exception] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.READ_CHANGE_FEED, "true") + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + } + assert(ex1.isInstanceOf[StateDataSourceConflictOptions]) + } + } + } +} From ebe635eb1d44dee879623e8646bd3be7424b5676 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 30 Aug 2024 09:13:58 +0900 Subject: [PATCH 006/230] [SPARK-48482][PYTHON][FOLLOWUP] Revert dropDuplicates and dropDuplicatesWIthinWatermark should accept variable length args ### What changes were proposed in this pull request? Per conversation from https://github.com/apache/spark/pull/47835#issuecomment-2311082085, we will revert 560c08332b35941260169124b4f522bdc82b84d8 for API parity with Pandas API ### Why are the changes needed? Bug fix ### Does this PR introduce _any_ user-facing change? Yes, reverting the API would reenable user to use `dropDuplicates(subset=xxx)` ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #47916 from WweiL/revert-dropDuplicates-api. Authored-by: Wei Liu Signed-off-by: Hyukjin Kwon --- .../structured-streaming-programming-guide.md | 6 +- python/pyspark/sql/classic/dataframe.py | 51 ++++++++------- python/pyspark/sql/connect/dataframe.py | 62 +++++++------------ python/pyspark/sql/connect/plan.py | 2 +- python/pyspark/sql/dataframe.py | 22 ++----- .../sql/tests/connect/test_connect_basic.py | 12 ---- .../sql/tests/connect/test_connect_plan.py | 20 ++---- python/pyspark/sql/tests/test_dataframe.py | 30 +++------ 8 files changed, 69 insertions(+), 136 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index aa20da6ae81d9..d266e761263b6 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -2082,12 +2082,12 @@ You can deduplicate records in data streams using a unique identifier in the eve streamingDf = spark.readStream. ... # Without watermark using guid column -streamingDf.dropDuplicates("guid") +streamingDf.dropDuplicates(["guid"]) # With watermark using guid and eventTime columns streamingDf \ .withWatermark("eventTime", "10 seconds") \ - .dropDuplicates("guid", "eventTime") + .dropDuplicates(["guid", "eventTime"]) {% endhighlight %} @@ -2163,7 +2163,7 @@ streamingDf = spark.readStream. ... # deduplicate using guid column with watermark based on eventTime column streamingDf \ .withWatermark("eventTime", "10 hours") \ - .dropDuplicatesWithinWatermark("guid") + .dropDuplicatesWithinWatermark(["guid"]) {% endhighlight %} diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index fb2bb3c227034..0e890e3343e66 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -1222,23 +1222,17 @@ def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame: def subtract(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sparkSession) - def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: - # Acceptable args should be str, ... or a single List[str] - # So if subset length is 1, it can be either single str, or a list of str - # if subset length is greater than 1, it must be a sequence of str - if not subset: + def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: + if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)): + raise PySparkTypeError( + errorClass="NOT_LIST_OR_TUPLE", + messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, + ) + + if subset is None: jdf = self._jdf.dropDuplicates() - elif len(subset) == 1 and isinstance(subset[0], list): - item = subset[0] - for c in item: - if not isinstance(c, str): - raise PySparkTypeError( - errorClass="NOT_STR", - messageParameters={"arg_name": "subset", "arg_type": type(c).__name__}, - ) - jdf = self._jdf.dropDuplicates(self._jseq(item)) else: - for c in subset: # type: ignore[assignment] + for c in subset: if not isinstance(c, str): raise PySparkTypeError( errorClass="NOT_STR", @@ -1247,20 +1241,22 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sparkSession) - drop_duplicates = dropDuplicates - - def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> ParentDataFrame: - # Acceptable args should be str, ... or a single List[str] - # So if subset length is 1, it can be either single str, or a list of str - # if subset length is greater than 1, it must be a sequence of str - if len(subset) > 1: - assert all(isinstance(c, str) for c in subset) + def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> ParentDataFrame: + if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)): + raise PySparkTypeError( + errorClass="NOT_LIST_OR_TUPLE", + messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, + ) - if not subset: + if subset is None: jdf = self._jdf.dropDuplicatesWithinWatermark() - elif len(subset) == 1 and isinstance(subset[0], list): - jdf = self._jdf.dropDuplicatesWithinWatermark(self._jseq(subset[0])) else: + for c in subset: + if not isinstance(c, str): + raise PySparkTypeError( + errorClass="NOT_STR", + messageParameters={"arg_name": "subset", "arg_type": type(c).__name__}, + ) jdf = self._jdf.dropDuplicatesWithinWatermark(self._jseq(subset)) return DataFrame(jdf, self.sparkSession) @@ -1805,6 +1801,9 @@ def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "Grouped def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc] return self.groupBy(*cols) + def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: + return self.dropDuplicates(subset) + def writeTo(self, table: str) -> DataFrameWriterV2: return DataFrameWriterV2(self, table) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 846f1109a92d0..442157eef0b75 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -426,40 +426,29 @@ def repartitionByRange( # type: ignore[misc] "arg_type": type(numPartitions).__name__, }, ) - res._cached_schema = self._cached_schema return res - def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: - # Acceptable args should be str, ... or a single List[str] - # So if subset length is 1, it can be either single str, or a list of str - # if subset length is greater than 1, it must be a sequence of str - if not subset: - res = DataFrame( - plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session + def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: + if subset is not None and not isinstance(subset, (list, tuple)): + raise PySparkTypeError( + errorClass="NOT_LIST_OR_TUPLE", + messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, ) - elif len(subset) == 1 and isinstance(subset[0], list): - item = subset[0] - for c in item: - if not isinstance(c, str): - raise PySparkTypeError( - errorClass="NOT_STR", - messageParameters={"arg_name": "subset", "arg_type": type(c).__name__}, - ) + + if subset is None: res = DataFrame( - plan.Deduplicate(child=self._plan, column_names=item), - session=self._session, + plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session ) else: - for c in subset: # type: ignore[assignment] + for c in subset: if not isinstance(c, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "subset", "arg_type": type(c).__name__}, ) res = DataFrame( - plan.Deduplicate(child=self._plan, column_names=cast(List[str], subset)), - session=self._session, + plan.Deduplicate(child=self._plan, column_names=subset), session=self._session ) res._cached_schema = self._cached_schema @@ -467,30 +456,27 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> ParentDataFrame: drop_duplicates = dropDuplicates - def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> ParentDataFrame: - # Acceptable args should be str, ... or a single List[str] - # So if subset length is 1, it can be either single str, or a list of str - # if subset length is greater than 1, it must be a sequence of str - if len(subset) > 1: - assert all(isinstance(c, str) for c in subset) + def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> ParentDataFrame: + if subset is not None and not isinstance(subset, (list, tuple)): + raise PySparkTypeError( + errorClass="NOT_LIST_OR_TUPLE", + messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, + ) - if not subset: + if subset is None: return DataFrame( plan.Deduplicate(child=self._plan, all_columns_as_keys=True, within_watermark=True), session=self._session, ) - elif len(subset) == 1 and isinstance(subset[0], list): - return DataFrame( - plan.Deduplicate(child=self._plan, column_names=subset[0], within_watermark=True), - session=self._session, - ) else: + for c in subset: + if not isinstance(c, str): + raise PySparkTypeError( + errorClass="NOT_STR", + messageParameters={"arg_name": "subset", "arg_type": type(c).__name__}, + ) return DataFrame( - plan.Deduplicate( - child=self._plan, - column_names=cast(List[str], subset), - within_watermark=True, - ), + plan.Deduplicate(child=self._plan, column_names=subset, within_watermark=True), session=self._session, ) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index b9c60c04d0f0d..958626280e41c 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -686,7 +686,7 @@ def __init__( self, child: Optional["LogicalPlan"], all_columns_as_keys: bool = False, - column_names: Optional[Sequence[str]] = None, + column_names: Optional[List[str]] = None, within_watermark: bool = False, ) -> None: super().__init__(child) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6f1afaba5f98c..7d3900c7afbc5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -4570,7 +4570,7 @@ def subtract(self, other: "DataFrame") -> "DataFrame": ... @dispatch_df_method - def dropDuplicates(self, *subset: Union[str, List[str]]) -> "DataFrame": + def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -4587,9 +4587,6 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> "DataFrame": .. versionchanged:: 3.4.0 Supports Spark Connect. - .. versionchanged:: 4.0.0 - Supports variable-length argument - Parameters ---------- subset : list of column names, optional @@ -4621,7 +4618,7 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> "DataFrame": Deduplicate values on 'name' and 'height' columns. - >>> df.dropDuplicates('name', 'height').show() + >>> df.dropDuplicates(['name', 'height']).show() +-----+---+------+ | name|age|height| +-----+---+------+ @@ -4631,7 +4628,7 @@ def dropDuplicates(self, *subset: Union[str, List[str]]) -> "DataFrame": ... @dispatch_df_method - def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> "DataFrame": + def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns, within watermark. @@ -4648,9 +4645,6 @@ def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> "Data .. versionadded:: 3.5.0 - .. versionchanged:: 4.0.0 - Supports variable-length argument - Parameters ---------- subset : List of column names, optional @@ -4680,7 +4674,7 @@ def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> "Data Deduplicate values on 'value' columns. - >>> df.dropDuplicatesWithinWatermark('value') # doctest: +SKIP + >>> df.dropDuplicatesWithinWatermark(['value']) # doctest: +SKIP """ ... @@ -5937,17 +5931,11 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": ... @dispatch_df_method - def drop_duplicates(self, *subset: Union[str, List[str]]) -> "DataFrame": + def drop_duplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": """ :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. .. versionadded:: 1.4.0 - - .. versionchanged:: 3.4.0 - Supports Spark Connect - - .. versionchanged:: 4.0.0 - Supports variable-length argument """ ... diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f084601d2e7b0..f0637056ab8f9 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -670,18 +670,6 @@ def test_deduplicate(self): self.assert_eq( df.dropDuplicates(["name"]).toPandas(), df2.dropDuplicates(["name"]).toPandas() ) - self.assert_eq( - df.drop_duplicates(["name"]).toPandas(), df2.drop_duplicates(["name"]).toPandas() - ) - self.assert_eq( - df.dropDuplicates(["name", "id"]).toPandas(), - df2.dropDuplicates(["name", "id"]).toPandas(), - ) - self.assert_eq( - df.drop_duplicates(["name", "id"]).toPandas(), - df2.drop_duplicates(["name", "id"]).toPandas(), - ) - self.assert_eq(df.dropDuplicates("name").toPandas(), df2.dropDuplicates("name").toPandas()) def test_drop(self): # SPARK-41169: test drop diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index 1b373b2e1944a..a03cd30c733fb 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -553,25 +553,13 @@ def test_deduplicate(self): self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys, True) self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names), 0) - deduplicate_on_subset_columns_plan_list_arg = df.dropDuplicates( - ["name", "height"] - )._plan.to_proto(self.connect) - self.assertEqual( - deduplicate_on_subset_columns_plan_list_arg.root.deduplicate.all_columns_as_keys, False - ) - self.assertEqual( - len(deduplicate_on_subset_columns_plan_list_arg.root.deduplicate.column_names), 2 - ) - - deduplicate_on_subset_columns_plan_var_arg = df.dropDuplicates( - "name", "height" - )._plan.to_proto(self.connect) - self.assertEqual( - deduplicate_on_subset_columns_plan_var_arg.root.deduplicate.all_columns_as_keys, False + deduplicate_on_subset_columns_plan = df.dropDuplicates(["name", "height"])._plan.to_proto( + self.connect ) self.assertEqual( - len(deduplicate_on_subset_columns_plan_var_arg.root.deduplicate.column_names), 2 + deduplicate_on_subset_columns_plan.root.deduplicate.all_columns_as_keys, False ) + self.assertEqual(len(deduplicate_on_subset_columns_plan.root.deduplicate.column_names), 2) def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 4e2d3b9ba42a2..a214b874f5ec0 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -266,35 +266,28 @@ def test_ordering_of_with_columns_renamed(self): self.assertEqual(df2.columns, ["a"]) def test_drop_duplicates(self): + # SPARK-36034 test that drop duplicates throws a type error when in correct type provided df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"]) # shouldn't drop a non-null row self.assertEqual(df.dropDuplicates().count(), 2) self.assertEqual(df.dropDuplicates(["name"]).count(), 1) - self.assertEqual(df.dropDuplicates(["name", "age"]).count(), 2) - - self.assertEqual(df.drop_duplicates(["name"]).count(), 1) - self.assertEqual(df.drop_duplicates(["name", "age"]).count(), 2) - # SPARK-48482 dropDuplicates should take varargs - self.assertEqual(df.dropDuplicates("name").count(), 1) - self.assertEqual(df.dropDuplicates("name", "age").count(), 2) - self.assertEqual(df.drop_duplicates("name").count(), 1) - self.assertEqual(df.drop_duplicates("name", "age").count(), 2) + self.assertEqual(df.dropDuplicates(["name", "age"]).count(), 2) - # Should raise proper error when taking non-string values with self.assertRaises(PySparkTypeError) as pe: - df.dropDuplicates([None]).show() + df.dropDuplicates("name") self.check_error( exception=pe.exception, - errorClass="NOT_STR", - messageParameters={"arg_name": "subset", "arg_type": "NoneType"}, + errorClass="NOT_LIST_OR_TUPLE", + messageParameters={"arg_name": "subset", "arg_type": "str"}, ) + # Should raise proper error when taking non-string values with self.assertRaises(PySparkTypeError) as pe: - df.dropDuplicates(None).show() + df.dropDuplicates([None]).show() self.check_error( exception=pe.exception, @@ -311,15 +304,6 @@ def test_drop_duplicates(self): messageParameters={"arg_name": "subset", "arg_type": "int"}, ) - with self.assertRaises(PySparkTypeError) as pe: - df.dropDuplicates(1).show() - - self.check_error( - exception=pe.exception, - errorClass="NOT_STR", - messageParameters={"arg_name": "subset", "arg_type": "int"}, - ) - def test_drop_duplicates_with_ambiguous_reference(self): df1 = self.spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"]) df2 = self.spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")]) From cf487b9f81fb8b95d0596111187b863dc9037855 Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Fri, 30 Aug 2024 11:24:20 +0900 Subject: [PATCH 007/230] [SPARK-49363][SS][TESTS] Add unit tests for potential RocksDB state store SST file mismatch ### What changes were proposed in this pull request? Add unit test to for RocksDB state store snapshot checkpointing for changelog. We intentionally add the same content in each batch, so that it is likely that SST files generated are all of the same size. We have some randomness on loading the existing version or move to the next, and whether maintenance task is executed. All three tests would fail for previous versions but not in master. ### Why are the changes needed? Recently we discovered some RocksDB state store file version ID mismatch issues. Although it happens to have been fixed by other change, we don't have test coverage for it. Add unit tests for them. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Run the tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #47850 from siying/cp_test. Authored-by: Siying Dong Signed-off-by: Jungtaek Lim --- .../streaming/state/RocksDBSuite.scala | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 90b7c26040763..d07ce07c41e5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.language.implicitConversions +import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration @@ -1770,6 +1771,138 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + testWithChangelogCheckpointingEnabled("reloading the same version") { + // Keep executing the same batch for two or more times. Some queries with ForEachBatch + // will cause this behavior. + // The test was accidentally fixed by SPARK-48586 (https://github.com/apache/spark/pull/47130) + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 2, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf) { db => + // load the same version of pending snapshot uploading + // This is possible because after committing version x, we can continue to x+1, and replay + // x+1. The replay will load a checkpoint by version x. At this moment, the snapshot + // uploading may not be finished. + // Previously this generated a problem: new files generated by reloading are added to + // local -> cloud file map and the information is used to skip some files uploading, which is + // wrong because these files aren't a part of the RocksDB checkpoint. + // This test was accidentally fixed by + // SPARK-48931 (https://github.com/apache/spark/pull/47393) + + db.load(0) + db.put("foo", "bar") + // Snapshot checkpoint not needed + db.commit() + + // Continue using local DB + db.load(1) + db.put("foo", "bar") + // Should create a local RocksDB snapshot + db.commit() + // Upload the local RocksDB snapshot to the cloud with 2.zip + db.doMaintenance() + + // This will reload Db from the cloud. + db.load(1) + db.put("foo", "bar") + // Should create another local snapshot + db.commit() + + // Continue using local DB + db.load(2) + db.put("foo", "bar") + // Snapshot checkpoint not needed + db.commit() + + // Reload DB from the cloud, loading from 2.zip + db.load(2) + db.put("foo", "bar") + // Snapshot checkpoint not needed + db.commit() + + // Will upload local snapshot and overwrite 2.zip + db.doMaintenance() + + // Reload new 2.zip just uploaded to validate it is not corrupted. + db.load(2) + db.put("foo", "bar") + db.commit() + + // Test the maintenance thread is delayed even after the next snapshot is created. + // There will be two outstanding snapshots. + for (batchVersion <- 3 to 6) { + db.load(batchVersion) + db.put("foo", "bar") + // In batchVersion 3 and 5, it will generate a local snapshot but won't be uploaded. + db.commit() + } + db.doMaintenance() + + // Test the maintenance is called after each batch. This tests a common case where + // maintenance tasks finish quickly. + for (batchVersion <- 7 to 10) { + for (j <- 0 to 1) { + db.load(batchVersion) + db.put("foo", "bar") + db.commit() + db.doMaintenance() + } + } + } + } + + for (randomSeed <- 1 to 8) { + for (ifTestSkipBatch <- 0 to 1) { + testWithChangelogCheckpointingEnabled( + s"randomized snapshotting $randomSeed ifTestSkipBatch $ifTestSkipBatch") { + // The unit test simulates the case where batches can be reloaded and maintenance tasks + // can be delayed. After each batch, we randomly decide whether we would move onto the + // next batch, and whetehr maintenance task is executed. + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 3, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf) { db => + // A second DB is opened to simulate another executor that runs some batches that + // skipped in the current DB. + withDB(remoteDir, conf = conf) { db2 => + val random = new Random(randomSeed) + var curVer: Int = 0 + for (i <- 1 to 100) { + db.load(curVer) + db.put("foo", "bar") + db.commit() + // For a one in five chance, maintenance task is executed. The chance is created to + // simulate the case where snapshot isn't immediatelly uploaded, and even delayed + // so that the next snapshot is ready. We create a snapshot in every 3 batches, so + // with 1/5 chance, it is more likely to create longer maintenance delay. + if (random.nextInt(5) == 0) { + db.doMaintenance() + } + // For half the chance, we move to the next version, and half the chance we keep the + // same version. When the same version is kept, the DB will be reloaded. + if (random.nextInt(2) == 0) { + val inc = if (ifTestSkipBatch == 1) { + random.nextInt(3) + } else { + 1 + } + if (inc > 1) { + // Create changelog files in the gap + for (j <- 1 to inc - 1) { + db2.load(curVer + j) + db2.put("foo", "bar") + db2.commit() + } + } + curVer = curVer + inc + } + } + } + } + } + } + } + test("validate Rocks DB SST files do not have a VersionIdMismatch" + " when metadata file is not overwritten - scenario 1") { val fmClass = "org.apache.spark.sql.execution.streaming.state." + From 2fc21adc636c351d2db0c7561e84c5d22d00cbb2 Mon Sep 17 00:00:00 2001 From: Neil Ramaswamy Date: Fri, 30 Aug 2024 11:34:15 +0900 Subject: [PATCH 008/230] [SPARK-49456][DOCS] Fix hash fragment scrolling behavior on versioned Spark documentation ### Why are the changes needed? Spark docs have always incorrectly scrolled to the right place when a hash fragment is specified in the URL: image We have a 12-year-old hack to fix this with Javascript, but this isn't the right solution for 2024. [Most browsers now support](https://developer.mozilla.org/en-US/docs/Web/CSS/scroll-margin-top#browser_compatibility) `scroll-margin-top`, which allows you to specify a margin that the browser will keep between the top of the element and the top-border of the viewport. If a user's browser doesn't support this, their hyperlinks will be off by about ~80 pixels, which is no worse than the UX today. You can play with the live changes [here](https://second-spark-site.vercel.app/). ### Does this PR introduce _any_ user-facing change? Yes. When they click/visit Spark links, the page will actually scroll where they want to go. ### How was this patch tested? I tested this on Chromium 128.0.6613.85, and Safari 17.4.1. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47925 from neilramaswamy/spark-49456. Authored-by: Neil Ramaswamy Signed-off-by: Hyukjin Kwon --- docs/css/custom.css | 17 ++++++++++++++++- docs/js/main.js | 17 ----------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/docs/css/custom.css b/docs/css/custom.css index 9edb466606555..bcee4a54a0cb5 100644 --- a/docs/css/custom.css +++ b/docs/css/custom.css @@ -1,3 +1,7 @@ +:root { + --navbar-height: 80px; +} + body { color: #666666; font-family: "DM Sans", sans-serif; @@ -5,8 +9,8 @@ body { font-weight: 400; overflow-wrap: anywhere; overflow-x: auto; - padding-top: 80px; padding-bottom: 20px; + padding-top: var(--navbar-height); } a { @@ -25,6 +29,17 @@ a:focus { border-radius: 0; z-index: 9999; transition: none !important; + height: var(--navbar-height); +} + +/* +Any element with an id attribute can be scrolled to via the URL hash fragment. +But since the navbar is fixed at the top, elements will, by default, get hidden +by the navbar. To prevent this, we make sure that there's a margin above these +links. +*/ +*[id] { + scroll-margin-top: var(--navbar-height); } .navbar .nav-item:hover .dropdown-menu, diff --git a/docs/js/main.js b/docs/js/main.js index 1c601f5210ab5..220cf4026bdcc 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -87,15 +87,6 @@ function codeTabs() { } -// A script to fix internal hash links because we have an overlapping top bar. -// Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510 -function maybeScrollToHash() { - if (window.location.hash && $(window.location.hash).length) { - var newTop = $(window.location.hash).offset().top - 57; - $(window).scrollTop(newTop); - } -} - $(function() { codeTabs(); // Display anchor links when hovering over headers. For documentation of the @@ -105,14 +96,6 @@ $(function() { }; anchors.add(); - $(window).bind('hashchange', function() { - maybeScrollToHash(); - }); - - // Scroll now too in case we had opened the page on a hash, but wait a bit because some browsers - // will try to do *their* initial scroll after running the onReady handler. - $(window).on('load', function() { setTimeout(function() { maybeScrollToHash(); }, 25); }); - // Make dropdown menus in nav bars show on hover instead of click // using solution at https://webdesign.tutsplus.com/tutorials/how- // to-make-the-bootstrap-navbar-dropdown-work-on-hover--cms-33840 From 03180ece177d3ca9ea9ee6aa7a17979696e386ad Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 29 Aug 2024 22:59:55 -0400 Subject: [PATCH 009/230] [SPARK-49421][CONNECT][SQL] Create a shared RelationalGroupedDataset interface ### What changes were proposed in this pull request? This PR introduces a shared RelationalGroupedDataset interface. ### Why are the changes needed? We want to unify the Classic and Connect Scala DataFrame APIs. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47918 from hvanhovell/SPARK-49421. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Dataset.scala | 254 ++---------- .../spark/sql/RelationalGroupedDataset.scala | 378 +++--------------- .../org/apache/spark/sql/api/Dataset.scala | 218 ++++++++++ .../sql/api/RelationalGroupedDataset.scala | 338 ++++++++++++++++ .../query-tests/queries/cube_string.json | 6 +- .../query-tests/queries/cube_string.proto.bin | Bin 92 -> 96 bytes .../queries/groupby_agg_string.json | 6 +- .../queries/groupby_agg_string.proto.bin | Bin 103 -> 107 bytes .../query-tests/queries/groupby_avg.json | 6 +- .../query-tests/queries/groupby_avg.proto.bin | Bin 90 -> 94 bytes .../query-tests/queries/groupby_max.json | 6 +- .../query-tests/queries/groupby_max.proto.bin | Bin 90 -> 94 bytes .../query-tests/queries/groupby_mean.json | 6 +- .../queries/groupby_mean.proto.bin | Bin 90 -> 94 bytes .../query-tests/queries/groupby_min.json | 6 +- .../query-tests/queries/groupby_min.proto.bin | Bin 90 -> 94 bytes .../query-tests/queries/groupby_sum.json | 6 +- .../query-tests/queries/groupby_sum.proto.bin | Bin 90 -> 94 bytes .../queries/grouping_and_grouping_id.json | 6 +- .../grouping_and_grouping_id.proto.bin | Bin 138 -> 142 bytes .../resources/query-tests/queries/pivot.json | 3 +- .../query-tests/queries/pivot.proto.bin | Bin 97 -> 99 bytes .../queries/pivot_without_column_values.json | 3 +- .../pivot_without_column_values.proto.bin | Bin 85 -> 87 bytes .../query-tests/queries/rollup_string.json | 6 +- .../queries/rollup_string.proto.bin | Bin 92 -> 96 bytes .../scala/org/apache/spark/sql/Dataset.scala | 229 ++--------- .../spark/sql/RelationalGroupedDataset.scala | 364 +++-------------- 28 files changed, 775 insertions(+), 1066 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 37a182675b6cd..d05834c4fc6c8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -135,6 +135,7 @@ class Dataset[T] private[sql] ( @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) extends api.Dataset[T, Dataset] { + type RGD = RelationalGroupedDataset import sparkSession.RichColumn @@ -506,58 +507,12 @@ class Dataset[T] private[sql] ( } } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) } - /** - * Groups the Dataset using the specified columns, so that we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * This is a variant of groupBy that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy("department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) - } - /** @inheritdoc */ def reduce(func: (T, T) => T): T = { val udf = SparkUserDefinedFunction( @@ -599,134 +554,19 @@ class Dataset[T] private[sql] ( def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(UdfUtils.mapFunctionToScalaFunc(func))(encoder) - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup($"department", $"group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) } - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * This is a variant of rollup that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup("department", "group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def rollup(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) - } - - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube($"department", $"group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def cube(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_CUBE) } - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * This is a variant of cube that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube("department", "group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def cube(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_CUBE) - } - - /** - * Create multi-dimensional aggregation for the current Dataset using the specified grouping - * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the - * available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns group by specific grouping sets. - * ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg() - * - * // Compute the max age and average salary, group by specific grouping sets. - * ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset = { val groupingSetMsgs = groupingSets.map { groupingSet => @@ -743,61 +583,6 @@ class Dataset[T] private[sql] ( groupingSets = Some(groupingSetMsgs)) } - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg("age" -> "max", "salary" -> "avg") - * ds.groupBy().agg("age" -> "max", "salary" -> "avg") - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs: _*) - } - - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * (Java-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(max($"age"), avg($"salary")) - * ds.groupBy().agg(max($"age"), avg($"salary")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -1745,4 +1530,33 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ override def distinct(): Dataset[T] = super.distinct() + + /** @inheritdoc */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): RelationalGroupedDataset = + super.groupBy(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def rollup(col1: String, cols: String*): RelationalGroupedDataset = + super.rollup(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def cube(col1: String, cols: String*): RelationalGroupedDataset = + super.cube(col1, cols: _*) + + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) + + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0c8657e12d8df..c9b011ca4535b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.Locale - import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto @@ -36,42 +34,52 @@ import org.apache.spark.connect.proto * @since 3.4.0 */ class RelationalGroupedDataset private[sql] ( - private[sql] val df: DataFrame, + protected val df: DataFrame, private[sql] val groupingExprs: Seq[Column], groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None, - groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) { + groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) + extends api.RelationalGroupedDataset[Dataset] { + type RGD = RelationalGroupedDataset import df.sparkSession.RichColumn - private[this] def toDF(aggExprs: Seq[Column]): DataFrame = { + protected def toDF(aggExprs: Seq[Column]): DataFrame = { df.sparkSession.newDataFrame { builder => - builder.getAggregateBuilder + val aggBuilder = builder.getAggregateBuilder .setInput(df.plan.getRoot) - .addAllGroupingExpressions(groupingExprs.map(_.expr).asJava) - .addAllAggregateExpressions(aggExprs.map(e => e.typedExpr(df.encoder)).asJava) + groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr)) + aggExprs.foreach(c => aggBuilder.addAggregateExpressions(c.typedExpr(df.encoder))) groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) case proto.Aggregate.GroupType.GROUP_TYPE_CUBE => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT => assert(pivot.isDefined) - builder.getAggregateBuilder + aggBuilder .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT) .setPivot(pivot.get) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS => assert(groupingSets.isDefined) - val aggBuilder = builder.getAggregateBuilder - .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) groupingSets.get.foreach(aggBuilder.addGroupingSets) case g => throw new UnsupportedOperationException(g.toString) } } } + protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] = { + // This behaves different than the classic implementation. The classic implementation validates + // if a column is actually a number, and if it is not it throws an error immediately. In connect + // it depends on the input type (casting) rules for the method invoked. If the input violates + // the a different error will be thrown. However it is also possible to get a result for a + // non-numeric column in connect, for example when you use min/max. + colNames.map(df.col) + } + /** * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of * current `RelationalGroupedDataset`. @@ -82,295 +90,71 @@ class RelationalGroupedDataset private[sql] ( KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) } - /** - * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The - * resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} - * - * @since 3.4.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - toDF((aggExpr +: aggExprs).map { case (colName, expr) => - strToColumn(expr, df(colName)) - }) - } - - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate - * methods. The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg(Map( - * "age" -> "max", - * "expense" -> "sum" - * )) - * }}} - * - * @since 3.4.0 - */ - def agg(exprs: Map[String, String]): DataFrame = { - toDF(exprs.map { case (colName, expr) => - strToColumn(expr, df(colName)) - }.toSeq) - } + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) - /** - * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. - * The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); - * }}} - * - * @since 3.4.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.asScala.toMap) - } + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) - private[this] def strToColumn(expr: String, inputExpr: Column): Column = { - expr.toLowerCase(Locale.ROOT) match { - case "avg" | "average" | "mean" => functions.avg(inputExpr) - case "stddev" | "std" => functions.stddev(inputExpr) - case "count" | "size" => functions.count(inputExpr) - case name => Column.fn(name, inputExpr) - } - } + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) - /** - * Compute aggregates by specifying a series of aggregate columns. Note that this function by - * default retains the grouping columns in its output. To not retain grouping columns, set - * `spark.sql.retainGroupColumns` to false. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change - * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. - * {{{ - * // Scala, 1.3.x: - * df.groupBy("department").agg($"department", max("age"), sum("expense")) - * - * // Java, 1.3.x: - * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = { - toDF(expr +: exprs) - } + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) - /** - * Count the number of rows for each group. The resulting `DataFrame` will also contain the - * grouping columns. - * - * @since 3.4.0 - */ - def count(): DataFrame = toDF(Seq(functions.count(functions.lit(1)).alias("count"))) + /** @inheritdoc */ + override def count(): DataFrame = super.count() - /** - * Compute the average value for each numeric columns for each group. This is an alias for - * `avg`. The resulting `DataFrame` will also contain the grouping columns. When specified - * columns are given, only compute the average values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def mean(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.mean(colName))) - } + override def mean(colNames: String*): DataFrame = super.mean(colNames: _*) - /** - * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will - * also contain the grouping columns. When specified columns are given, only compute the max - * values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def max(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.max(colName))) - } + override def max(colNames: String*): DataFrame = super.max(colNames: _*) - /** - * Compute the mean value for each numeric columns for each group. The resulting `DataFrame` - * will also contain the grouping columns. When specified columns are given, only compute the - * mean values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def avg(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.avg(colName))) - } + override def avg(colNames: String*): DataFrame = super.avg(colNames: _*) - /** - * Compute the min value for each numeric column for each group. The resulting `DataFrame` will - * also contain the grouping columns. When specified columns are given, only compute the min - * values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def min(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.min(colName))) - } + override def min(colNames: String*): DataFrame = super.min(colNames: _*) - /** - * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also - * contain the grouping columns. When specified columns are given, only compute the sum for - * them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def sum(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.sum(colName))) - } + override def sum(colNames: String*): DataFrame = super.sum(colNames: _*) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the - * resulting schema of the transformation. To avoid any eager computations, provide an explicit - * list of values via `pivot(pivotColumn: String, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @since 3.4.0 - */ - def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + /** @inheritdoc */ + override def pivot(pivotColumn: String): RelationalGroupedDataset = super.pivot(pivotColumn) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are - * two versions of pivot function: one that requires the caller to specify the list of distinct - * values to pivot on, and one that does not. The latter is more concise but less efficient, - * because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by - * multiple columns, use the `struct` function to combine the columns and values: - * - * {{{ - * df.groupBy("year") - * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) - * .agg(sum($"earnings")) - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) - } + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list of - * distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) + + /** @inheritdoc */ + override def pivot( + pivotColumn: Column, + values: java.util.List[Any]): RelationalGroupedDataset = { + super.pivot(pivotColumn, values) } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an - * overloaded version of the `pivot` method with `pivotColumn` of the `String` type. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ + /** @inheritdoc */ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => - val valueExprs = values.map(_ match { + val valueExprs = values.map { case c: Column if c.expr.hasLiteral => c.expr.getLiteral case c: Column if !c.expr.hasLiteral => throw new IllegalArgumentException("values only accept literal Column") case v => functions.lit(v).expr.getLiteral - }) + } new RelationalGroupedDataset( df, groupingExprs, @@ -386,46 +170,8 @@ class RelationalGroupedDataset private[sql] ( } } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the - * resulting schema of the transformation. To avoid any eager computations, provide an explicit - * list of values via `pivot(pivotColumn: Column, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course").sum($"earnings"); - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * he column to pivot. - * @since 3.4.0 - */ + /** @inheritdoc */ def pivot(pivotColumn: Column): RelationalGroupedDataset = { - pivot(pivotColumn, Seq()) - } - - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of the - * `String` type. - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(pivotColumn, values.asScala.toSeq) + pivot(pivotColumn, Nil) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 16f15205cabea..2b071a384e0ac 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -120,6 +120,8 @@ import org.apache.spark.util.SparkClassUtils */ @Stable abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { + type RGD <: RelationalGroupedDataset[DS] + def sparkSession: SparkSession[DS] val encoder: Encoder[T] @@ -1137,6 +1139,222 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { */ def where(conditionExpr: String): DS[T] = filter(conditionExpr) + /** + * Groups the Dataset using the specified columns, so we can run aggregation on them. See + * [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * ds.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * ds.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def groupBy(cols: Column*): RGD + + /** + * Groups the Dataset using the specified columns, so that we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * ds.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * ds.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): RGD = groupBy((col1 +: cols).map(col): _*) + + /** + * Create a multi-dimensional rollup for the current Dataset using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns rolled up by department and group. + * ds.rollup($"department", $"group").avg() + * + * // Compute the max age and average salary, rolled up by department and gender. + * ds.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def rollup(cols: Column*): RGD + + /** + * Create a multi-dimensional rollup for the current Dataset using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * This is a variant of rollup that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns rolled up by department and group. + * ds.rollup("department", "group").avg() + * + * // Compute the max age and average salary, rolled up by department and gender. + * ds.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def rollup(col1: String, cols: String*): RGD = rollup((col1 +: cols).map(col): _*) + + /** + * Create a multi-dimensional cube for the current Dataset using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * ds.cube($"department", $"group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * ds.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def cube(cols: Column*): RGD + + /** + * Create a multi-dimensional cube for the current Dataset using the specified columns, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * This is a variant of cube that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * ds.cube("department", "group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * ds.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def cube(col1: String, cols: String*): RGD = cube((col1 +: cols).map(col): _*) + + /** + * Create multi-dimensional aggregation for the current Dataset using the specified grouping sets, + * so we can run aggregation on them. + * See [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns group by specific grouping sets. + * ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg() + * + * // Compute the max age and average salary, group by specific grouping sets. + * ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 4.0.0 + */ + @scala.annotation.varargs + def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RGD + + /** + * (Scala-specific) Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg("age" -> "max", "salary" -> "avg") + * ds.groupBy().agg("age" -> "max", "salary" -> "avg") + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] = { + groupBy().agg(aggExpr, aggExprs: _*) + } + + /** + * (Scala-specific) Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + def agg(exprs: Map[String, String]): DS[Row] = groupBy().agg(exprs) + + /** + * (Java-specific) Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(Map("age" -> "max", "salary" -> "avg")) + * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + def agg(exprs: util.Map[String, String]): DS[Row] = groupBy().agg(exprs) + + /** + * Aggregates on the entire Dataset without groups. + * {{{ + * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) + * ds.agg(max($"age"), avg($"salary")) + * ds.groupBy().agg(max($"age"), avg($"salary")) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DS[Row] = groupBy().agg(expr, exprs: _*) + /** * (Scala-specific) * Reduces the elements of this Dataset using the specified binary function. The given `func` diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala new file mode 100644 index 0000000000000..30b2992d43a00 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import scala.jdk.CollectionConverters._ + +import _root_.java.util + +import org.apache.spark.annotation.Stable +import org.apache.spark.sql.{functions, Column, Row} + +/** + * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], + * [[Dataset#cube cube]] or [[Dataset#rollup rollup]] (and also `pivot`). + * + * The main method is the `agg` function, which has multiple variants. This class also contains + * some first-order statistics such as `mean`, `sum` for convenience. + * + * @note This class was named `GroupedData` in Spark 1.x. + * @since 2.0.0 + */ +@Stable +abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { + type RGD <: RelationalGroupedDataset[DS] + + protected def df: DS[Row] + + /** + * Create a aggregation based on the grouping column, the grouping type, and the aggregations. + */ + protected def toDF(aggCols: Seq[Column]): DS[Row] + + protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] + + /** + * Convert a name method tuple into a Column. + */ + private def toAggCol(colAndMethod: (String, String)): Column = { + val col = df.col(colAndMethod._1) + colAndMethod._2.toLowerCase(util.Locale.ROOT) match { + case "avg" | "average" | "mean" => functions.avg(col) + case "stddev" | "std" => functions.stddev(col) + case "count" | "size" => functions.count(col) + case name => Column.fn(name, col) + } + } + + private def aggregateNumericColumns( + colNames: Seq[String], + function: Column => Column): DS[Row] = { + toDF(selectNumericColumns(colNames).map(function)) + } + + /** + * (Scala-specific) Compute aggregates by specifying the column names and + * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg( + * "age" -> "max", + * "expense" -> "sum" + * ) + * }}} + * + * @since 1.3.0 + */ + def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] = + toDF((aggExpr +: aggExprs).map(toAggCol)) + + /** + * (Scala-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max", + * "expense" -> "sum" + * )) + * }}} + * + * @since 1.3.0 + */ + def agg(exprs: Map[String, String]): DS[Row] = toDF(exprs.map(toAggCol).toSeq) + + /** + * (Java-specific) Compute aggregates by specifying a map from column name to + * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import com.google.common.collect.ImmutableMap; + * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); + * }}} + * + * @since 1.3.0 + */ + def agg(exprs: util.Map[String, String]): DS[Row] = { + agg(exprs.asScala.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. Note that this function by + * default retains the grouping columns in its output. To not retain grouping columns, set + * `spark.sql.retainGroupColumns` to false. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * + * // Scala: + * import org.apache.spark.sql.functions._ + * df.groupBy("department").agg(max("age"), sum("expense")) + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df.groupBy("department").agg(max("age"), sum("expense")); + * }}} + * + * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change + * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. + * {{{ + * // Scala, 1.3.x: + * df.groupBy("department").agg($"department", max("age"), sum("expense")) + * + * // Java, 1.3.x: + * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); + * }}} + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DS[Row] = toDF(expr +: exprs) + + /** + * Count the number of rows for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * + * @since 1.3.0 + */ + def count(): DS[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def mean(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg) + + + /** + * Compute the max value for each numeric columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the max values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def max(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.max) + + /** + * Compute the mean value for each numeric columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the mean values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def avg(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg) + + /** + * Compute the min value for each numeric column for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the min values for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def min(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.min) + + /** + * Compute the sum for each numeric columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 1.3.0 + */ + @scala.annotation.varargs + def sum(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.sum) + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * + * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine + * the resulting schema of the transformation. To avoid any eager computations, provide an + * explicit list of values via `pivot(pivotColumn: String, values: Seq[Any])`. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * @param pivotColumn Name of the column to pivot. + * @since 1.6.0 + */ + def pivot(pivotColumn: String): RGD = pivot(df.col(pivotColumn)) + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * + * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by + * multiple columns, use the `struct` function to combine the columns and values: + * + * {{{ + * df.groupBy("year") + * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) + * .agg(sum($"earnings")) + * }}} + * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: Seq[Any]): RGD = + pivot(df.col(pivotColumn), values) + + /** + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. + * + * There are two versions of pivot function: one that requires the caller to specify the list + * of distinct values to pivot on, and one that does not. The latter is more concise but less + * efficient, because Spark needs to first compute the list of distinct values internally. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); + * + * // Or without specifying column values (less efficient) + * df.groupBy("year").pivot("course").sum("earnings"); + * }}} + * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * @param pivotColumn Name of the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 1.6.0 + */ + def pivot(pivotColumn: String, values: util.List[Any]): RGD = + pivot(df.col(pivotColumn), values) + + /** + * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified + * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of + * the `String` type. + * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * @param pivotColumn the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column, values: util.List[Any]): RGD = + pivot(pivotColumn, values.asScala.toSeq) + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * + * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine + * the resulting schema of the transformation. To avoid any eager computations, provide an + * explicit list of values via `pivot(pivotColumn: Column, values: Seq[Any])`. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy($"year").pivot($"course").sum($"earnings"); + * }}} + * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * @param pivotColumn he column to pivot. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column): RGD + + /** + * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. + * + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") + * }}} + * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * @param pivotColumn the column to pivot. + * @param values List of values that will be translated to columns in the output DataFrame. + * @since 2.4.0 + */ + def pivot(pivotColumn: Column, values: Seq[Any]): RGD +} diff --git a/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json b/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json index 5b9709ff06576..03625861d88f2 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/cube_string.json @@ -14,11 +14,13 @@ "groupType": "GROUP_TYPE_CUBE", "groupingExpressions": [{ "unresolvedAttribute": { - "unparsedIdentifier": "a" + "unparsedIdentifier": "a", + "planId": "0" } }, { "unresolvedAttribute": { - "unparsedIdentifier": "b" + "unparsedIdentifier": "b", + "planId": "0" } }], "aggregateExpressions": [{ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin index d46e40b39dcfec95244ab4af93c4cd4905d60d02..59c7a55571201686bc8f434588aa7c8f2316554b 100644 GIT binary patch delta 33 lcmaz^VB=yEVDyTb$fhgEF2u^km?*#?1!g1(Fidoi0RT|g1eE{) delta 54 zcmYd@VdG*FVDt)`$fm2wD#Xmim?#CNl9Z&3B)Iscc)3`U^Gowegjl$k3>bwV901%q B2_*mk diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json index 26320d404835f..285c13f4bc8b3 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.json @@ -14,11 +14,13 @@ "groupType": "GROUP_TYPE_GROUPBY", "groupingExpressions": [{ "unresolvedAttribute": { - "unparsedIdentifier": "id" + "unparsedIdentifier": "id", + "planId": "0" } }, { "unresolvedAttribute": { - "unparsedIdentifier": "b" + "unparsedIdentifier": "b", + "planId": "0" } }], "aggregateExpressions": [{ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_agg_string.proto.bin index 818146f7f69356c4997a36e0f1c3cf1c055d3fe1..674d506fa4a07817d6c28a4f3a56e22ee5126f7f 100644 GIT binary patch delta 34 ocmYe#X5(TKVDw6z$fhgAAq3=PrU)=du?w+sF(wHxOmtKO0A6tgF#rGn delta 30 lcmc~!XX9cLVDw6y$fnE3Cd9(Ul$j#MD#Xmim^9H<4FF9u1qA>A diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json b/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json index 5785eee2cadb5..0ded46cf6cc7c 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.json @@ -22,7 +22,8 @@ "functionName": "avg", "arguments": [{ "unresolvedAttribute": { - "unparsedIdentifier": "a" + "unparsedIdentifier": "a", + "planId": "0" } }] } @@ -31,7 +32,8 @@ "functionName": "avg", "arguments": [{ "unresolvedAttribute": { - "unparsedIdentifier": "b" + "unparsedIdentifier": "b", + "planId": "0" } }] } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/groupby_avg.proto.bin index 4a18ea2d82d935fcc13bb2e74c2e1c9e83f001d2..444b0c3853f164c5227de9d6576baaabc3fd861f 100644 GIT binary patch delta 48 rcma!wW8-2HVDyTZ$Y!c6AjQYUoLH7F#4g0j#h56-poA)tB)|XwxXA|Y delta 44 ncma!xV&h^GVDt)^$Yv_fC&k0XoLH7F#45zh#h9ps%1;6So{||$i&ViI8VYM#iZ$IB|j%*B`}1*VcFx_SZtQTzoX diff --git a/sql/connect/common/src/test/resources/query-tests/queries/pivot.json b/sql/connect/common/src/test/resources/query-tests/queries/pivot.json index 30bff04c531db..2af86606b9fcb 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/pivot.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/pivot.json @@ -30,7 +30,8 @@ "pivot": { "col": { "unresolvedAttribute": { - "unparsedIdentifier": "a" + "unparsedIdentifier": "a", + "planId": "0" } }, "values": [{ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin index 67063209a184c022b9da77c315a286007d96a210..f545179e84968884a3fcdc5b3ca0b3bd259027c7 100644 GIT binary patch delta 35 ncmYdHX5(TKVDyTe$mT94%Ed0k%Eg!{z#znAzzD=lK+FsPS{noo delta 33 lcmYdJWaDBIVDyTf$mT90%*86i%*B`}#ALt-#7scU3;bwV901%q B2_*mk diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 03b6a8d6d737d..a28dfbdbf66a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -214,6 +214,7 @@ class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @DeveloperApi @Unstable @transient val encoder: Encoder[T]) extends api.Dataset[T, Dataset] { + type RGD = RelationalGroupedDataset @transient lazy val sparkSession: SparkSession = { if (queryExecution == null || queryExecution.sparkSession == null) { @@ -891,73 +892,19 @@ class Dataset[T] private[sql]( RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) } - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup($"department", $"group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.RollupType) } - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube($"department", $"group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def cube(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.CubeType) } - /** - * Create multi-dimensional aggregation for the current Dataset using the specified grouping sets, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns group by specific grouping sets. - * ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg() - * - * // Compute the max age and average salary, group by specific grouping sets. - * ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset( @@ -966,33 +913,6 @@ class Dataset[T] private[sql]( RelationalGroupedDataset.GroupingSetsType(groupingSets.map(_.map(_.expr)))) } - /** - * Groups the Dataset using the specified columns, so that we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * This is a variant of groupBy that can only group by existing columns using column names - * (i.e. cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy("department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 2.0.0 - */ - @scala.annotation.varargs - def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - RelationalGroupedDataset( - toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.GroupByType) - } - /** @inheritdoc */ def reduce(func: (T, T) => T): T = withNewRDDExecutionId("reduce") { rdd.reduce(func) @@ -1027,118 +947,6 @@ class Dataset[T] private[sql]( def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = groupByKey(func.call(_))(encoder) - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * This is a variant of rollup that can only group by existing columns using column names - * (i.e. cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup("department", "group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ - @scala.annotation.varargs - def rollup(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - RelationalGroupedDataset( - toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.RollupType) - } - - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, - * so we can run aggregation on them. - * See [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * This is a variant of cube that can only group by existing columns using column names - * (i.e. cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube("department", "group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 2.0.0 - */ - @scala.annotation.varargs - def cube(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - RelationalGroupedDataset( - toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.CubeType) - } - - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg("age" -> "max", "salary" -> "avg") - * ds.groupBy().agg("age" -> "max", "salary" -> "avg") - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs : _*) - } - - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ - def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * (Java-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(max($"age"), avg($"salary")) - * ds.groupBy().agg(max($"age"), avg($"salary")) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -2223,6 +2031,35 @@ class Dataset[T] private[sql]( /** @inheritdoc */ override def distinct(): Dataset[T] = super.distinct() + /** @inheritdoc */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): RelationalGroupedDataset = + super.groupBy(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def rollup(col1: String, cols: String*): RelationalGroupedDataset = + super.rollup(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def cube(col1: String, cols: String*): RelationalGroupedDataset = + super.cube(col1, cols: _*) + + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) + + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 3cafe0d98f1bf..777baa3e62687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql -import java.util.Locale - -import scala.jdk.CollectionConverters._ - import org.apache.spark.SparkRuntimeException import org.apache.spark.annotation.Stable import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedFunction} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -54,13 +50,19 @@ import org.apache.spark.util.ArrayImplicits._ */ @Stable class RelationalGroupedDataset protected[sql]( - private[sql] val df: DataFrame, + protected[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Expression], - groupType: RelationalGroupedDataset.GroupType) { + groupType: RelationalGroupedDataset.GroupType) + extends api.RelationalGroupedDataset[Dataset] { + type RGD = RelationalGroupedDataset import RelationalGroupedDataset._ import df.sparkSession._ - private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { + override protected def toDF(aggCols: Seq[Column]): DataFrame = { + val aggExprs = aggCols.map(expression).map { e => + withInputType(e, df.exprEnc, df.logicalPlan.output) + } + @scala.annotation.nowarn("cat=deprecation") val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { groupingExprs match { @@ -98,9 +100,7 @@ class RelationalGroupedDataset protected[sql]( } } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { - + override protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. df.numericColumns @@ -114,29 +114,9 @@ class RelationalGroupedDataset protected[sql]( namedExpr } } - toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) + columnExprs.map(column) } - private[this] def strToExpr(expr: String): (Expression => Expression) = { - val exprToFunc: (Expression => Expression) = { - (inputExpr: Expression) => expr.toLowerCase(Locale.ROOT) match { - // We special handle a few cases that have alias that are not in function registry. - case "avg" | "average" | "mean" => - UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) - case "stddev" | "std" => - UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) - // Also special handle count because we need to take care count(*). - case "count" | "size" => - // Turn count(*) into count(1) - inputExpr match { - case s: Star => Count(Literal(1)).toAggregateExpression() - case _ => Count(inputExpr).toAggregateExpression() - } - case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) - } - } - (inputExpr: Expression) => exprToFunc(inputExpr) - } /** * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions @@ -159,296 +139,68 @@ class RelationalGroupedDataset protected[sql]( groupingAttributes) } - /** - * (Scala-specific) Compute aggregates by specifying the column names and - * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} - * - * @since 1.3.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - toDF((aggExpr +: aggExprs).map { case (colName, expr) => - strToExpr(expr)(df(colName).expr) - }) - } + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg(Map( - * "age" -> "max", - * "expense" -> "sum" - * )) - * }}} - * - * @since 1.3.0 - */ - def agg(exprs: Map[String, String]): DataFrame = { - toDF(exprs.map { case (colName, expr) => - strToExpr(expr)(df(colName).expr) - }.toSeq) - } + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) - /** - * (Java-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); - * }}} - * - * @since 1.3.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.asScala.toMap) - } + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) - /** - * Compute aggregates by specifying a series of aggregate columns. Note that this function by - * default retains the grouping columns in its output. To not retain grouping columns, set - * `spark.sql.retainGroupColumns` to false. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change - * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. - * {{{ - * // Scala, 1.3.x: - * df.groupBy("department").agg($"department", max("age"), sum("expense")) - * - * // Java, 1.3.x: - * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); - * }}} - * - * @since 1.3.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map { - case typed: TypedColumn[_, _] => - withInputType(typed.expr, df.exprEnc, df.logicalPlan.output) - case c => c.expr - }) - } + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) - /** - * Count the number of rows for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * - * @since 1.3.0 - */ - def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) + /** @inheritdoc */ + override def count(): DataFrame = super.count() - /** - * Compute the average value for each numeric columns for each group. This is an alias for `avg`. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the average values for them. - * - * @since 1.3.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average(_)) - } + override def mean(colNames: String*): DataFrame = super.mean(colNames: _*) - /** - * Compute the max value for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the max values for them. - * - * @since 1.3.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Max) - } + override def max(colNames: String*): DataFrame = super.max(colNames: _*) - /** - * Compute the mean value for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the mean values for them. - * - * @since 1.3.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Average(_)) - } + override def avg(colNames: String*): DataFrame = super.avg(colNames: _*) - /** - * Compute the min value for each numeric column for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the min values for them. - * - * @since 1.3.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Min) - } + override def min(colNames: String*): DataFrame = super.min(colNames: _*) - /** - * Compute the sum for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 1.3.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(Sum(_)) - } + override def sum(colNames: String*): DataFrame = super.sum(colNames: _*) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine - * the resulting schema of the transformation. To avoid any eager computations, provide an - * explicit list of values via `pivot(pivotColumn: String, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * - * @param pivotColumn Name of the column to pivot. - * @since 1.6.0 - */ - def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + /** @inheritdoc */ + override def pivot(pivotColumn: String): RelationalGroupedDataset = super.pivot(pivotColumn) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by - * multiple columns, use the `struct` function to combine the columns and values: - * - * {{{ - * df.groupBy("year") - * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) - * .agg(sum($"earnings")) - * }}} - * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * - * @param pivotColumn Name of the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 - */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) - } + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list - * of distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} - * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * - * @param pivotColumn Name of the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. - * @since 1.6.0 - */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) + + /** @inheritdoc */ + override def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { + super.pivot(pivotColumn, values) } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine - * the resulting schema of the transformation. To avoid any eager computations, provide an - * explicit list of values via `pivot(pivotColumn: Column, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course").sum($"earnings"); - * }}} - * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * - * @param pivotColumn he column to pivot. - * @since 2.4.0 - */ - def pivot(pivotColumn: Column): RelationalGroupedDataset = { + /** @inheritdoc */ + override def pivot(pivotColumn: Column): RelationalGroupedDataset = pivot(pivotColumn, collectPivotValues(df, pivotColumn)) - } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") - * }}} - * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * - * @param pivotColumn the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. - * @since 2.4.0 - */ + /** @inheritdoc */ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => - val valueExprs = values.map(_ match { + val valueExprs = values.map { case c: Column => c.expr case v => try { @@ -457,7 +209,7 @@ class RelationalGroupedDataset protected[sql]( case _: SparkRuntimeException => throw QueryExecutionErrors.pivotColumnUnsupportedError(v, pivotColumn.expr) } - }) + } new RelationalGroupedDataset( df, groupingExprs, @@ -471,22 +223,6 @@ class RelationalGroupedDataset protected[sql]( } } - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of - * the `String` type. - * - * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, - * except for the aggregation. - * - * @param pivotColumn the column to pivot. - * @param values List of values that will be translated to columns in the output DataFrame. - * @since 2.4.0 - */ - def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(pivotColumn, values.asScala.toSeq) - } - /** * Applies the given serialized R function `func` to each group of data. For each unique group, * the function will be passed the group key and an iterator that contains all of the elements in From 0602020eb3b346a8c50ad32eeda4e6dabb70c584 Mon Sep 17 00:00:00 2001 From: tianhanhu Date: Fri, 30 Aug 2024 11:05:00 +0800 Subject: [PATCH 010/230] [SPARK-49252][CORE] Make`TaskSetExcludeList` and `HeathTracker` independent ### What changes were proposed in this pull request? Make the change such that `TaskSetExcludeList` and `HeathTracker` can be enabled independently. When application level `HealthTracker` is created, but taskset level exclusion is not enabled, `TaskSetExcludeList` would be created in dry run mode, where it still records and reports task failure data to `HealthTracker` but does not participate in scheduler decision making. ### Why are the changes needed? Currently, when `spark.excludeOnFailure.enabled` is set to true, both task set level exclusion (`TaskSetExcludeList`) and application level (`HealthTracker`) would both be enabled. In some cases, we only want to enable exclusion on a single dimension. ### Does this PR introduce _any_ user-facing change? Yes, introduced two new user facing configs `spark.excludeOnFailure.application.enabled` and `spark.excludeOnFailure.taskAndStage.enabled` that allows setting exclusion for taskset/application individually. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47793 from tianhanhu/SPARK-49252_separate_exclusion. Authored-by: tianhanhu Signed-off-by: Yi Wu --- .../spark/internal/config/package.scala | 12 +++++++ .../spark/scheduler/HealthTracker.scala | 14 ++++---- .../spark/scheduler/TaskSetExcludeList.scala | 29 +++++++++++++--- .../spark/scheduler/TaskSetManager.scala | 19 ++++++++--- .../spark/scheduler/HealthTrackerSuite.scala | 17 ++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 33 +++++++++++++++++++ docs/configuration.md | 27 +++++++++++++++ 7 files changed, 135 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e9e411cc56b51..8224bcac28301 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -918,6 +918,18 @@ package object config { .booleanConf .createOptional + private[spark] val EXCLUDE_ON_FAILURE_ENABLED_APPLICATION = + ConfigBuilder("spark.excludeOnFailure.application.enabled") + .version("4.0.0") + .booleanConf + .createOptional + + private[spark] val EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE = + ConfigBuilder("spark.excludeOnFailure.taskAndStage.enabled") + .version("4.0.0") + .booleanConf + .createOptional + private[spark] val MAX_TASK_ATTEMPTS_PER_EXECUTOR = ConfigBuilder("spark.excludeOnFailure.task.maxTaskAttemptsPerExecutor") .version("3.1.0") diff --git a/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala index 1606072153906..82ec0ef91f4fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/HealthTracker.scala @@ -425,14 +425,16 @@ private[spark] object HealthTracker extends Logging { private val DEFAULT_TIMEOUT = "1h" /** - * Returns true if the excludeOnFailure is enabled, based on checking the configuration - * in the following order: - * 1. Is it specifically enabled or disabled? - * 2. Is it enabled via the legacy timeout conf? - * 3. Default is off + * Returns true if the excludeOnFailure is enabled on the application level, + * based on checking the configuration in the following order: + * 1. Is application level exclusion specifically enabled or disabled? + * 2. Is overall exclusion feature enabled or disabled? + * 3. Is it enabled via the legacy timeout conf? + * 4. Default is off */ def isExcludeOnFailureEnabled(conf: SparkConf): Boolean = { - conf.get(config.EXCLUDE_ON_FAILURE_ENABLED) match { + conf.get(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION) + .orElse(conf.get(config.EXCLUDE_ON_FAILURE_ENABLED)) match { case Some(enabled) => enabled case None => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala index c9aa74e0852be..3637305293107 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetExcludeList.scala @@ -31,6 +31,9 @@ import org.apache.spark.util.Clock * which is handled by [[HealthTracker]]. Note that HealthTracker does not know anything * about task failures until a taskset completes successfully. * + * If isDryRun is true, then this class will only function to store information for application + * level exclusion, and will not actually exclude any tasks in task/stage level. + * * THREADING: This class is a helper to [[TaskSetManager]]; as with the methods in * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. @@ -40,7 +43,8 @@ private[scheduler] class TaskSetExcludelist( val conf: SparkConf, val stageId: Int, val stageAttemptId: Int, - val clock: Clock) extends Logging { + val clock: Clock, + val isDryRun: Boolean = false) extends Logging { private val MAX_TASK_ATTEMPTS_PER_EXECUTOR = conf.get(config.MAX_TASK_ATTEMPTS_PER_EXECUTOR) private val MAX_TASK_ATTEMPTS_PER_NODE = conf.get(config.MAX_TASK_ATTEMPTS_PER_NODE) @@ -80,13 +84,13 @@ private[scheduler] class TaskSetExcludelist( * of the scheduler, where those filters will have already been applied. */ def isExecutorExcludedForTask(executorId: String, index: Int): Boolean = { - execToFailures.get(executorId).exists { execFailures => + !isDryRun && execToFailures.get(executorId).exists { execFailures => execFailures.getNumTaskFailures(index) >= MAX_TASK_ATTEMPTS_PER_EXECUTOR } } def isNodeExcludedForTask(node: String, index: Int): Boolean = { - nodeToExcludedTaskIndexes.get(node).exists(_.contains(index)) + !isDryRun && nodeToExcludedTaskIndexes.get(node).exists(_.contains(index)) } /** @@ -96,11 +100,11 @@ private[scheduler] class TaskSetExcludelist( * scheduler, where those filters will already have been applied. */ def isExecutorExcludedForTaskSet(executorId: String): Boolean = { - excludedExecs.contains(executorId) + !isDryRun && excludedExecs.contains(executorId) } def isNodeExcludedForTaskSet(node: String): Boolean = { - excludedNodes.contains(node) + !isDryRun && excludedNodes.contains(node) } private[scheduler] def updateExcludedForFailedTask( @@ -163,3 +167,18 @@ private[scheduler] class TaskSetExcludelist( } } } + +private[scheduler] object TaskSetExcludelist { + + /** + * Returns true if the excludeOnFailure is enabled on the task/stage level, + * based on checking the configuration in the following order: + * 1. Is taskset level exclusion specifically enabled or disabled? + * 2. Is overall exclusion feature enabled or disabled? + * 3. Default is off + */ + def isExcludeOnFailureEnabled(conf: SparkConf): Boolean = { + conf.get(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE) + .orElse(conf.get(config.EXCLUDE_ON_FAILURE_ENABLED)).getOrElse(false) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 7dba4a6dc8fc4..a3d074ddd56cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -143,8 +143,18 @@ private[spark] class TaskSetManager( private var calculatedTasks = 0 private[scheduler] val taskSetExcludelistHelperOpt: Option[TaskSetExcludelist] = { - healthTracker.map { _ => - new TaskSetExcludelist(sched.sc.listenerBus, conf, stageId, taskSet.stageAttemptId, clock) + if (TaskSetExcludelist.isExcludeOnFailureEnabled(conf)) { + Some(new TaskSetExcludelist(sched.sc.listenerBus, conf, stageId, + taskSet.stageAttemptId, clock)) + } else if (healthTracker.isDefined) { + // If we enabled exclusion at application level but not at taskset level exclusion, we create + // TaskSetExcludelist in dry run mode. + // In this mode, TaskSetExcludeList would not exclude any executors but only store + // task failure information. + Some(new TaskSetExcludelist(sched.sc.listenerBus, conf, stageId, + taskSet.stageAttemptId, clock, isDryRun = true)) + } else { + None } } @@ -698,7 +708,6 @@ private[spark] class TaskSetManager( private[scheduler] def getCompletelyExcludedTaskIfAny( hostToExecutors: HashMap[String, HashSet[String]]): Option[Int] = { taskSetExcludelistHelperOpt.flatMap { taskSetExcludelist => - val appHealthTracker = healthTracker.get // Only look for unschedulable tasks when at least one executor has registered. Otherwise, // task sets will be (unnecessarily) aborted in cases when no executors have registered yet. if (hostToExecutors.nonEmpty) { @@ -725,7 +734,7 @@ private[spark] class TaskSetManager( hostToExecutors.forall { case (host, execsOnHost) => // Check if the task can run on the node val nodeExcluded = - appHealthTracker.isNodeExcluded(host) || + healthTracker.exists(_.isNodeExcluded(host)) || taskSetExcludelist.isNodeExcludedForTaskSet(host) || taskSetExcludelist.isNodeExcludedForTask(host, indexInTaskSet) if (nodeExcluded) { @@ -733,7 +742,7 @@ private[spark] class TaskSetManager( } else { // Check if the task can run on any of the executors execsOnHost.forall { exec => - appHealthTracker.isExecutorExcluded(exec) || + healthTracker.exists(_.isExecutorExcluded(exec)) || taskSetExcludelist.isExecutorExcludedForTaskSet(exec) || taskSetExcludelist.isExecutorExcludedForTask(exec, indexInTaskSet) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala index e7a57c22ef66e..478e578130fcb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala @@ -441,6 +441,23 @@ class HealthTrackerSuite extends SparkFunSuite with MockitoSugar with LocalSpark assert(1000 === HealthTracker.getExcludeOnFailureTimeout(conf)) } + test("SPARK-49252: check exclusion enabling config on the application level") { + val conf = new SparkConf().setMaster("local") + assert(!HealthTracker.isExcludeOnFailureEnabled(conf)) + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED, true) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn off taskset level exclusion, application level healthtracker should still be enabled. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn off the application level exclusion specifically, this overrides the global setting. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION, false) + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + assert(!HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn on application level exclusion, health tracker should be enabled. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION, true) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + } + test("check exclude configuration invariants") { val conf = new SparkConf().setMaster("yarn").set(config.SUBMIT_DEPLOY_MODE, "cluster") Seq( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ab2c00e368468..7607d4d9fe6d9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -2725,6 +2725,39 @@ class TaskSetManagerSuite assert(executorMonitor.isExecutorIdle("exec2")) } + test("SPARK-49252: TaskSetExcludeList can be created without HealthTracker") { + // When the excludeOnFailure.enabled is set to true, the TaskSetManager should create a + // TaskSetExcludelist even if the application level HealthTracker is not defined. + val conf = new SparkConf().set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, true) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(1) + + val taskSetManager = new TaskSetManager(sched, taskSet, 1, + // No application level HealthTracker. + healthTracker = None) + assert(taskSetManager.taskSetExcludelistHelperOpt.isDefined) + } + + test("SPARK-49252: TaskSetExcludeList will be running in dry run mode when" + + "exludeOnFailure at taskset level is disabled but health tracker is enabled") { + // Disable the excludeOnFailure.enabled at taskset level. + val conf = new SparkConf().set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(1) + + val taskSetManager = new TaskSetManager(sched, taskSet, 1, + // Enable the application level HealthTracker. + healthTracker = Some(new HealthTracker(sc, None))) + assert(taskSetManager.taskSetExcludelistHelperOpt.isDefined) + assert(taskSetManager.taskSetExcludelistHelperOpt.get.isDryRun) + } + } class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, partitionId) { diff --git a/docs/configuration.md b/docs/configuration.md index ff2f21d282a5e..2da099a6c5ed2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -2839,9 +2839,36 @@ Apart from these, the following properties are also available, and may be useful If set to "true", prevent Spark from scheduling tasks on executors that have been excluded due to too many task failures. The algorithm used to exclude executors and nodes can be further controlled by the other "spark.excludeOnFailure" configuration options. + This config will be overriden by "spark.excludeOnFailure.application.enabled" and + "spark.excludeOnFailure.taskAndStage.enabled" to specify exclusion enablement on individual + levels. 2.1.0 + + spark.excludeOnFailure.application.enabled + + false + + + If set to "true", enables excluding executors for the entire application due to too many task + failures and prevent Spark from scheduling tasks on them. + This config overrides "spark.excludeOnFailure.enabled". + + 4.0.0 + + + spark.excludeOnFailure.taskAndStage.enabled + + false + + + If set to "true", enables excluding executors on a task set level due to too many task + failures and prevent Spark from scheduling tasks on them. + This config overrides "spark.excludeOnFailure.enabled". + + 4.0.0 + spark.excludeOnFailure.timeout 1h From 53c1f31dc26bb56d56e0b71b144910df5d376a76 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 30 Aug 2024 16:15:01 +0800 Subject: [PATCH 011/230] [SPARK-49119][SQL] Fix the inconsistency of syntax `show columns` between v1 and v2 ### What changes were proposed in this pull request? The pr aims to - fix the `inconsistency` of syntax `show columns` between `v1` and `v2`. - assign a name `SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE` to the error condition `_LEGACY_ERROR_TEMP_1057`. - unify v1 and v2 `SHOW COLUMNS ...` tests. - move some UT related to `SHOW COLUMNS` from `DDLSuite` to `command/ShowColumnsSuiteBase` or `v1/ShowColumnsSuiteBase`. - move some UT related to `SHOW COLUMNS` from `DDLParserSuite` and `ErrorParserSuite` to `ShowColumnsParserSuite`. ### Why are the changes needed? In `AstBuilder`, we have `a comment` that explains as follows: https://github.com/apache/spark/blob/2a752105091ef95f994526b15bae2159657c8ed0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala#L5054-L5055 However, in our v2 of the syntax `show columns` implementation, we `did not` perform the above checks, as shown below: ``` withNamespaceAndTable("ns", "tbl") { t => sql(s"CREATE TABLE $t (col1 int, col2 string) $defaultUsing") sql(s"SHOW COLUMNS IN $t IN ns1") } ``` - Before (inconsistent, v1 will fail, but v2 will success) v1: ``` [SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE] SHOW COLUMNS with conflicting namespace: `ns1` != `ns`. ``` v2: ``` Execute successfully. ``` #### so, we should fix it. - After (consistent, v1 & v2 all will fail) v1: ``` [SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE] SHOW COLUMNS with conflicting namespace: `ns1` != `ns`. ``` v2: ``` [SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE] SHOW COLUMNS with conflicting namespace: `ns1` != `ns`. ``` ### Does this PR introduce _any_ user-facing change? Yes, for v2 tables, in syntax `SHOW COLUMNS {FROM | IN} {tableName} {FROM | IN} {namespace}`, if the namespace (`second parameter`) is different from the namespace of the table(`first parameter`), the command will succeed without any awareness before this PR, after this PR, it will report an error. ### How was this patch tested? Add new UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47628 from panbingkun/SPARK-49119. Lead-authored-by: panbingkun Co-authored-by: Kent Yao Signed-off-by: yangjie01 --- .../resources/error/error-conditions.json | 11 +- .../sql/errors/QueryCompilationErrors.scala | 11 +- .../sql/catalyst/parser/DDLParserSuite.scala | 23 ---- .../catalyst/parser/ErrorParserSuite.scala | 4 - .../analysis/ResolveSessionCatalog.scala | 3 +- .../datasources/v2/DataSourceV2Strategy.scala | 13 ++- ...sTableExec.scala => ShowColumnsExec.scala} | 4 +- .../analyzer-results/show_columns.sql.out | 7 +- .../sql-tests/results/show_columns.sql.out | 7 +- .../sql/connector/DataSourceV2SQLSuite.scala | 10 -- .../sql/execution/command/DDLSuite.scala | 33 ------ .../command/ShowColumnsParserSuite.scala | 55 ++++++++++ .../command/ShowColumnsSuiteBase.scala | 100 ++++++++++++++++++ .../command/v1/ShowColumnsSuite.scala | 55 ++++++++++ .../command/v2/ShowColumnsSuite.scala | 25 +++++ .../execution/command/ShowColumnsSuite.scala | 26 +++++ 16 files changed, 297 insertions(+), 90 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/{ShowColumnsTableExec.scala => ShowColumnsExec.scala} (92%) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowColumnsSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowColumnsSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 89d2627ef32ee..496a90e5db347 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3866,6 +3866,12 @@ ], "sqlState" : "42K08" }, + "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE" : { + "message" : [ + "SHOW COLUMNS with conflicting namespaces: != ." + ], + "sqlState" : "42K05" + }, "SORT_BY_WITHOUT_BUCKETING" : { "message" : [ "sortBy must be used together with bucketBy." @@ -5685,11 +5691,6 @@ "ADD COLUMN with v1 tables cannot specify NOT NULL." ] }, - "_LEGACY_ERROR_TEMP_1057" : { - "message" : [ - "SHOW COLUMNS with conflicting databases: '' != ''." - ] - }, "_LEGACY_ERROR_TEMP_1058" : { "message" : [ "Cannot create table with both USING and ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 51ab2eb063233..613e7cff1e42e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1045,13 +1045,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } - def showColumnsWithConflictDatabasesError( - db: Seq[String], v1TableName: TableIdentifier): Throwable = { + def showColumnsWithConflictNamespacesError( + namespaceA: Seq[String], + namespaceB: Seq[String]): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1057", + errorClass = "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", messageParameters = Map( - "dbA" -> db.head, - "dbB" -> v1TableName.database.get)) + "namespaceA" -> toSQLId(namespaceA), + "namespaceB" -> toSQLId(namespaceB))) } def cannotCreateTableWithBothProviderAndSerdeError( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 756ec95c70d2f..d514f777e5544 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -2406,29 +2406,6 @@ class DDLParserSuite extends AnalysisTest { RefreshTable(UnresolvedTableOrView(Seq("a", "b", "c"), "REFRESH TABLE", true))) } - test("show columns") { - val sql1 = "SHOW COLUMNS FROM t1" - val sql2 = "SHOW COLUMNS IN db1.t1" - val sql3 = "SHOW COLUMNS FROM t1 IN db1" - val sql4 = "SHOW COLUMNS FROM db1.t1 IN db1" - - val parsed1 = parsePlan(sql1) - val expected1 = ShowColumns(UnresolvedTableOrView(Seq("t1"), "SHOW COLUMNS", true), None) - val parsed2 = parsePlan(sql2) - val expected2 = ShowColumns(UnresolvedTableOrView(Seq("db1", "t1"), "SHOW COLUMNS", true), None) - val parsed3 = parsePlan(sql3) - val expected3 = - ShowColumns(UnresolvedTableOrView(Seq("db1", "t1"), "SHOW COLUMNS", true), Some(Seq("db1"))) - val parsed4 = parsePlan(sql4) - val expected4 = - ShowColumns(UnresolvedTableOrView(Seq("db1", "t1"), "SHOW COLUMNS", true), Some(Seq("db1"))) - - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - } - test("alter view: add partition (not supported)") { val sql = """ALTER VIEW a.b.c ADD IF NOT EXISTS PARTITION diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index cd1556a2e7916..e4f9b54680dc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -141,10 +141,6 @@ class ErrorParserSuite extends AnalysisTest { exception = parseException("SHOW TABLE EXTENDED IN hyphen-db LIKE \"str\""), errorClass = "INVALID_IDENTIFIER", parameters = Map("ident" -> "hyphen-db")) - checkError( - exception = parseException("SHOW COLUMNS IN t FROM test-db"), - errorClass = "INVALID_IDENTIFIER", - parameters = Map("ident" -> "test-db")) checkError( exception = parseException("DESC SCHEMA EXTENDED test-db"), errorClass = "INVALID_IDENTIFIER", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 20e3b4e980f2a..d569f1ed484cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -330,7 +330,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) val resolver = conf.resolver val db = ns match { case Some(db) if v1TableName.database.exists(!resolver(_, db.head)) => - throw QueryCompilationErrors.showColumnsWithConflictDatabasesError(db, v1TableName) + throw QueryCompilationErrors.showColumnsWithConflictNamespacesError( + Seq(db.head), Seq(v1TableName.database.get)) case _ => ns.map(_.head) } ShowColumnsCommand(db, v1TableName, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 89882997681ca..112ee2c5450b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation, PushableColumnAndNestedColumn} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.storage.StorageLevel @@ -477,7 +478,17 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat Seq(part).asResolvedPartitionSpecs.head, recacheTable(r)) :: Nil - case ShowColumns(resolvedTable: ResolvedTable, _, output) => + case ShowColumns(resolvedTable: ResolvedTable, ns, output) => + ns match { + case Some(namespace) => + val tableNamespace = resolvedTable.identifier.namespace() + if (namespace.length != tableNamespace.length || + !namespace.zip(tableNamespace).forall(SQLConf.get.resolver.tupled)) { + throw QueryCompilationErrors.showColumnsWithConflictNamespacesError( + namespace, tableNamespace.toSeq) + } + case _ => + } ShowColumnsExec(output, resolvedTable) :: Nil case r @ ShowPartitions( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowColumnsTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowColumnsExec.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowColumnsTableExec.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowColumnsExec.scala index e7a608938a04e..e92607aa87164 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowColumnsTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowColumnsExec.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.LeafExecNode * Physical plan node for show columns from table. */ case class ShowColumnsExec( - output: Seq[Attribute], - resolvedTable: ResolvedTable) extends V2CommandExec with LeafExecNode { + output: Seq[Attribute], + resolvedTable: ResolvedTable) extends V2CommandExec with LeafExecNode { override protected def run(): Seq[InternalRow] = { resolvedTable.table.columns().map(f => toCatalystRow(f.name())).toSeq } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/show_columns.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/show_columns.sql.out index 27e75187cdba7..76c3b88a3ce6b 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/show_columns.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/show_columns.sql.out @@ -94,10 +94,11 @@ SHOW COLUMNS IN showdb.showcolumn1 FROM baddb -- !query analysis org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1057", + "errorClass" : "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", + "sqlState" : "42K05", "messageParameters" : { - "dbA" : "baddb", - "dbB" : "showdb" + "namespaceA" : "`baddb`", + "namespaceB" : "`showdb`" } } diff --git a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out index 9a0d82d3617af..bb4e7e08c6f5b 100644 --- a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out @@ -123,10 +123,11 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1057", + "errorClass" : "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", + "sqlState" : "42K05", "messageParameters" : { - "dbA" : "baddb", - "dbB" : "showdb" + "namespaceA" : "`baddb`", + "namespaceB" : "`showdb`" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index a61a266c1ed58..1d37c6aa4eb7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2390,16 +2390,6 @@ class DataSourceV2SQLSuiteV1Filter sql(s"UNCACHE TABLE IF EXISTS $t") } - test("SHOW COLUMNS") { - val t = "testcat.ns1.ns2.tbl" - withTable(t) { - spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") - checkAnswer(sql(s"SHOW COLUMNS FROM $t IN testcat.ns1.ns2"), Seq(Row("id"), Row("data"))) - checkAnswer(sql(s"SHOW COLUMNS in $t"), Seq(Row("id"), Row("data"))) - checkAnswer(sql(s"SHOW COLUMNS FROM $t"), Seq(Row("id"), Row("data"))) - } - } - test("ALTER TABLE ... SET [SERDE|SERDEPROPERTIES]") { val t = "testcat.ns1.ns2.tbl" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index c06f44d0dd042..5c1090c288ed5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1376,39 +1376,6 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { } } - test("show columns - negative test") { - // When case sensitivity is true, the user supplied database name in table identifier - // should match the supplied database name in case sensitive way. - withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - withTempDatabase { db => - val tabName = s"$db.showcolumn" - withTable(tabName) { - sql(s"CREATE TABLE $tabName(col1 int, col2 string) USING parquet ") - checkError( - exception = intercept[AnalysisException] { - sql(s"SHOW COLUMNS IN $db.showcolumn FROM ${db.toUpperCase(Locale.ROOT)}") - }, - errorClass = "_LEGACY_ERROR_TEMP_1057", - parameters = Map("dbA" -> db.toUpperCase(Locale.ROOT), "dbB" -> db) - ) - } - } - } - } - - test("show columns - invalid db name") { - withTable("tbl") { - sql("CREATE TABLE tbl(col1 int, col2 string) USING parquet ") - checkError( - exception = intercept[AnalysisException] { - sql("SHOW COLUMNS IN tbl FROM a.b.c") - }, - errorClass = "REQUIRES_SINGLE_PART_NAMESPACE", - parameters = Map("sessionCatalog" -> "spark_catalog", "namespace" -> "`a`.`b`.`c`") - ) - } - } - test("SPARK-18009 calling toLocalIterator on commands") { import scala.jdk.CollectionConverters._ val df = sql("show databases") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala new file mode 100644 index 0000000000000..17a6df87aa0e4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsParserSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedTableOrView} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.plans.logical.ShowColumns + +class ShowColumnsParserSuite extends AnalysisTest { + + test("show columns") { + comparePlans( + parsePlan("SHOW COLUMNS IN a.b.c"), + ShowColumns( + UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW COLUMNS", allowTempView = true), + None)) + comparePlans( + parsePlan("SHOW COLUMNS FROM a.b.c"), + ShowColumns( + UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW COLUMNS", allowTempView = true), + None)) + comparePlans( + parsePlan("SHOW COLUMNS IN a.b.c FROM a.b"), + ShowColumns(UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW COLUMNS", allowTempView = true), + Some(Seq("a", "b")))) + comparePlans( + parsePlan("SHOW COLUMNS FROM a.b.c IN a.b"), + ShowColumns(UnresolvedTableOrView(Seq("a", "b", "c"), "SHOW COLUMNS", allowTempView = true), + Some(Seq("a", "b")))) + } + + test("illegal characters in unquoted identifier") { + checkError( + exception = parseException(parsePlan)("SHOW COLUMNS IN t FROM test-db"), + errorClass = "INVALID_IDENTIFIER", + sqlState = "42602", + parameters = Map("ident" -> "test-db") + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala new file mode 100644 index 0000000000000..c6f4e0bbd01a1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowColumnsSuiteBase.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.util.Locale + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf + +/** + * This base suite contains unified tests for the `SHOW COLUMNS ...` command that + * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are + * located in more specific test suites: + * + * - V2 table catalog tests: + * `org.apache.spark.sql.execution.command.v2.ShowColumnsSuite` + * - V1 table catalog tests: + * `org.apache.spark.sql.execution.command.v1.ShowColumnsSuiteBase` + * - V1 In-Memory catalog: + * `org.apache.spark.sql.execution.command.v1.ShowColumnsSuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.ShowColumnsSuite` + */ +trait ShowColumnsSuiteBase extends QueryTest with DDLCommandTestUtils { + override val command = "SHOW COLUMNS ..." + + test("basic test") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t(col1 int, col2 string) $defaultUsing") + val expected = Seq(Row("col1"), Row("col2")) + checkAnswer(sql(s"SHOW COLUMNS FROM $t IN ns"), expected) + checkAnswer(sql(s"SHOW COLUMNS IN $t FROM ns"), expected) + checkAnswer(sql(s"SHOW COLUMNS IN $t"), expected) + } + } + + test("negative test - the table does not exist") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t(col1 int, col2 string) $defaultUsing") + + checkError( + exception = intercept[AnalysisException] { + sql(s"SHOW COLUMNS IN tbl IN ns1") + }, + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`ns1`.`tbl`"), + context = ExpectedContext(fragment = "tbl", start = 16, stop = 18) + ) + } + } + + test("the namespace of the table conflicts with the specified namespace") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t(col1 int, col2 string) $defaultUsing") + + val sqlText1 = s"SHOW COLUMNS IN $t IN ns1" + val sqlText2 = s"SHOW COLUMNS IN $t FROM ${"ns".toUpperCase(Locale.ROOT)}" + + checkError( + exception = intercept[AnalysisException] { + sql(sqlText1) + }, + errorClass = "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", + parameters = Map( + "namespaceA" -> s"`ns1`", + "namespaceB" -> s"`ns`" + ) + ) + // When case sensitivity is true, the user supplied namespace name in table identifier + // should match the supplied namespace name in case-sensitive way. + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkError( + exception = intercept[AnalysisException] { + sql(sqlText2) + }, + errorClass = "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE", + parameters = Map( + "namespaceA" -> s"`${"ns".toUpperCase(Locale.ROOT)}`", + "namespaceB" -> "`ns`" + ) + ) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala new file mode 100644 index 0000000000000..e9459a224486c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowColumnsSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v1 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.command + +/** + * This base suite contains unified tests for the `SHOW COLUMNS ...` command that check V1 table + * catalogs. The tests that cannot run for all V1 catalogs are located in more specific test suites: + * + * - V1 In-Memory catalog: + * `org.apache.spark.sql.execution.command.v1.ShowColumnsSuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.ShowColumnsSuite` + */ +trait ShowColumnsSuiteBase extends command.ShowColumnsSuiteBase { + + test("invalid db name") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t(col1 int, col2 string) $defaultUsing") + checkError( + exception = intercept[AnalysisException] { + sql("SHOW COLUMNS IN tbl FROM a.b.c") + }, + errorClass = "REQUIRES_SINGLE_PART_NAMESPACE", + parameters = Map( + "sessionCatalog" -> catalog, + "namespace" -> "`a`.`b`.`c`" + ) + ) + } + } +} + +/** + * The class contains tests for the `SHOW COLUMNS ...` command to check V1 In-Memory + * table catalog. + */ +class ShowColumnsSuite extends ShowColumnsSuiteBase with CommandSuiteBase diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowColumnsSuite.scala new file mode 100644 index 0000000000000..64ddce85658e8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowColumnsSuite.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.execution.command + +/** + * The class contains tests for the `SHOW COLUMNS ...` command to check V2 table catalogs. + */ +class ShowColumnsSuite extends command.ShowColumnsSuiteBase with CommandSuiteBase diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowColumnsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowColumnsSuite.scala new file mode 100644 index 0000000000000..4b36d00455aff --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowColumnsSuite.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution.command + +import org.apache.spark.sql.execution.command.v1 + +/** + * The class contains tests for the `SHOW COLUMNS ...` command to check V1 Hive external + * table catalog. + */ +class ShowColumnsSuite extends v1.ShowColumnsSuiteBase with CommandSuiteBase From 493ca987ca930ce4f92eb459895d90ee7f7ee67c Mon Sep 17 00:00:00 2001 From: Neil Ramaswamy Date: Fri, 30 Aug 2024 17:52:15 +0800 Subject: [PATCH 012/230] [SPARK-49378][DOCS][SS] Break apart the Structured Streaming Programming Guide MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? These changes break the Structured Streaming Programming Guide into smaller sub-pages **without changing any content**. You can see a preview of it [here](https://nr-spark-site.vercel.app/). I broke up the pages by `h1` tag; within pages, the sub-sections on the left menu are broken up by `h2`. The SS programming guide now will resemble the SQL programming guide and the MLLib programming guide. Additionally, to avoid cluttering the top-level namespace (there are dozens of `sql-*` files for the SQL reference), we nest all streaming docs in by one directory, namely the `/streaming/`. This has the side-effect of breaking links from our `_layouts`, since we assume a flat top-level namespace. To fix this issue, URLs in global layout files now all use absolute paths. This move to `/streaming/` has the consequence that bookmarks of `https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html` will not refer to the actual programming guide content. In anticipation of this, I have kept all pages for existing URLs present with links to the pages in their new locations. This includes the new state data source and the Kafka integration guide. In the future, we'll be able to quite easily (and in-parallel) break the programming guide apart further. This PR does all of the plumbing to make it work. ![image](https://github.com/user-attachments/assets/3eca87d4-9fb7-453c-a74a-20bd5c504d87) It is future work to fix the oddly-sized left-navigation bar for our menus. ### Why are the changes needed? One of the major hurdles that users have with Structured Streaming is that our guide is exceptionally long—it feels insurmountable, especially compared to other engines like Flink, which has many sub-pages. Google also has a very tricky time indexing the single large page; if you Google "[structured streaming output mode](https://www.google.com/search?q=structured+streaming+output+mode)" and you click on the link to our programming guide... nothing happens. You aren't taken to the actual content, since Google has trouble with indexing to specific heading tags. ### Does this PR introduce _any_ user-facing change? The structure of the website, with respect to Structured Streaming-related pages, is now changed. See the earlier parts of the PR description for the specific changes. However, **no** content is changed. This should make reviewing the changes much easier. ### How was this patch tested? I have used automated tools (e.g. [Lychee](https://github.com/lycheeverse/lychee)) and manual verification (i.e. clicking on every link) to make sure that I didn't break any links. It isn't fool-proof, though. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47864 from neilramaswamy/nr/streaming-guide-breakapart. Lead-authored-by: Neil Ramaswamy Co-authored-by: Kent Yao Signed-off-by: Kent Yao --- docs/_data/menu-streaming.yaml | 57 + .../_includes/nav-left-wrapper-streaming.html | 22 + docs/_includes/nav-left.html | 2 +- docs/_layouts/global.html | 93 +- docs/index.md | 2 +- docs/migration-guide.md | 2 +- docs/sparkr.md | 2 +- docs/ss-migration-guide.md | 40 +- docs/streaming-programming-guide.md | 2 +- docs/streaming/additional-information.md | 58 + .../apis-on-dataframes-and-datasets.md | 3592 ++++++++++++++ docs/streaming/getting-started.md | 508 ++ docs/streaming/index.md | 28 + docs/streaming/performance-tips.md | 174 + docs/streaming/ss-migration-guide.md | 56 + .../structured-streaming-kafka-integration.md | 1173 +++++ .../structured-streaming-state-data-source.md | 0 .../structured-streaming-kafka-integration.md | 1155 +---- .../structured-streaming-programming-guide.md | 4268 +---------------- 19 files changed, 5728 insertions(+), 5506 deletions(-) create mode 100644 docs/_data/menu-streaming.yaml create mode 100644 docs/_includes/nav-left-wrapper-streaming.html create mode 100644 docs/streaming/additional-information.md create mode 100644 docs/streaming/apis-on-dataframes-and-datasets.md create mode 100644 docs/streaming/getting-started.md create mode 100644 docs/streaming/index.md create mode 100644 docs/streaming/performance-tips.md create mode 100644 docs/streaming/ss-migration-guide.md create mode 100644 docs/streaming/structured-streaming-kafka-integration.md rename docs/{ => streaming}/structured-streaming-state-data-source.md (100%) diff --git a/docs/_data/menu-streaming.yaml b/docs/_data/menu-streaming.yaml new file mode 100644 index 0000000000000..b1dd024451125 --- /dev/null +++ b/docs/_data/menu-streaming.yaml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +- text: Overview + url: streaming/index.html +- text: Getting Started + url: streaming/getting-started.html + subitems: + - text: Quick Example + url: streaming/getting-started.html#quick-example + - text: Programming Model + url: streaming/getting-started.html#programming-model +- text: APIs on DataFrames and Datasets + url: streaming/apis-on-dataframes-and-datasets.html + subitems: + - text: Creating Streaming DataFrames and Streaming Datasets + url: streaming/apis-on-dataframes-and-datasets.html#creating-streaming-dataframes-and-streaming-datasets + - text: Operations on Streaming DataFrames/Datasets + url: streaming/apis-on-dataframes-and-datasets.html#operations-on-streaming-dataframesdatasets + - text: Starting Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#starting-streaming-queries + - text: Managing Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#managing-streaming-queries + - text: Monitoring Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#monitoring-streaming-queries + - text: Recovering from Failures with Checkpointing + url: streaming/apis-on-dataframes-and-datasets.html#recovering-from-failures-with-checkpointing + - text: Recovery Semantics after Changes in a Streaming Query + url: streaming/apis-on-dataframes-and-datasets.html#recovery-semantics-after-changes-in-a-streaming-query +- text: Performance Tips + url: streaming/performance-tips.html + subitems: + - text: Asynchronous Progress Tracking + url: streaming/performance-tips.html#asynchronous-progress-tracking + - text: Continuous Processing + url: streaming/performance-tips.html#continuous-processing +- text: Additional Information + url: streaming/additional-information.html + subitems: + - text: Miscellaneous Notes + url: streaming/additional-information.html#miscellaneous-notes + - text: Related Resources + url: streaming/additional-information.html#related-resources + - text: Migration Guide + url: streaming/additional-information.html#migration-guide diff --git a/docs/_includes/nav-left-wrapper-streaming.html b/docs/_includes/nav-left-wrapper-streaming.html new file mode 100644 index 0000000000000..82849f8140f5d --- /dev/null +++ b/docs/_includes/nav-left-wrapper-streaming.html @@ -0,0 +1,22 @@ +{% comment %} +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +{% endcomment %} +
+
+

Structured Streaming Programming Guide

+ {% include nav-left.html nav=include.nav-streaming %} +
+
diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html index 19d68fd191635..935ed0c732ee6 100644 --- a/docs/_includes/nav-left.html +++ b/docs/_includes/nav-left.html @@ -2,7 +2,7 @@