diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 0bf7e57c364e4..2b459e4c73bbb 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -304,7 +304,7 @@ jobs: uses: actions/upload-artifact@v4 with: name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} - path: "**/target/unit-tests.log" + path: "**/target/*.log" infra-image: name: "Base image build" @@ -723,7 +723,7 @@ jobs: # See 'ipython_genutils' in SPARK-38517 # See 'docutils<0.18.0' in SPARK-39421 python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy==1.26.4' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index 3ac1a0117e41b..f668d813ef26e 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -71,7 +71,7 @@ jobs: python packaging/connect/setup.py sdist cd dist pip install pyspark*connect-*.tar.gz - pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting + pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8' - name: Run tests env: SPARK_TESTING: 1 diff --git a/.github/workflows/build_sparkr_window.yml b/.github/workflows/build_sparkr_window.yml index b6656351d431f..ddaf60ad3e71a 100644 --- a/.github/workflows/build_sparkr_window.yml +++ b/.github/workflows/build_sparkr_window.yml @@ -85,7 +85,7 @@ jobs: shell: cmd env: NOT_CRAN: true - SPARK_TESTING: 1 + SPARKR_SUPPRESS_DEPRECATION_WARNING: 1 # See SPARK-27848. Currently installing some dependent packages causes # "(converted from warning) unable to identify current timezone 'C':" for an unknown reason. # This environment variable works around to test SparkR against a higher version. diff --git a/.github/workflows/test_report.yml b/.github/workflows/test_report.yml index c6225e6a1abe5..9ab69af42c818 100644 --- a/.github/workflows/test_report.yml +++ b/.github/workflows/test_report.yml @@ -30,14 +30,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Download test results to report - uses: dawidd6/action-download-artifact@09385b76de790122f4da9c82b17bccf858b9557c # pin@v2 + uses: dawidd6/action-download-artifact@bf251b5aa9c2f7eeb574a96ee720e24f801b7c11 # pin @v6 with: github_token: ${{ secrets.GITHUB_TOKEN }} workflow: ${{ github.event.workflow_run.workflow_id }} commit: ${{ github.event.workflow_run.head_commit.id }} workflow_conclusion: completed - name: Publish test report - uses: scacap/action-surefire-report@482f012643ed0560e23ef605a79e8e87ca081648 # pin@v1 + uses: scacap/action-surefire-report@a2911bd1a4412ec18dde2d93b1758b3e56d2a880 # pin @v1.8.0 with: check_name: Report test results github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 787eb6180c35c..0a4138ec26948 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ *.swp *~ .java-version +.python-version .DS_Store .ammonite .bloop @@ -26,6 +27,7 @@ .scala_dependencies .settings .vscode +artifacts/ /lib/ R-unit-tests.log R/unit-tests.out diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index fc2ab8de1eca0..29c05b0db7c2d 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -404,7 +404,7 @@ sparkR.session <- function( enableHiveSupport = TRUE, ...) { - if (Sys.getenv("SPARK_TESTING") == "") { + if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { warning( "SparkR is deprecated from Apache Spark 4.0.0 and will be removed in a future version.") } diff --git a/R/run-tests.sh b/R/run-tests.sh index 90a60eda03871..3a90b44c2b659 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -30,9 +30,9 @@ if [[ $(echo $SPARK_AVRO_JAR_PATH | wc -l) -eq 1 ]]; then fi if [ -z "$SPARK_JARS" ]; then - SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE + SPARKR_SUPPRESS_DEPRECATION_WARNING=1 SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE else - SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --jars $SPARK_JARS --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE + SPARKR_SUPPRESS_DEPRECATION_WARNING=1 SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --jars $SPARK_JARS --driver-java-options "-Dlog4j.configurationFile=file:$FWDIR/log4j2.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true -Xss4M" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE fi FAILED=$((PIPESTATUS[0]||$FAILED)) diff --git a/assembly/pom.xml b/assembly/pom.xml index e5628ce90fa90..01bd324efc118 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -123,7 +123,7 @@ com.google.guava @@ -200,7 +200,7 @@ cp - ${basedir}/../connector/connect/client/jvm/target/spark-connect-client-jvm_${scala.binary.version}-${version}.jar + ${basedir}/../connector/connect/client/jvm/target/spark-connect-client-jvm_${scala.binary.version}-${project.version}.jar ${basedir}/target/scala-${scala.binary.version}/jars/connect-repl @@ -339,6 +339,14 @@ + + + jjwt + + compile + + + diff --git a/build/mvn b/build/mvn index 28454c68fd128..060209ac1ac4d 100755 --- a/build/mvn +++ b/build/mvn @@ -58,7 +58,7 @@ install_app() { local local_checksum="${local_tarball}.${checksum_suffix}" local remote_checksum="https://archive.apache.org/dist/${url_path}.${checksum_suffix}" - local curl_opts="--retry 3 --retry-all-errors --silent --show-error -L" + local curl_opts="--retry 3 --silent --show-error -L" local wget_opts="--no-verbose" if [ ! -f "$binary" ]; then diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java index f9c0c60c2f2c6..62fcda701d948 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -19,10 +19,7 @@ import java.io.*; import java.util.concurrent.TimeUnit; -import java.util.zip.Adler32; -import java.util.zip.CRC32; -import java.util.zip.CheckedInputStream; -import java.util.zip.Checksum; +import java.util.zip.*; import com.google.common.io.ByteStreams; @@ -66,6 +63,13 @@ private static Checksum[] getChecksumsByAlgorithm(int num, String algorithm) { } } + case "CRC32C" -> { + checksums = new CRC32C[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new CRC32C(); + } + } + default -> throw new UnsupportedOperationException( "Unsupported shuffle checksum algorithm: " + algorithm); } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index e3821a0b85989..fb610a5d96f17 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util; import com.ibm.icu.lang.UCharacter; +import com.ibm.icu.lang.UProperty; import com.ibm.icu.text.BreakIterator; import com.ibm.icu.text.Collator; import com.ibm.icu.text.RuleBasedCollator; @@ -48,6 +49,16 @@ public class CollationAwareUTF8String { */ private static final int MATCH_NOT_FOUND = -1; + /** + * `COMBINED_ASCII_SMALL_I_COMBINING_DOT` is an internal representation of the combined + * lowercase code point for ASCII lowercase letter i with an additional combining dot character + * (U+0307). This integer value is not a valid code point itself, but rather an artificial code + * point marker used to represent the two lowercase characters that are the result of converting + * the uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase. + */ + private static final int COMBINED_ASCII_SMALL_I_COMBINING_DOT = + SpecialCodePointConstants.ASCII_SMALL_I << 16 | SpecialCodePointConstants.COMBINING_DOT; + /** * Returns whether the target string starts with the specified prefix, starting from the * specified position (0-based index referring to character position in UTF8String), with respect @@ -98,16 +109,16 @@ private static int lowercaseMatchLengthFrom( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; } else { // Use buffered lowercase code point iteration to handle one-to-many case mappings. targetCodePoint = getLowercaseCodePoint(targetIterator.next()); - if (targetCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) { - targetCodePoint = CODE_POINT_LOWERCASE_I; - codePointBuffer = CODE_POINT_COMBINING_DOT; + if (targetCodePoint == COMBINED_ASCII_SMALL_I_COMBINING_DOT) { + targetCodePoint = SpecialCodePointConstants.ASCII_SMALL_I; + codePointBuffer = SpecialCodePointConstants.COMBINING_DOT; } ++matchLength; } @@ -200,16 +211,16 @@ private static int lowercaseMatchLengthUntil( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; } else { // Use buffered lowercase code point iteration to handle one-to-many case mappings. targetCodePoint = getLowercaseCodePoint(targetIterator.next()); - if (targetCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) { - targetCodePoint = CODE_POINT_COMBINING_DOT; - codePointBuffer = CODE_POINT_LOWERCASE_I; + if (targetCodePoint == COMBINED_ASCII_SMALL_I_COMBINING_DOT) { + targetCodePoint = SpecialCodePointConstants.COMBINING_DOT; + codePointBuffer = SpecialCodePointConstants.ASCII_SMALL_I; } ++matchLength; } @@ -461,28 +472,16 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col */ private static void appendLowercaseCodePoint(final int codePoint, final StringBuilder sb) { int lowercaseCodePoint = getLowercaseCodePoint(codePoint); - if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) { + if (lowercaseCodePoint == COMBINED_ASCII_SMALL_I_COMBINING_DOT) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. - sb.appendCodePoint(0x0069); - sb.appendCodePoint(0x0307); + sb.appendCodePoint(SpecialCodePointConstants.ASCII_SMALL_I); + sb.appendCodePoint(SpecialCodePointConstants.COMBINING_DOT); } else { // All other characters should follow context-unaware ICU single-code point case mapping. sb.appendCodePoint(lowercaseCodePoint); } } - /** - * `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of the combined lowercase - * code point for ASCII lowercase letter i with an additional combining dot character (U+0307). - * This integer value is not a valid code point itself, but rather an artificial code point - * marker used to represent the two lowercase characters that are the result of converting the - * uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase. - */ - private static final int CODE_POINT_LOWERCASE_I = 0x69; - private static final int CODE_POINT_COMBINING_DOT = 0x307; - private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT = - CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT; - /** * Returns the lowercase version of the provided code point, with special handling for * one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and @@ -490,15 +489,15 @@ private static void appendLowercaseCodePoint(final int codePoint, final StringBu * the position in the string relative to other characters in lowercase). */ private static int getLowercaseCodePoint(final int codePoint) { - if (codePoint == 0x0130) { + if (codePoint == SpecialCodePointConstants.CAPITAL_I_WITH_DOT_ABOVE) { // Latin capital letter I with dot above is mapped to 2 lowercase characters. - return CODE_POINT_COMBINED_LOWERCASE_I_DOT; + return COMBINED_ASCII_SMALL_I_COMBINING_DOT; } - else if (codePoint == 0x03C2) { + else if (codePoint == SpecialCodePointConstants.GREEK_FINAL_SIGMA) { // Greek final and non-final letter sigma should be mapped the same. This is achieved by // mapping Greek small final sigma (U+03C2) to Greek small non-final sigma (U+03C3). Capital // letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in the `else` branch. - return 0x03C3; + return SpecialCodePointConstants.GREEK_SMALL_SIGMA; } else { // All other characters should follow context-unaware ICU single-code point case mapping. @@ -550,6 +549,152 @@ public static UTF8String toTitleCase(final UTF8String target, final int collatio BreakIterator.getWordInstance(locale))); } + /** + * This 'HashMap' is introduced as a performance speedup. Since title-casing a codepoint can + * result in more than a single codepoint, for correctness, we would use + * 'UCharacter.toTitleCase(String)' which returns a 'String'. If we use + * 'UCharacter.toTitleCase(int)' (the version of the same function which converts a single + * codepoint to its title-case codepoint), it would be faster than the previously mentioned + * version, but the problem here is that we don't handle when title-casing a codepoint yields more + * than 1 codepoint. Since there are only 48 codepoints that are mapped to more than 1 codepoint + * when title-cased, they are precalculated here, so that the faster function for title-casing + * could be used in combination with this 'HashMap' in the method 'appendCodepointToTitleCase'. + */ + private static final HashMap codepointOneToManyTitleCaseLookupTable = + new HashMap<>(){{ + StringBuilder sb = new StringBuilder(); + for (int i = Character.MIN_CODE_POINT; i <= Character.MAX_CODE_POINT; ++i) { + sb.appendCodePoint(i); + String titleCase = UCharacter.toTitleCase(sb.toString(), null); + if (titleCase.codePointCount(0, titleCase.length()) > 1) { + put(i, titleCase); + } + sb.setLength(0); + } + }}; + + /** + * Title-casing a string using ICU case mappings. Iterates over the string and title-cases + * the first character in each word, and lowercases every other character. Handles lowercasing + * capital Greek letter sigma ('Σ') separately, taking into account if it should be a small final + * Greek sigma ('ς') or small non-final Greek sigma ('σ'). Words are separated by ASCII + * space(\u0020). + * + * @param source UTF8String to be title cased + * @return title cased source + */ + public static UTF8String toTitleCaseICU(UTF8String source) { + // In the default UTF8String implementation, `toLowerCase` method implicitly does UTF8String + // validation (replacing invalid UTF-8 byte sequences with Unicode replacement character + // U+FFFD), but now we have to do the validation manually. + source = source.makeValid(); + + // Building the title cased source with 'sb'. + UTF8StringBuilder sb = new UTF8StringBuilder(); + + // 'isNewWord' is true if the current character is the beginning of a word, false otherwise. + boolean isNewWord = true; + // We are maintaining if the current character is preceded by a cased letter. + // This is used when lowercasing capital Greek letter sigma ('Σ'), to figure out if it should be + // lowercased into σ or ς. + boolean precededByCasedLetter = false; + + // 'offset' is a byte offset in source's byte array pointing to the beginning of the character + // that we need to process next. + int offset = 0; + int len = source.numBytes(); + + while (offset < len) { + // We will actually call 'codePointFrom()' 2 times for each character in the worst case (once + // here, and once in 'followedByCasedLetter'). Example of a string where we call it 2 times + // for almost every character is 'ΣΣΣΣΣ' (a string consisting only of Greek capital sigma) + // and 'Σ`````' (a string consisting of a Greek capital sigma, followed by case-ignorable + // characters). + int codepoint = source.codePointFrom(offset); + // Appending the correctly cased character onto 'sb'. + appendTitleCasedCodepoint(sb, codepoint, isNewWord, precededByCasedLetter, source, offset); + // Updating 'isNewWord', 'precededByCasedLetter' and 'offset' to be ready for the next + // character that we will process. + isNewWord = (codepoint == SpecialCodePointConstants.ASCII_SPACE); + if (!UCharacter.hasBinaryProperty(codepoint, UProperty.CASE_IGNORABLE)) { + precededByCasedLetter = UCharacter.hasBinaryProperty(codepoint, UProperty.CASED); + } + offset += UTF8String.numBytesForFirstByte(source.getByte(offset)); + } + return sb.build(); + } + + private static void appendTitleCasedCodepoint( + UTF8StringBuilder sb, + int codepoint, + boolean isAfterAsciiSpace, + boolean precededByCasedLetter, + UTF8String source, + int offset) { + if (isAfterAsciiSpace) { + // Title-casing a character if it is in the beginning of a new word. + appendCodepointToTitleCase(sb, codepoint); + return; + } + if (codepoint == SpecialCodePointConstants.GREEK_CAPITAL_SIGMA) { + // Handling capital Greek letter sigma ('Σ'). + appendLowerCasedGreekCapitalSigma(sb, precededByCasedLetter, source, offset); + return; + } + // If it's not the beginning of a word, or a capital Greek letter sigma ('Σ'), we lowercase the + // character. We specially handle 'CAPITAL_I_WITH_DOT_ABOVE'. + if (codepoint == SpecialCodePointConstants.CAPITAL_I_WITH_DOT_ABOVE) { + sb.appendCodePoint(SpecialCodePointConstants.ASCII_SMALL_I); + sb.appendCodePoint(SpecialCodePointConstants.COMBINING_DOT); + return; + } + sb.appendCodePoint(UCharacter.toLowerCase(codepoint)); + } + + private static void appendLowerCasedGreekCapitalSigma( + UTF8StringBuilder sb, + boolean precededByCasedLetter, + UTF8String source, + int offset) { + int codepoint = (!followedByCasedLetter(source, offset) && precededByCasedLetter) + ? SpecialCodePointConstants.GREEK_FINAL_SIGMA + : SpecialCodePointConstants.GREEK_SMALL_SIGMA; + sb.appendCodePoint(codepoint); + } + + /** + * Checks if the character beginning at 'offset'(in 'sources' byte array) is followed by a cased + * letter. + */ + private static boolean followedByCasedLetter(UTF8String source, int offset) { + // Moving the offset one character forward, so we could start the linear search from there. + offset += UTF8String.numBytesForFirstByte(source.getByte(offset)); + int len = source.numBytes(); + + while (offset < len) { + int codepoint = source.codePointFrom(offset); + + if (UCharacter.hasBinaryProperty(codepoint, UProperty.CASE_IGNORABLE)) { + offset += UTF8String.numBytesForFirstByte(source.getByte(offset)); + continue; + } + return UCharacter.hasBinaryProperty(codepoint, UProperty.CASED); + } + return false; + } + + /** + * Appends title-case of a single character to a 'StringBuilder' using the ICU root locale rules. + */ + private static void appendCodepointToTitleCase(UTF8StringBuilder sb, int codepoint) { + String toTitleCase = codepointOneToManyTitleCaseLookupTable.get(codepoint); + if (toTitleCase == null) { + sb.appendCodePoint(UCharacter.toTitleCase(codepoint)); + } else { + sb.append(toTitleCase); + } + } + /* * Returns the position of the first occurrence of the match string in the set string, * counting ASCII commas as delimiters. The match string is compared in a collation-aware manner, @@ -843,11 +988,11 @@ public static UTF8String lowercaseTranslate(final UTF8String input, } // Special handling for letter i (U+0069) followed by a combining dot (U+0307). By ensuring // that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a max-length match. - if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) && - codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) { + if (lowercaseDict.containsKey(COMBINED_ASCII_SMALL_I_COMBINING_DOT) + && codePoint == SpecialCodePointConstants.ASCII_SMALL_I && inputIter.hasNext()) { int nextCodePoint = inputIter.next(); - if (nextCodePoint == CODE_POINT_COMBINING_DOT) { - codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT; + if (nextCodePoint == SpecialCodePointConstants.COMBINING_DOT) { + codePoint = COMBINED_ASCII_SMALL_I_COMBINING_DOT; } else { codePointBuffer = nextCodePoint; } @@ -1007,11 +1152,11 @@ public static UTF8String lowercaseTrimLeft( codePoint = getLowercaseCodePoint(srcIter.next()); } // Special handling for Turkish dotted uppercase letter I. - if (codePoint == CODE_POINT_LOWERCASE_I && srcIter.hasNext() && - trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) { + if (codePoint == SpecialCodePointConstants.ASCII_SMALL_I && srcIter.hasNext() && + trimChars.contains(COMBINED_ASCII_SMALL_I_COMBINING_DOT)) { codePointBuffer = codePoint; codePoint = getLowercaseCodePoint(srcIter.next()); - if (codePoint == CODE_POINT_COMBINING_DOT) { + if (codePoint == SpecialCodePointConstants.COMBINING_DOT) { searchIndex += 2; codePointBuffer = -1; } else if (trimChars.contains(codePointBuffer)) { @@ -1125,11 +1270,11 @@ public static UTF8String lowercaseTrimRight( codePoint = getLowercaseCodePoint(srcIter.next()); } // Special handling for Turkish dotted uppercase letter I. - if (codePoint == CODE_POINT_COMBINING_DOT && srcIter.hasNext() && - trimChars.contains(CODE_POINT_COMBINED_LOWERCASE_I_DOT)) { + if (codePoint == SpecialCodePointConstants.COMBINING_DOT && srcIter.hasNext() && + trimChars.contains(COMBINED_ASCII_SMALL_I_COMBINING_DOT)) { codePointBuffer = codePoint; codePoint = getLowercaseCodePoint(srcIter.next()); - if (codePoint == CODE_POINT_LOWERCASE_I) { + if (codePoint == SpecialCodePointConstants.ASCII_SMALL_I) { searchIndex -= 2; codePointBuffer = -1; } else if (trimChars.contains(codePointBuffer)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 5640a2468d02e..d5dbca7eb89bc 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -23,12 +23,14 @@ import java.util.function.Function; import java.util.function.BiFunction; import java.util.function.ToLongFunction; +import java.util.stream.Stream; +import com.ibm.icu.text.CollationKey; +import com.ibm.icu.text.Collator; import com.ibm.icu.text.RuleBasedCollator; import com.ibm.icu.text.StringSearch; import com.ibm.icu.util.ULocale; -import com.ibm.icu.text.CollationKey; -import com.ibm.icu.text.Collator; +import com.ibm.icu.util.VersionInfo; import org.apache.spark.SparkException; import org.apache.spark.unsafe.types.UTF8String; @@ -88,6 +90,17 @@ public Optional getVersion() { } } + public record CollationMeta( + String catalog, + String schema, + String collationName, + String language, + String country, + String icuVersion, + String padAttribute, + boolean accentSensitivity, + boolean caseSensitivity) { } + /** * Entry encapsulating all information about a collation. */ @@ -342,6 +355,23 @@ private static int collationNameToId(String collationName) throws SparkException } protected abstract Collation buildCollation(); + + protected abstract CollationMeta buildCollationMeta(); + + static List listCollations() { + return Stream.concat( + CollationSpecUTF8.listCollations().stream(), + CollationSpecICU.listCollations().stream()).toList(); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + CollationMeta collationSpecUTF8 = + CollationSpecUTF8.loadCollationMeta(collationIdentifier); + if (collationSpecUTF8 == null) { + return CollationSpecICU.loadCollationMeta(collationIdentifier); + } + return collationSpecUTF8; + } } private static class CollationSpecUTF8 extends CollationSpec { @@ -364,6 +394,9 @@ private enum CaseSensitivity { */ private static final int CASE_SENSITIVITY_MASK = 0b1; + private static final String UTF8_BINARY_COLLATION_NAME = "UTF8_BINARY"; + private static final String UTF8_LCASE_COLLATION_NAME = "UTF8_LCASE"; + private static final int UTF8_BINARY_COLLATION_ID = new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId; private static final int UTF8_LCASE_COLLATION_ID = @@ -406,7 +439,7 @@ private static CollationSpecUTF8 fromCollationId(int collationId) { protected Collation buildCollation() { if (collationId == UTF8_BINARY_COLLATION_ID) { return new Collation( - "UTF8_BINARY", + UTF8_BINARY_COLLATION_NAME, PROVIDER_SPARK, null, UTF8String::binaryCompare, @@ -417,7 +450,7 @@ protected Collation buildCollation() { /* supportsLowercaseEquality = */ false); } else { return new Collation( - "UTF8_LCASE", + UTF8_LCASE_COLLATION_NAME, PROVIDER_SPARK, null, CollationAwareUTF8String::compareLowerCase, @@ -428,6 +461,52 @@ protected Collation buildCollation() { /* supportsLowercaseEquality = */ true); } } + + @Override + protected CollationMeta buildCollationMeta() { + if (collationId == UTF8_BINARY_COLLATION_ID) { + return new CollationMeta( + CATALOG, + SCHEMA, + UTF8_BINARY_COLLATION_NAME, + /* language = */ null, + /* country = */ null, + /* icuVersion = */ null, + COLLATION_PAD_ATTRIBUTE, + /* accentSensitivity = */ true, + /* caseSensitivity = */ true); + } else { + return new CollationMeta( + CATALOG, + SCHEMA, + UTF8_LCASE_COLLATION_NAME, + /* language = */ null, + /* country = */ null, + /* icuVersion = */ null, + COLLATION_PAD_ATTRIBUTE, + /* accentSensitivity = */ true, + /* caseSensitivity = */ false); + } + } + + static List listCollations() { + CollationIdentifier UTF8_BINARY_COLLATION_IDENT = + new CollationIdentifier(PROVIDER_SPARK, UTF8_BINARY_COLLATION_NAME, "1.0"); + CollationIdentifier UTF8_LCASE_COLLATION_IDENT = + new CollationIdentifier(PROVIDER_SPARK, UTF8_LCASE_COLLATION_NAME, "1.0"); + return Arrays.asList(UTF8_BINARY_COLLATION_IDENT, UTF8_LCASE_COLLATION_IDENT); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + try { + int collationId = CollationSpecUTF8.collationNameToId( + collationIdentifier.name, collationIdentifier.name.toUpperCase()); + return CollationSpecUTF8.fromCollationId(collationId).buildCollationMeta(); + } catch (SparkException ignored) { + // ignore + return null; + } + } } private static class CollationSpecICU extends CollationSpec { @@ -684,6 +763,20 @@ protected Collation buildCollation() { /* supportsLowercaseEquality = */ false); } + @Override + protected CollationMeta buildCollationMeta() { + return new CollationMeta( + CATALOG, + SCHEMA, + collationName(), + ICULocaleMap.get(locale).getDisplayLanguage(), + ICULocaleMap.get(locale).getDisplayCountry(), + VersionInfo.ICU_VERSION.toString(), + COLLATION_PAD_ATTRIBUTE, + accentSensitivity == AccentSensitivity.AS, + caseSensitivity == CaseSensitivity.CS); + } + /** * Compute normalized collation name. Components of collation name are given in order: * - Locale name @@ -704,6 +797,37 @@ private String collationName() { } return builder.toString(); } + + private static List allCollationNames() { + List collationNames = new ArrayList<>(); + for (String locale: ICULocaleToId.keySet()) { + // CaseSensitivity.CS + AccentSensitivity.AS + collationNames.add(locale); + // CaseSensitivity.CS + AccentSensitivity.AI + collationNames.add(locale + "_AI"); + // CaseSensitivity.CI + AccentSensitivity.AS + collationNames.add(locale + "_CI"); + // CaseSensitivity.CI + AccentSensitivity.AI + collationNames.add(locale + "_CI_AI"); + } + return collationNames.stream().sorted().toList(); + } + + static List listCollations() { + return allCollationNames().stream().map(name -> + new CollationIdentifier(PROVIDER_ICU, name, VersionInfo.ICU_VERSION.toString())).toList(); + } + + static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + try { + int collationId = CollationSpecICU.collationNameToId( + collationIdentifier.name, collationIdentifier.name.toUpperCase()); + return CollationSpecICU.fromCollationId(collationId).buildCollationMeta(); + } catch (SparkException ignored) { + // ignore + return null; + } + } } /** @@ -730,9 +854,12 @@ public CollationIdentifier identifier() { } } + public static final String CATALOG = "SYSTEM"; + public static final String SCHEMA = "BUILTIN"; public static final String PROVIDER_SPARK = "spark"; public static final String PROVIDER_ICU = "icu"; public static final List SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU); + public static final String COLLATION_PAD_ATTRIBUTE = "NO_PAD"; public static final int UTF8_BINARY_COLLATION_ID = Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION_ID; @@ -794,6 +921,18 @@ public static int collationNameToId(String collationName) throws SparkException return Collation.CollationSpec.collationNameToId(collationName); } + /** + * Returns whether the ICU collation is not Case Sensitive Accent Insensitive + * for the given collation id. + * This method is used in expressions which do not support CS_AI collations. + */ + public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == + Collation.CollationSpecICU.CaseSensitivity.CS && + Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == + Collation.CollationSpecICU.AccentSensitivity.AI; + } + public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( @@ -923,4 +1062,12 @@ public static String getClosestSuggestionsOnInvalidName( return String.join(", ", suggestions); } + + public static List listCollations() { + return Collation.CollationSpec.listCollations(); + } + + public static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) { + return Collation.CollationSpec.loadCollationMeta(collationIdentifier); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 6516837968776..f05d9e512568f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -283,7 +283,7 @@ public static UTF8String execBinary(final UTF8String v) { return v.toLowerCase().toTitleCase(); } public static UTF8String execBinaryICU(final UTF8String v) { - return CollationAwareUTF8String.toLowerCase(v).toTitleCaseICU(); + return CollationAwareUTF8String.toTitleCaseICU(v); } public static UTF8String execLowercase(final UTF8String v) { return CollationAwareUTF8String.toTitleCase(v); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/SpecialCodePointConstants.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/SpecialCodePointConstants.java new file mode 100644 index 0000000000000..db615d747910b --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/SpecialCodePointConstants.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +/** + * 'SpecialCodePointConstants' is introduced in order to keep the codepoints used in + * 'CollationAwareUTF8String' in one place. + */ +public class SpecialCodePointConstants { + + public static final int COMBINING_DOT = 0x0307; + public static final int ASCII_SMALL_I = 0x0069; + public static final int ASCII_SPACE = 0x0020; + public static final int GREEK_CAPITAL_SIGMA = 0x03A3; + public static final int GREEK_SMALL_SIGMA = 0x03C3; + public static final int GREEK_FINAL_SIGMA = 0x03C2; + public static final int CAPITAL_I_WITH_DOT_ABOVE = 0x0130; +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java index 481ea89090b2a..3a697345ce1f5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UTF8StringBuilder.java @@ -96,4 +96,33 @@ public void appendBytes(Object base, long offset, int length) { public UTF8String build() { return UTF8String.fromBytes(buffer, 0, totalSize()); } + + public void appendCodePoint(int codePoint) { + if (codePoint <= 0x7F) { + grow(1); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) codePoint; + ++cursor; + } else if (codePoint <= 0x7FF) { + grow(2); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) (0xC0 | (codePoint >> 6)); + buffer[cursor + 1 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | (codePoint & 0x3F)); + cursor += 2; + } else if (codePoint <= 0xFFFF) { + grow(3); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) (0xE0 | (codePoint >> 12)); + buffer[cursor + 1 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | ((codePoint >> 6) & 0x3F)); + buffer[cursor + 2 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | (codePoint & 0x3F)); + cursor += 3; + } else if (codePoint <= 0x10FFFF) { + grow(4); + buffer[cursor - Platform.BYTE_ARRAY_OFFSET] = (byte) (0xF0 | (codePoint >> 18)); + buffer[cursor + 1 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | ((codePoint >> 12) & 0x3F)); + buffer[cursor + 2 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | ((codePoint >> 6) & 0x3F)); + buffer[cursor + 3 - Platform.BYTE_ARRAY_OFFSET] = (byte) (0x80 | (codePoint & 0x3F)); + cursor += 4; + } else { + throw new IllegalArgumentException("Invalid Unicode codePoint: " + codePoint); + } + } + } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 5cc975d38d4da..a445cde52ad57 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -629,6 +629,8 @@ public void testStartsWith() throws SparkException { assertStartsWith("İonic", "Io", "UTF8_LCASE", false); assertStartsWith("İonic", "i\u0307o", "UTF8_LCASE", true); assertStartsWith("İonic", "İo", "UTF8_LCASE", true); + assertStartsWith("oİ", "oİ", "UTF8_LCASE", true); + assertStartsWith("oİ", "oi̇", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertStartsWith("σ", "σ", "UTF8_BINARY", true); assertStartsWith("σ", "ς", "UTF8_BINARY", false); @@ -880,6 +882,8 @@ public void testEndsWith() throws SparkException { assertEndsWith("the İo", "Io", "UTF8_LCASE", false); assertEndsWith("the İo", "i\u0307o", "UTF8_LCASE", true); assertEndsWith("the İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "i̇o", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertEndsWith("σ", "σ", "UTF8_BINARY", true); assertEndsWith("σ", "ς", "UTF8_BINARY", false); @@ -1334,6 +1338,23 @@ private void assertInitCap(String target, String collationName, String expected) // Note: results should be the same in these tests for both ICU and JVM-based implementations. } + private void assertInitCap( + String target, + String collationName, + String expectedICU, + String expectedNonICU) throws SparkException { + UTF8String target_utf8 = UTF8String.fromString(target); + UTF8String expectedICU_utf8 = UTF8String.fromString(expectedICU); + UTF8String expectedNonICU_utf8 = UTF8String.fromString(expectedNonICU); + int collationId = CollationFactory.collationNameToId(collationName); + // Testing the new ICU-based implementation of the Lower function. + assertEquals(expectedICU_utf8, CollationSupport.InitCap.exec(target_utf8, collationId, true)); + // Testing the old JVM-based implementation of the Lower function. + assertEquals(expectedNonICU_utf8, CollationSupport.InitCap.exec(target_utf8, collationId, + false)); + // Note: results should be the same in these tests for both ICU and JVM-based implementations. + } + @Test public void testInitCap() throws SparkException { for (String collationName: testSupportedCollations) { @@ -1372,12 +1393,22 @@ public void testInitCap() throws SparkException { assertInitCap("ÄBĆΔE", "UTF8_LCASE", "Äbćδe"); assertInitCap("ÄBĆΔE", "UNICODE", "Äbćδe"); assertInitCap("ÄBĆΔE", "UNICODE_CI", "Äbćδe"); + // Case-variable character length + assertInitCap("İo", "UTF8_BINARY", "İo", "I\u0307o"); + assertInitCap("İo", "UTF8_LCASE", "İo"); + assertInitCap("İo", "UNICODE", "İo"); + assertInitCap("İo", "UNICODE_CI", "İo"); + assertInitCap("i\u0307o", "UTF8_BINARY", "I\u0307o"); + assertInitCap("i\u0307o", "UTF8_LCASE", "I\u0307o"); + assertInitCap("i\u0307o", "UNICODE", "I\u0307o"); + assertInitCap("i\u0307o", "UNICODE_CI", "I\u0307o"); + // Different possible word boundaries assertInitCap("aB 世 de", "UTF8_BINARY", "Ab 世 De"); assertInitCap("aB 世 de", "UTF8_LCASE", "Ab 世 De"); assertInitCap("aB 世 de", "UNICODE", "Ab 世 De"); assertInitCap("aB 世 de", "UNICODE_CI", "Ab 世 De"); // One-to-many case mapping (e.g. Turkish dotted I). - assertInitCap("İ", "UTF8_BINARY", "I\u0307"); + assertInitCap("İ", "UTF8_BINARY", "İ", "I\u0307"); assertInitCap("İ", "UTF8_LCASE", "İ"); assertInitCap("İ", "UNICODE", "İ"); assertInitCap("İ", "UNICODE_CI", "İ"); @@ -1385,7 +1416,7 @@ public void testInitCap() throws SparkException { assertInitCap("I\u0307", "UTF8_LCASE","I\u0307"); assertInitCap("I\u0307", "UNICODE","I\u0307"); assertInitCap("I\u0307", "UNICODE_CI","I\u0307"); - assertInitCap("İonic", "UTF8_BINARY", "I\u0307onic"); + assertInitCap("İonic", "UTF8_BINARY", "İonic", "I\u0307onic"); assertInitCap("İonic", "UTF8_LCASE", "İonic"); assertInitCap("İonic", "UNICODE", "İonic"); assertInitCap("İonic", "UNICODE_CI", "İonic"); @@ -1414,23 +1445,24 @@ public void testInitCap() throws SparkException { assertInitCap("𝔸", "UTF8_LCASE", "𝔸"); assertInitCap("𝔸", "UNICODE", "𝔸"); assertInitCap("𝔸", "UNICODE_CI", "𝔸"); - assertInitCap("𐐅", "UTF8_BINARY", "𐐭"); + assertInitCap("𐐅", "UTF8_BINARY", "\uD801\uDC05", "𐐭"); assertInitCap("𐐅", "UTF8_LCASE", "𐐅"); assertInitCap("𐐅", "UNICODE", "𐐅"); assertInitCap("𐐅", "UNICODE_CI", "𐐅"); - assertInitCap("𐐭", "UTF8_BINARY", "𐐭"); + assertInitCap("𐐭", "UTF8_BINARY", "\uD801\uDC05", "𐐭"); assertInitCap("𐐭", "UTF8_LCASE", "𐐅"); assertInitCap("𐐭", "UNICODE", "𐐅"); assertInitCap("𐐭", "UNICODE_CI", "𐐅"); - assertInitCap("𐐭𝔸", "UTF8_BINARY", "𐐭𝔸"); + assertInitCap("𐐭𝔸", "UTF8_BINARY", "\uD801\uDC05\uD835\uDD38", "𐐭𝔸"); assertInitCap("𐐭𝔸", "UTF8_LCASE", "𐐅𝔸"); assertInitCap("𐐭𝔸", "UNICODE", "𐐅𝔸"); assertInitCap("𐐭𝔸", "UNICODE_CI", "𐐅𝔸"); // Ligatures. - assertInitCap("ß fi ffi ff st ῗ", "UTF8_BINARY","ß fi ffi ff st ῗ"); - assertInitCap("ß fi ffi ff st ῗ", "UTF8_LCASE","Ss Fi Ffi Ff St \u0399\u0308\u0342"); - assertInitCap("ß fi ffi ff st ῗ", "UNICODE","Ss Fi Ffi Ff St \u0399\u0308\u0342"); - assertInitCap("ß fi ffi ff st ῗ", "UNICODE","Ss Fi Ffi Ff St \u0399\u0308\u0342"); + assertInitCap("ß fi ffi ff st ῗ", "UTF8_BINARY", "Ss Fi Ffi Ff St Ϊ͂", "ß fi ffi ff st ῗ"); + assertInitCap("ß fi ffi ff st ῗ", "UTF8_LCASE", "Ss Fi Ffi Ff St \u0399\u0308\u0342"); + assertInitCap("ß fi ffi ff st ῗ", "UNICODE", "Ss Fi Ffi Ff St \u0399\u0308\u0342"); + assertInitCap("ß fi ffi ff st ῗ", "UNICODE", "Ss Fi Ffi Ff St \u0399\u0308\u0342"); + assertInitCap("œ ǽ", "UTF8_BINARY", "Œ Ǽ", "Œ Ǽ"); // Different possible word boundaries. assertInitCap("a b c", "UTF8_BINARY", "A B C"); assertInitCap("a b c", "UNICODE", "A B C"); @@ -1458,13 +1490,42 @@ public void testInitCap() throws SparkException { assertInitCap("džaba Ljubav NJegova", "UTF8_LCASE", "Džaba Ljubav Njegova"); assertInitCap("džaba Ljubav NJegova", "UNICODE_CI", "Džaba Ljubav Njegova"); assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UTF8_BINARY", - "ß fi ffi ff st Σημερινος Ασημενιος I\u0307ota"); + "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota","ß fi ffi ff st Σημερινος Ασημενιος I\u0307ota"); assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UTF8_LCASE", "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota"); assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UNICODE", "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota"); - assertInitCap("ß fi ffi ff st ΣΗΜΕΡΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UNICODE_CI", - "Ss Fi Ffi Ff St Σημερινος Ασημενιος İota"); + assertInitCap("ß fi ffi ff st ΣΗΜΕΡςΙΝΟΣ ΑΣΗΜΕΝΙΟΣ İOTA", "UNICODE_CI", + "Ss Fi Ffi Ff St Σημερςινος Ασημενιος İota"); + // Characters that map to multiple characters when titlecased and lowercased. + assertInitCap("ß fi ffi ff st İOTA", "UTF8_BINARY", "Ss Fi Ffi Ff St İota", "ß fi ffi ff st İota"); + assertInitCap("ß fi ffi ff st OİOTA", "UTF8_BINARY", + "Ss Fi Ffi Ff St Oi\u0307ota", "ß fi ffi ff st Oi̇ota"); + // Lowercasing Greek letter sigma ('Σ') when case-ignorable character present. + assertInitCap("`Σ", "UTF8_BINARY", "`σ", "`σ"); + assertInitCap("1`Σ`` AΣ", "UTF8_BINARY", "1`σ`` Aς", "1`σ`` Aς"); + assertInitCap("a1`Σ``", "UTF8_BINARY", "A1`σ``", "A1`σ``"); + assertInitCap("a`Σ``", "UTF8_BINARY", "A`ς``", "A`σ``"); + assertInitCap("a`Σ``1", "UTF8_BINARY", "A`ς``1", "A`σ``1"); + assertInitCap("a`Σ``A", "UTF8_BINARY", "A`σ``a", "A`σ``a"); + assertInitCap("ΘΑ�Σ�ΟΣ�", "UTF8_BINARY", "Θα�σ�ος�", "Θα�σ�ος�"); + assertInitCap("ΘΑᵩΣ�ΟᵩΣᵩ�", "UTF8_BINARY", "Θαᵩς�οᵩςᵩ�", "Θαᵩς�οᵩςᵩ�"); + assertInitCap("ΘΑ�ᵩΣ�ΟᵩΣᵩ�", "UTF8_BINARY", "Θα�ᵩσ�οᵩςᵩ�", "Θα�ᵩσ�οᵩςᵩ�"); + assertInitCap("ΘΑ�ᵩΣᵩ�ΟᵩΣᵩ�", "UTF8_BINARY", "Θα�ᵩσᵩ�οᵩςᵩ�", "Θα�ᵩσᵩ�οᵩςᵩ�"); + assertInitCap("ΘΑ�Σ�Ο�Σ�", "UTF8_BINARY", "Θα�σ�ο�σ�", "Θα�σ�ο�σ�"); + // Disallowed bytes and invalid sequences. + assertInitCap(UTF8String.fromBytes(new byte[] { (byte)0xC0, (byte)0xC1, (byte)0xF5}).toString(), + "UTF8_BINARY", "���", "���"); + assertInitCap(UTF8String.fromBytes( + new byte[]{(byte)0xC0, (byte)0xC1, (byte)0xF5, 0x20, 0x61, 0x41, (byte)0xC0}).toString(), + "UTF8_BINARY", + "��� Aa�", "��� Aa�"); + assertInitCap(UTF8String.fromBytes(new byte[]{(byte)0xC2,(byte)0xC2}).toString(), + "UTF8_BINARY", "��", "��"); + assertInitCap(UTF8String.fromBytes( + new byte[]{0x61, 0x41, (byte)0xC2, (byte)0xC2, 0x41}).toString(), + "UTF8_BINARY", + "Aa��a", "Aa��a"); } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 2428d40fe8016..c4a66fdffdd4d 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -26,6 +26,8 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UTF8StringBuilder; + import org.junit.jupiter.api.Test; import static org.apache.spark.unsafe.types.UTF8String.fromString; @@ -1362,4 +1364,27 @@ public void toBinaryString() { UTF8String.fromString("111111111111111111111111111111111111111111111111111111111111111"), UTF8String.toBinaryString(Long.MAX_VALUE)); } + + /** + * This tests whether appending a codepoint to a 'UTF8StringBuilder' correctly appends every + * single codepoint. We test it against an already existing 'StringBuilder.appendCodePoint' and + * 'UTF8String.fromString'. We skip testing the surrogate codepoints because at some point while + * converting the surrogate codepoint to 'UTF8String' (via 'StringBuilder' and 'UTF8String') we + * get an ill-formated byte sequence (probably because 'String' is in UTF-16 format, and a single + * surrogate codepoint is handled differently in UTF-16 than in UTF-8, so somewhere during those + * conversions some different behaviour happens). + */ + @Test + public void testAppendCodepointToUTF8StringBuilder() { + int surrogateRangeLowerBound = 0xD800; + int surrogateRangeUpperBound = 0xDFFF; + for (int i = Character.MIN_CODE_POINT; i <= Character.MAX_CODE_POINT; ++i) { + if(surrogateRangeLowerBound <= i && i <= surrogateRangeUpperBound) continue; + UTF8StringBuilder usb = new UTF8StringBuilder(); + usb.appendCodePoint(i); + StringBuilder sb = new StringBuilder(); + sb.appendCodePoint(i); + assert(usb.build().equals(UTF8String.fromString(sb.toString()))); + } + } } diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 89d2627ef32ee..e83202d9e5ee3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1,4 +1,10 @@ { + "ADD_DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION" : { "message" : [ "Non-deterministic expression should not appear in the arguments of an aggregate function." @@ -434,6 +440,12 @@ ], "sqlState" : "42846" }, + "CANNOT_USE_KRYO" : { + "message" : [ + "Cannot load Kryo serialization codec. Kryo serialization cannot be used in the Spark Connect client. Use Java serialization, provide a custom Codec, or use Spark Classic instead." + ], + "sqlState" : "22KD3" + }, "CANNOT_WRITE_STATE_STORE" : { "message" : [ "Error writing state store files for provider ." @@ -449,13 +461,13 @@ }, "CAST_INVALID_INPUT" : { "message" : [ - "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set to \"false\" to bypass this error." + "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead." ], "sqlState" : "22018" }, "CAST_OVERFLOW" : { "message" : [ - "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set to \"false\" to bypass this error." + "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead." ], "sqlState" : "22003" }, @@ -898,7 +910,7 @@ }, "NON_STRING_TYPE" : { "message" : [ - "all arguments must be strings." + "all arguments of the function must be strings." ] }, "NULL_TYPE" : { @@ -1039,6 +1051,12 @@ ], "sqlState" : "42710" }, + "DATA_SOURCE_EXTERNAL_ERROR" : { + "message" : [ + "Encountered error when saving to external data source." + ], + "sqlState" : "KD010" + }, "DATA_SOURCE_NOT_EXIST" : { "message" : [ "Data source '' not found. Please make sure the data source is registered." @@ -1084,6 +1102,12 @@ ], "sqlState" : "42608" }, + "DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" : { "message" : [ "Distinct window functions are not supported: ." @@ -1432,6 +1456,12 @@ ], "sqlState" : "2203G" }, + "FAILED_TO_LOAD_ROUTINE" : { + "message" : [ + "Failed to load routine ." + ], + "sqlState" : "38000" + }, "FAILED_TO_PARSE_TOO_COMPLEX" : { "message" : [ "The statement, including potential SQL functions and referenced views, was too complex to parse.", @@ -1457,6 +1487,12 @@ ], "sqlState" : "42704" }, + "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in flatMapGroupsWithState. Reason: " + ], + "sqlState" : "39000" + }, "FORBIDDEN_OPERATION" : { "message" : [ "The operation is not allowed on the : ." @@ -1469,6 +1505,12 @@ ], "sqlState" : "39000" }, + "FOREACH_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in foreach sink. Reason: " + ], + "sqlState" : "39000" + }, "FOUND_MULTIPLE_DATA_SOURCES" : { "message" : [ "Detected multiple data sources with the name ''. Please check the data source isn't simultaneously registered and located in the classpath." @@ -1565,6 +1607,36 @@ ], "sqlState" : "42601" }, + "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION" : { + "message" : [ + "Duplicated IDENTITY column sequence generator option: ." + ], + "sqlState" : "42601" + }, + "IDENTITY_COLUMNS_ILLEGAL_STEP" : { + "message" : [ + "IDENTITY column step cannot be 0." + ], + "sqlState" : "42611" + }, + "IDENTITY_COLUMNS_UNSUPPORTED_DATA_TYPE" : { + "message" : [ + "DataType is not supported for IDENTITY columns." + ], + "sqlState" : "428H2" + }, + "IDENTITY_COLUMN_WITH_DEFAULT_VALUE" : { + "message" : [ + "A column cannot have both a default value and an identity column specification but column has default value: () and identity column specification: ()." + ], + "sqlState" : "42623" + }, + "ILLEGAL_DAY_OF_WEEK" : { + "message" : [ + "Illegal input for day of week: ." + ], + "sqlState" : "22009" + }, "ILLEGAL_STATE_STORE_VALUE" : { "message" : [ "Illegal value provided to the State Store" @@ -1942,6 +2014,12 @@ }, "sqlState" : "42903" }, + "INVALID_AGNOSTIC_ENCODER" : { + "message" : [ + "Found an invalid agnostic encoder. Expects an instance of AgnosticEncoder but got . For more information consult '/api/java/index.html?org/apache/spark/sql/Encoder.html'." + ], + "sqlState" : "42001" + }, "INVALID_ARRAY_INDEX" : { "message" : [ "The index is out of bounds. The array has elements. Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead. If necessary set to \"false\" to bypass this error." @@ -2074,6 +2152,11 @@ "message" : [ "Too many letters in datetime pattern: . Please reduce pattern length." ] + }, + "SECONDS_FRACTION" : { + "message" : [ + "Cannot detect a seconds fraction pattern of variable length. Please make sure the pattern contains 'S', and does not contain illegal characters." + ] } }, "sqlState" : "22007" @@ -2177,6 +2260,12 @@ ], "sqlState" : "42001" }, + "INVALID_EXTERNAL_TYPE" : { + "message" : [ + "The external type is not valid for the type at the expression ." + ], + "sqlState" : "42K0N" + }, "INVALID_EXTRACT_BASE_FIELD_TYPE" : { "message" : [ "Can't extract a value from . Need a complex type [STRUCT, ARRAY, MAP] but got ." @@ -2366,6 +2455,11 @@ "Uncaught arithmetic exception while parsing ''." ] }, + "DAY_TIME_PARSING" : { + "message" : [ + "Error parsing interval day-time string: ." + ] + }, "INPUT_IS_EMPTY" : { "message" : [ "Interval string cannot be empty." @@ -2376,6 +2470,11 @@ "Interval string cannot be null." ] }, + "INTERVAL_PARSING" : { + "message" : [ + "Error parsing interval string." + ] + }, "INVALID_FRACTION" : { "message" : [ " cannot have fractional part." @@ -2411,15 +2510,35 @@ "Expect a unit name after but hit EOL." ] }, + "SECOND_NANO_FORMAT" : { + "message" : [ + "Interval string does not match second-nano format of ss.nnnnnnnnn." + ] + }, "UNKNOWN_PARSING_ERROR" : { "message" : [ "Unknown error when parsing ." ] }, + "UNMATCHED_FORMAT_STRING" : { + "message" : [ + "Interval string does not match format of when cast to : ." + ] + }, + "UNMATCHED_FORMAT_STRING_WITH_NOTICE" : { + "message" : [ + "Interval string does not match format of when cast to : . Set \"spark.sql.legacy.fromDayTimeString.enabled\" to \"true\" to restore the behavior before Spark 3.0." + ] + }, "UNRECOGNIZED_NUMBER" : { "message" : [ "Unrecognized number ." ] + }, + "UNSUPPORTED_FROM_TO_EXPRESSION" : { + "message" : [ + "Cannot support (interval '' to ) expression." + ] } }, "sqlState" : "22006" @@ -2483,6 +2602,24 @@ ], "sqlState" : "F0000" }, + "INVALID_LABEL_USAGE" : { + "message" : [ + "The usage of the label is invalid." + ], + "subClass" : { + "DOES_NOT_EXIST" : { + "message" : [ + "Label was used in the statement, but the label does not belong to any surrounding block." + ] + }, + "ITERATE_IN_COMPOUND" : { + "message" : [ + "ITERATE statement cannot be used with a label that belongs to a compound (BEGIN...END) body." + ] + } + }, + "sqlState" : "42K0L" + }, "INVALID_LAMBDA_FUNCTION_CALL" : { "message" : [ "Invalid lambda function call." @@ -3035,12 +3172,12 @@ "subClass" : { "NOT_ALLOWED_IN_SCOPE" : { "message" : [ - "Variable was declared on line , which is not allowed in this scope." + "Declaration of the variable is not allowed in this scope." ] }, "ONLY_AT_BEGINNING" : { "message" : [ - "Variable can only be declared at the beginning of the compound, but it was declared on line ." + "Variable can only be declared at the beginning of the compound." ] } }, @@ -3665,6 +3802,12 @@ ], "sqlState" : "42K03" }, + "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION" : { + "message" : [ + "Aggregate function is not allowed when using the pipe operator |> SELECT clause; please use the pipe operator |> AGGREGATE clause instead" + ], + "sqlState" : "0A000" + }, "PIVOT_VALUE_DATA_TYPE_MISMATCH" : { "message" : [ "Invalid pivot value '': value data type does not match pivot column data type ." @@ -3866,6 +4009,12 @@ ], "sqlState" : "42K08" }, + "SHOW_COLUMNS_WITH_CONFLICT_NAMESPACE" : { + "message" : [ + "SHOW COLUMNS with conflicting namespaces: != ." + ], + "sqlState" : "42K05" + }, "SORT_BY_WITHOUT_BUCKETING" : { "message" : [ "sortBy must be used together with bucketBy." @@ -4314,6 +4463,24 @@ ], "sqlState" : "428EK" }, + "TRANSPOSE_EXCEED_ROW_LIMIT" : { + "message" : [ + "Number of rows exceeds the allowed limit of for TRANSPOSE. If this was intended, set to at least the current row count." + ], + "sqlState" : "54006" + }, + "TRANSPOSE_INVALID_INDEX_COLUMN" : { + "message" : [ + "Invalid index column for TRANSPOSE because: " + ], + "sqlState" : "42804" + }, + "TRANSPOSE_NO_LEAST_COMMON_TYPE" : { + "message" : [ + "Transpose requires non-index columns to share a least common type, but and do not." + ], + "sqlState" : "42K09" + }, "UDTF_ALIAS_NUMBER_MISMATCH" : { "message" : [ "The number of aliases supplied in the AS clause does not match the number of columns output by the UDTF.", @@ -5187,6 +5354,11 @@ "" ] }, + "SCALAR_SUBQUERY_IN_VALUES" : { + "message" : [ + "Scalar subqueries in the VALUES clause." + ] + }, "UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION" : { "message" : [ "Correlated subqueries in the join predicate cannot reference both join inputs:", @@ -5685,11 +5857,6 @@ "ADD COLUMN with v1 tables cannot specify NOT NULL." ] }, - "_LEGACY_ERROR_TEMP_1057" : { - "message" : [ - "SHOW COLUMNS with conflicting databases: '' != ''." - ] - }, "_LEGACY_ERROR_TEMP_1058" : { "message" : [ "Cannot create table with both USING and ." @@ -6161,7 +6328,7 @@ "Detected implicit cartesian product for join between logical plans", "", "and", - "rightPlan", + "", "Join condition is missing or trivial.", "Either: use the CROSS JOIN syntax to allow cartesian products between these relations, or: enable implicit cartesian products by setting the configuration variable spark.sql.crossJoin.enabled=true." ] @@ -6524,21 +6691,6 @@ "Sinks cannot request distribution and ordering in continuous execution mode." ] }, - "_LEGACY_ERROR_TEMP_1344" : { - "message" : [ - "Invalid DEFAULT value for column : fails to parse as a valid literal value." - ] - }, - "_LEGACY_ERROR_TEMP_1345" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." - ] - }, - "_LEGACY_ERROR_TEMP_1346" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." - ] - }, "_LEGACY_ERROR_TEMP_2000" : { "message" : [ ". If necessary set to false to bypass this error." @@ -6554,11 +6706,6 @@ "Type does not support ordered operations." ] }, - "_LEGACY_ERROR_TEMP_2011" : { - "message" : [ - "Unexpected data type ." - ] - }, "_LEGACY_ERROR_TEMP_2013" : { "message" : [ "Negative values found in " @@ -7766,7 +7913,7 @@ }, "_LEGACY_ERROR_TEMP_3055" : { "message" : [ - "ScalarFunction '' neither implement magic method nor override 'produceResult'" + "ScalarFunction neither implement magic method nor override 'produceResult'" ] }, "_LEGACY_ERROR_TEMP_3056" : { @@ -8372,36 +8519,6 @@ "The number of fields () in the partition identifier is not equal to the partition schema length (). The identifier might not refer to one partition." ] }, - "_LEGACY_ERROR_TEMP_3209" : { - "message" : [ - "Illegal input for day of week: " - ] - }, - "_LEGACY_ERROR_TEMP_3210" : { - "message" : [ - "Interval string does not match second-nano format of ss.nnnnnnnnn" - ] - }, - "_LEGACY_ERROR_TEMP_3211" : { - "message" : [ - "Error parsing interval day-time string: " - ] - }, - "_LEGACY_ERROR_TEMP_3212" : { - "message" : [ - "Cannot support (interval '' to ) expression" - ] - }, - "_LEGACY_ERROR_TEMP_3213" : { - "message" : [ - "Error parsing interval string: " - ] - }, - "_LEGACY_ERROR_TEMP_3214" : { - "message" : [ - "Interval string does not match format of when cast to : " - ] - }, "_LEGACY_ERROR_TEMP_3215" : { "message" : [ "Expected a Boolean type expression in replaceNullWithFalse, but got the type in ." diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index c5c55f11a6aa8..87811fef9836e 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4625,6 +4625,12 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0N": { + "description": "Invalid external type.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", @@ -7411,6 +7417,12 @@ "standard": "N", "usedBy": ["Databricks"] }, + "KD010": { + "description": "external data source failure", + "origin": "Databricks", + "standard": "N", + "usedBy": ["Databricks"] + }, "P0000": { "description": "procedural logic error", "origin": "PostgreSQL", diff --git a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala index a1934dcf7a007..e2dd0da1aac85 100644 --- a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala +++ b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.net.URL -import scala.collection.immutable.Map import scala.jdk.CollectionConverters._ import com.fasterxml.jackson.annotation.JsonIgnore @@ -52,7 +51,7 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) { val sub = new StringSubstitutor(sanitizedParameters.asJava) sub.setEnableUndefinedVariableException(true) sub.setDisableSubstitutionInValues(true) - try { + val errorMessage = try { sub.replace(ErrorClassesJsonReader.TEMPLATE_REGEX.replaceAllIn( messageTemplate, "\\$\\{$1\\}")) } catch { @@ -61,6 +60,17 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) { s"MessageTemplate: $messageTemplate, " + s"Parameters: $messageParameters", i) } + if (util.SparkEnvUtils.isTesting) { + val placeHoldersNum = ErrorClassesJsonReader.TEMPLATE_REGEX.findAllIn(messageTemplate).length + if (placeHoldersNum < sanitizedParameters.size) { + throw SparkException.internalError( + s"Found unused message parameters of the error class '$errorClass'. " + + s"Its error message format has $placeHoldersNum placeholders, " + + s"but the passed message parameters map has ${sanitizedParameters.size} items. " + + "Consider to add placeholders to the error format or remove unused message parameters.") + } + } + errorMessage } def getMessageParameters(errorClass: String): Seq[String] = { diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index db5eff72e124a..428c9d2a49351 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -74,7 +74,7 @@ private[spark] object SparkThrowableHelper { } def isInternalError(errorClass: String): Boolean = { - errorClass.startsWith("INTERNAL_ERROR") + errorClass != null && errorClass.startsWith("INTERNAL_ERROR") } def getMessage(e: SparkThrowable with Throwable, format: ErrorMessageFormat.Value): String = { diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index a7e4f186000b5..12d456a371d07 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -266,6 +266,7 @@ private[spark] object LogKeys { case object FEATURE_NAME extends LogKey case object FETCH_SIZE extends LogKey case object FIELD_NAME extends LogKey + case object FIELD_TYPE extends LogKey case object FILES extends LogKey case object FILE_ABSOLUTE_PATH extends LogKey case object FILE_END_OFFSET extends LogKey @@ -652,6 +653,7 @@ private[spark] object LogKeys { case object RECEIVER_IDS extends LogKey case object RECORDS extends LogKey case object RECOVERY_STATE extends LogKey + case object RECURSIVE_DEPTH extends LogKey case object REDACTED_STATEMENT extends LogKey case object REDUCE_ID extends LogKey case object REGEX extends LogKey diff --git a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala index 7ffaef0855fd1..7471b764bd2b3 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/Logging.scala @@ -334,7 +334,7 @@ trait Logging { // If Log4j 2 is used but is initialized by default configuration, // load a default properties file // scalastyle:off println - if (Logging.islog4j2DefaultConfigured()) { + if (Logging.defaultSparkLog4jConfig || Logging.islog4j2DefaultConfigured()) { Logging.defaultSparkLog4jConfig = true val defaultLogProps = if (Logging.isStructuredLoggingEnabled) { "org/apache/spark/log4j2-defaults.properties" @@ -424,7 +424,6 @@ private[spark] object Logging { def uninitialize(): Unit = initLock.synchronized { if (isLog4j2()) { if (defaultSparkLog4jConfig) { - defaultSparkLog4jConfig = false val context = LogManager.getContext(false).asInstanceOf[LoggerContext] context.reconfigure() } else { diff --git a/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala new file mode 100644 index 0000000000000..1f1d3492d6ac5 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler + +import com.fasterxml.jackson.annotation.JsonTypeInfo + +import org.apache.spark.annotation.DeveloperApi + +@DeveloperApi +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") +trait SparkListenerEvent { + /* Whether output this event to the event log */ + protected[spark] def logEvent: Boolean = true +} diff --git a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala index 42a1d1612aeeb..d54a2f2ed9cea 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala @@ -342,7 +342,7 @@ private[spark] object MavenUtils extends Logging { } /* Set ivy settings for location of cache, if option is supplied */ - private def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { + private[util] def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { val alternateIvyDir = ivyPath.filterNot(_.trim.isEmpty).getOrElse { // To protect old Ivy-based systems like old Spark from Apache Ivy 2.5.2's incompatibility. System.getProperty("ivy.home", diff --git a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala index 76062074edcaf..140de836622f4 100644 --- a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala +++ b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala @@ -365,7 +365,7 @@ private[spark] object IvyTestUtils { useIvyLayout: Boolean = false, withPython: Boolean = false, withR: Boolean = false, - ivySettings: IvySettings = new IvySettings)(f: String => Unit): Unit = { + ivySettings: IvySettings = defaultIvySettings())(f: String => Unit): Unit = { val deps = dependencies.map(MavenUtils.extractMavenCoordinates) purgeLocalIvyCache(artifact, deps, ivySettings) val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, @@ -401,4 +401,16 @@ private[spark] object IvyTestUtils { } } } + + /** + * Creates and initializes a new instance of IvySettings with default configurations. + * The method processes the Ivy path argument using MavenUtils to ensure proper setup. + * + * @return A newly created and configured instance of IvySettings. + */ + private def defaultIvySettings(): IvySettings = { + val settings = new IvySettings + MavenUtils.processIvyPathArg(ivySettings = settings, ivyPath = None) + settings + } } diff --git a/common/variant/README.md b/common/variant/README.md index a66d708da75bf..4ed7c16f5b6ed 100644 --- a/common/variant/README.md +++ b/common/variant/README.md @@ -333,27 +333,27 @@ The Decimal type contains a scale, but no precision. The implied precision of a | Object | `2` | A collection of (string-key, variant-value) pairs | | Array | `3` | An ordered sequence of variant values | -| Primitive Type | Type ID | Equivalent Parquet Type | Binary format | -|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------| -| null | `0` | any | none | -| boolean (True) | `1` | BOOLEAN | none | -| boolean (False) | `2` | BOOLEAN | none | -| int8 | `3` | INT(8, signed) | 1 byte | -| int16 | `4` | INT(16, signed) | 2 byte little-endian | -| int32 | `5` | INT(32, signed) | 4 byte little-endian | -| int64 | `6` | INT(64, signed) | 8 byte little-endian | -| double | `7` | DOUBLE | IEEE little-endian | -| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | -| date | `11` | DATE | 4 byte little-endian | -| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | -| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | -| float | `14` | FLOAT | IEEE little-endian | -| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | -| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | -| year-month interval | `19` | INT(32, signed)1 | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | -| day-time interval | `20` | INT(64, signed)1 | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | +| Logical Type | Physical Type | Type ID | Equivalent Parquet Type | Binary format | +|----------------------|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------| +| NullType | null | `0` | any | none | +| Boolean | boolean (True) | `1` | BOOLEAN | none | +| Boolean | boolean (False) | `2` | BOOLEAN | none | +| Exact Numeric | int8 | `3` | INT(8, signed) | 1 byte | +| Exact Numeric | int16 | `4` | INT(16, signed) | 2 byte little-endian | +| Exact Numeric | int32 | `5` | INT(32, signed) | 4 byte little-endian | +| Exact Numeric | int64 | `6` | INT(64, signed) | 8 byte little-endian | +| Double | double | `7` | DOUBLE | IEEE little-endian | +| Exact Numeric | decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Exact Numeric | decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Exact Numeric | decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) | +| Date | date | `11` | DATE | 4 byte little-endian | +| Timestamp | timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian | +| TimestampNTZ | timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian | +| Float | float | `14` | FLOAT | IEEE little-endian | +| Binary | binary | `15` | BINARY | 4 byte little-endian size, followed by bytes | +| String | string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes | +| YMInterval | year-month interval | `19` | INT(32, signed)1 | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. | +| DTInterval | day-time interval | `20` | INT(64, signed)1 | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. | | Decimal Precision | Decimal value type | |-----------------------|--------------------| @@ -362,6 +362,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a | 18 <= precision <= 38 | int128 | | > 38 | Not supported | +The *Logical Type* column indicates logical equivalence of physically encoded types. For example, a user expression operating on a string value containing "hello" should behave the same, whether it is encoded with the short string optimization, or long string encoding. Similarly, user expressions operating on an *int8* value of 1 should behave the same as a decimal16 with scale 2 and unscaled value 100. + The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND. Type IDs 17 and 18 were originally reserved for a prototype feature (string-from-metadata) that was never implemented. These IDs are available for use by new types. diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index f5e5f729459f7..375d69034fd31 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -26,10 +26,7 @@ import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; +import java.util.*; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParser; @@ -43,24 +40,29 @@ * Build variant value and metadata by parsing JSON values. */ public class VariantBuilder { + public VariantBuilder(boolean allowDuplicateKeys) { + this.allowDuplicateKeys = allowDuplicateKeys; + } + /** * Parse a JSON string as a Variant value. * @throws VariantSizeLimitException if the resulting variant value or metadata would exceed * the SIZE_LIMIT (for example, this could be a maximum of 16 MiB). * @throws IOException if any JSON parsing error happens. */ - public static Variant parseJson(String json) throws IOException { + public static Variant parseJson(String json, boolean allowDuplicateKeys) throws IOException { try (JsonParser parser = new JsonFactory().createParser(json)) { parser.nextToken(); - return parseJson(parser); + return parseJson(parser, allowDuplicateKeys); } } /** - * Similar {@link #parseJson(String)}, but takes a JSON parser instead of string input. + * Similar {@link #parseJson(String, boolean)}, but takes a JSON parser instead of string input. */ - public static Variant parseJson(JsonParser parser) throws IOException { - VariantBuilder builder = new VariantBuilder(); + public static Variant parseJson(JsonParser parser, boolean allowDuplicateKeys) + throws IOException { + VariantBuilder builder = new VariantBuilder(allowDuplicateKeys); builder.buildJson(parser); return builder.result(); } @@ -274,23 +276,63 @@ public int getWritePos() { // record the offset of the field. The offset is computed as `getWritePos() - start`. // 3. The caller calls `finishWritingObject` to finish writing a variant object. // - // This function is responsible to sort the fields by key and check for any duplicate field keys. + // This function is responsible to sort the fields by key. If there are duplicate field keys: + // - when `allowDuplicateKeys` is true, the field with the greatest offset value (the last + // appended one) is kept. + // - otherwise, throw an exception. public void finishWritingObject(int start, ArrayList fields) { - int dataSize = writePos - start; int size = fields.size(); Collections.sort(fields); int maxId = size == 0 ? 0 : fields.get(0).id; - // Check for duplicate field keys. Only need to check adjacent key because they are sorted. - for (int i = 1; i < size; ++i) { - maxId = Math.max(maxId, fields.get(i).id); - String key = fields.get(i).key; - if (key.equals(fields.get(i - 1).key)) { - @SuppressWarnings("unchecked") - Map parameters = Map$.MODULE$.empty().updated("key", key); - throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters, - null, new QueryContext[]{}, ""); + if (allowDuplicateKeys) { + int distinctPos = 0; + // Maintain a list of distinct keys in-place. + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + if (fields.get(i).id == fields.get(i - 1).id) { + // Found a duplicate key. Keep the field with a greater offset. + if (fields.get(distinctPos).offset < fields.get(i).offset) { + fields.set(distinctPos, fields.get(distinctPos).withNewOffset(fields.get(i).offset)); + } + } else { + // Found a distinct key. Add the field to the list. + ++distinctPos; + fields.set(distinctPos, fields.get(i)); + } + } + if (distinctPos + 1 < fields.size()) { + size = distinctPos + 1; + // Resize `fields` to `size`. + fields.subList(size, fields.size()).clear(); + // Sort the fields by offsets so that we can move the value data of each field to the new + // offset without overwriting the fields after it. + fields.sort(Comparator.comparingInt(f -> f.offset)); + int currentOffset = 0; + for (int i = 0; i < size; ++i) { + int oldOffset = fields.get(i).offset; + int fieldSize = VariantUtil.valueSize(writeBuffer, start + oldOffset); + System.arraycopy(writeBuffer, start + oldOffset, + writeBuffer, start + currentOffset, fieldSize); + fields.set(i, fields.get(i).withNewOffset(currentOffset)); + currentOffset += fieldSize; + } + writePos = start + currentOffset; + // Change back to the sort order by field keys to meet the variant spec. + Collections.sort(fields); + } + } else { + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + String key = fields.get(i).key; + if (key.equals(fields.get(i - 1).key)) { + @SuppressWarnings("unchecked") + Map parameters = Map$.MODULE$.empty().updated("key", key); + throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters, + null, new QueryContext[]{}, ""); + } } } + int dataSize = writePos - start; boolean largeSize = size > U8_MAX; int sizeBytes = largeSize ? U32_SIZE : 1; int idSize = getIntegerSize(maxId); @@ -415,6 +457,10 @@ public FieldEntry(String key, int id, int offset) { this.offset = offset; } + FieldEntry withNewOffset(int newOffset) { + return new FieldEntry(key, id, newOffset); + } + @Override public int compareTo(FieldEntry other) { return key.compareTo(other.key); @@ -518,4 +564,5 @@ private boolean tryParseDecimal(String input) { private final HashMap dictionary = new HashMap<>(); // Store all keys in `dictionary` in the order of id. private final ArrayList dictionaryKeys = new ArrayList<>(); + private final boolean allowDuplicateKeys; } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 7d80998d96eb1..0b85b208242cb 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -42,7 +42,8 @@ private[sql] case class AvroDataToCatalyst( val dt = SchemaConverters.toSqlType( expectedSchema, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType).dataType + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. @@ -69,7 +70,8 @@ private[sql] case class AvroDataToCatalyst( dataType, avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 877c3f89e88c0..ac20614553ca2 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -51,14 +51,16 @@ private[sql] class AvroDeserializer( datetimeRebaseSpec: RebaseSpec, filters: StructFilters, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) { def this( rootAvroType: Schema, rootCatalystType: DataType, datetimeRebaseMode: String, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) = { this( rootAvroType, rootCatalystType, @@ -66,7 +68,8 @@ private[sql] class AvroDeserializer( RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), new NoopFilters, useStableIdForUnionType, - stableIdPrefixForUnionType) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) } private lazy val decimalConversions = new DecimalConversion() @@ -128,7 +131,8 @@ private[sql] class AvroDeserializer( s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" val realDataType = SchemaConverters.toSqlType( - avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType + avroType, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth).dataType (avroType.getType, catalystType) match { case (NULL, NullType) => (updater, ordinal, _) => diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 372f24b54f5c4..264c3a1f48abe 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -145,7 +145,8 @@ private[sql] class AvroFileFormat extends FileFormat datetimeRebaseMode, avroFilters, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType) + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 4332904339f19..e0c6ad3ee69d3 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** @@ -136,6 +137,15 @@ private[sql] class AvroOptions( val stableIdPrefixForUnionType: String = parameters .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_") + + val recursiveFieldMaxDepth: Int = + parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1) + + if (recursiveFieldMaxDepth > RECURSIVE_FIELD_MAX_DEPTH_LIMIT) { + throw QueryCompilationErrors.avroOptionsException( + RECURSIVE_FIELD_MAX_DEPTH, + s"Should not be greater than $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.") + } } private[sql] object AvroOptions extends DataSourceOptions { @@ -170,4 +180,25 @@ private[sql] object AvroOptions extends DataSourceOptions { // When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields // of Avro Union type. val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType") + + /** + * Adds support for recursive fields. If this option is not specified or is set to 0, recursive + * fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive + * fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. + * Values larger than 15 are not allowed in order to avoid inadvertently creating very large + * schemas. If an avro message has depth beyond this limit, the Spark struct returned is + * truncated after the recursion limit. + * + * Examples: Consider an Avro schema with a recursive field: + * {"type" : "record", "name" : "Node", "fields" : [{"name": "Id", "type": "int"}, + * {"name": "Next", "type": ["null", "Node"]}]} + * The following lists the parsed schema with different values for this setting. + * 1: `struct` + * 2: `struct>` + * 3: `struct>>` + * and so on. + */ + val RECURSIVE_FIELD_MAX_DEPTH = newOption("recursiveFieldMaxDepth") + + val RECURSIVE_FIELD_MAX_DEPTH_LIMIT: Int = 15 } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 7cbc30f1fb3dc..594ebb4716c41 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -65,7 +65,8 @@ private[sql] object AvroUtils extends Logging { SchemaConverters.toSqlType( avroSchema, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType).dataType match { + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth).dataType match { case t: StructType => Some(t) case _ => throw new RuntimeException( s"""Avro schema cannot be converted to a Spark SQL StructType: diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index b2285aa966ddb..1168a887abd8e 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -27,6 +27,10 @@ import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalT import org.apache.avro.Schema.Type._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{FIELD_NAME, FIELD_TYPE, RECURSIVE_DEPTH} +import org.apache.spark.internal.MDC +import org.apache.spark.sql.avro.AvroOptions.RECURSIVE_FIELD_MAX_DEPTH_LIMIT import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ import org.apache.spark.sql.types.Decimal.minBytesForPrecision @@ -36,7 +40,7 @@ import org.apache.spark.sql.types.Decimal.minBytesForPrecision * versa. */ @DeveloperApi -object SchemaConverters { +object SchemaConverters extends Logging { private lazy val nullSchema = Schema.create(Schema.Type.NULL) /** @@ -48,14 +52,27 @@ object SchemaConverters { /** * Converts an Avro schema to a corresponding Spark SQL schema. - * + * + * @param avroSchema The Avro schema to convert. + * @param useStableIdForUnionType If true, Avro schema is deserialized into Spark SQL schema, + * and the Avro Union type is transformed into a structure where + * the field names remain consistent with their respective types. + * @param stableIdPrefixForUnionType The prefix to use to configure the prefix for fields of + * Avro Union type + * @param recursiveFieldMaxDepth The maximum depth to recursively process fields in Avro schema. + * -1 means not supported. * @since 4.0.0 */ def toSqlType( avroSchema: Schema, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType) + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int = -1): SchemaType = { + val schema = toSqlTypeHelper(avroSchema, Map.empty, useStableIdForUnionType, + stableIdPrefixForUnionType, recursiveFieldMaxDepth) + // the top level record should never return null + assert(schema != null) + schema } /** * Converts an Avro schema to a corresponding Spark SQL schema. @@ -63,17 +80,17 @@ object SchemaConverters { * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlType(avroSchema, false, "") + toSqlType(avroSchema, false, "", -1) } @deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0") def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { val avroOptions = AvroOptions(options) - toSqlTypeHelper( + toSqlType( avroSchema, - Set.empty, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) } // The property specifies Catalyst type of the given field @@ -81,9 +98,10 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, - existingRecordNames: Set[String], + existingRecordNames: Map[String, Int], useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -128,62 +146,110 @@ object SchemaConverters { case NULL => SchemaType(NullType, nullable = true) case RECORD => - if (existingRecordNames.contains(avroSchema.getFullName)) { + val recursiveDepth: Int = existingRecordNames.getOrElse(avroSchema.getFullName, 0) + if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) { throw new IncompatibleSchemaException(s""" - |Found recursive reference in Avro schema, which can not be processed by Spark: - |${avroSchema.toString(true)} + |Found recursive reference in Avro schema, which can not be processed by Spark by + | default: ${avroSchema.toString(true)}. Try setting the option `recursiveFieldMaxDepth` + | to 1 - $RECURSIVE_FIELD_MAX_DEPTH_LIMIT. """.stripMargin) - } - val newRecordNames = existingRecordNames + avroSchema.getFullName - val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper( - f.schema(), - newRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType) - StructField(f.name, schemaType.dataType, schemaType.nullable) - } + } else if (recursiveDepth > 0 && recursiveDepth >= recursiveFieldMaxDepth) { + logInfo( + log"The field ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} is dropped at recursive depth " + + log"${MDC(RECURSIVE_DEPTH, recursiveDepth)}." + ) + null + } else { + val newRecordNames = + existingRecordNames + (avroSchema.getFullName -> (recursiveDepth + 1)) + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlTypeHelper( + f.schema(), + newRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null + } + else { + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) + } case ARRAY => val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - SchemaType( - ArrayType(schemaType.dataType, containsNull = schemaType.nullable), - nullable = false) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + } case MAP => val schemaType = toSqlTypeHelper(avroSchema.getValueType, - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) - SchemaType( - MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), - nullable = false) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + } case UNION => if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) - if (remainingUnionTypes.size == 1) { - toSqlTypeHelper( - remainingUnionTypes.head, - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + val remainingSchema = + if (remainingUnionTypes.size == 1) { + remainingUnionTypes.head + } else { + Schema.createUnion(remainingUnionTypes.asJava) + } + val schemaType = toSqlTypeHelper( + remainingSchema, + existingRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null } else { - toSqlTypeHelper( - Schema.createUnion(remainingUnionTypes.asJava), - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + schemaType.copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => toSqlTypeHelper(avroSchema.getTypes.get(0), - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -201,29 +267,33 @@ object SchemaConverters { s, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - - val fieldName = if (useStableIdForUnionType) { - // Avro's field name may be case sensitive, so field names for two named type - // could be "a" and "A" and we need to distinguish them. In this case, we throw - // an exception. - // Stable id prefix can be empty so the name of the field can be just the type. - val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" - if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { - throw new IncompatibleSchemaException( - "Cannot generate stable identifier for Avro union type due to name " + - s"conflict of type name ${s.getName}") - } - tempFieldName + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null } else { - s"member$i" - } + val fieldName = if (useStableIdForUnionType) { + // Avro's field name may be case sensitive, so field names for two named type + // could be "a" and "A" and we need to distinguish them. In this case, we throw + // an exception. + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" + if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { + throw new IncompatibleSchemaException( + "Cannot generate stable identifier for Avro union type due to name " + + s"conflict of type name ${s.getName}") + } + tempFieldName + } else { + s"member$i" + } - // All fields are nullable because only one of them is set at a time - StructField(fieldName, schemaType.dataType, nullable = true) - } + // All fields are nullable because only one of them is set at a time + StructField(fieldName, schemaType.dataType, nullable = true) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 1083c99160724..a13faf3b51560 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -105,7 +105,8 @@ case class AvroPartitionReaderFactory( datetimeRebaseMode, avroFilters, options.useStableIdForUnionType, - options.stableIdPrefixForUnionType) + options.stableIdPrefixForUnionType, + options.recursiveFieldMaxDepth) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala index fe61fe3db8786..8ec711b2757f5 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -37,7 +37,7 @@ case class AvroTable( fallbackFileFormat: Class[_ <: FileFormat]) extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { override def newScanBuilder(options: CaseInsensitiveStringMap): AvroScanBuilder = - new AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options)) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 388347537a4d6..311eda3a1b6ae 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -291,7 +291,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite RebaseSpec(LegacyBehaviorPolicy.CORRECTED), filters, false, - "") + "", + -1) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala index 256b608feaa1f..0db9d284c4512 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala @@ -54,7 +54,7 @@ class AvroCodecSuite extends FileSourceCodecSuite { s"""CREATE TABLE avro_t |USING $format OPTIONS('compression'='unsupported') |AS SELECT 1 as id""".stripMargin)), - errorClass = "CODEC_SHORT_NAME_NOT_FOUND", + condition = "CODEC_SHORT_NAME_NOT_FOUND", sqlState = Some("42704"), parameters = Map("codecName" -> "unsupported") ) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 432c3fa9be3ac..a7f7abadcf485 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} class AvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -329,7 +329,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select to_avro(s, 42) as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map("sqlExpr" -> "\"to_avro(s, 42)\"", "msg" -> ("The second argument of the TO_AVRO SQL function must be a constant string " + "containing the JSON representation of the schema to use for converting the value to " + @@ -344,7 +344,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select from_avro(s, 42, '') as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map("sqlExpr" -> "\"from_avro(s, 42, )\"", "msg" -> ("The second argument of the FROM_AVRO SQL function must be a constant string " + "containing the JSON representation of the schema to use for converting the value " + @@ -359,7 +359,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { s""" |select from_avro(s, '$jsonFormatSchema', 42) as result from t |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"\"from_avro(s, $jsonFormatSchema, 42)\"".stripMargin, @@ -374,6 +374,37 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { } } + + test("roundtrip in to_avro and from_avro - recursive schema") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4))), Row(1, null))), + catalystSchema).select(struct("Id", "Name").as("struct")) + + val avroStructDF = df.select(functions.to_avro($"struct", avroSchema).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroSchema, Map( + "recursiveFieldMaxDepth" -> "3").asJava)), df) + } + private def serialize(record: GenericRecord, avroSchema: String): Array[Byte] = { val schema = new Schema.Parser().parse(avroSchema) val datumWriter = new GenericDatumWriter[GenericRecord](schema) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 429f3c0deca6a..751ac275e048a 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -439,7 +439,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkArithmeticException], - errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", + condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", parameters = Map( "value" -> "0", "precision" -> "4", diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 9b3bb929a700d..c1ab96a63eb26 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -77,7 +77,8 @@ class AvroRowReaderSuite RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index cbcbc2e7e76a6..3643a95abe19c 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -228,7 +228,8 @@ object AvroSerdeSuite { RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index b20ee4b3cc231..be887bd5237b0 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -894,7 +894,7 @@ abstract class AvroSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "decimal\\(12,10\\)", @@ -972,7 +972,7 @@ abstract class AvroSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval day to second", @@ -1009,7 +1009,7 @@ abstract class AvroSuite assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], - errorClass = "AVRO_INCOMPATIBLE_READ_TYPE", + condition = "AVRO_INCOMPATIBLE_READ_TYPE", parameters = Map("avroPath" -> "field 'a'", "sqlPath" -> "field 'a'", "avroType" -> "interval year to month", @@ -1673,7 +1673,7 @@ abstract class AvroSuite exception = intercept[AnalysisException] { sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir) }, - errorClass = "_LEGACY_ERROR_TEMP_1136", + condition = "_LEGACY_ERROR_TEMP_1136", parameters = Map.empty ) checkError( @@ -1681,7 +1681,7 @@ abstract class AvroSuite spark.udf.register("testType", () => new IntervalData()) sql("select testType()").write.format("avro").mode("overwrite").save(tempDir) }, - errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", parameters = Map( "columnName" -> "`testType()`", "columnType" -> "UDT(\"INTERVAL\")", @@ -2220,7 +2220,8 @@ abstract class AvroSuite } } - private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { + private def checkSchemaWithRecursiveLoop(avroSchema: String, recursiveFieldMaxDepth: Int): + Unit = { val message = intercept[IncompatibleSchemaException] { SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "") }.getMessage @@ -2229,7 +2230,79 @@ abstract class AvroSuite } test("Detect recursive loop") { - checkSchemaWithRecursiveLoop(""" + for (recursiveFieldMaxDepth <- Seq(-1, 0)) { + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, // each element has a long + | {"name": "next", "type": ["null", "LongList"]} // optional next element + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields": [ + | { + | "name": "value", + | "type": { + | "type": "record", + | "name": "foo", + | "fields": [ + | { + | "name": "parent", + | "type": "LongList" + | } + | ] + | } + | } + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "array", "type": {"type": "array", "items": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "map", "type": {"type": "map", "values": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + } + } + + private def checkSparkSchemaEquals( + avroSchema: String, expectedSchema: StructType, recursiveFieldMaxDepth: Int): Unit = { + val sparkSchema = + SchemaConverters.toSqlType( + new Schema.Parser().parse(avroSchema), false, "", recursiveFieldMaxDepth).dataType + + assert(sparkSchema === expectedSchema) + } + + test("Translate recursive schema - union") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2238,9 +2311,57 @@ abstract class AvroSuite | {"name": "next", "type": ["null", "LongList"]} // optional next element | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("next", expectedSchema) + } + } + + test("Translate recursive schema - union - 2 non-null fields") { + val avroSchema = """ + |{ + | "type": "record", + | "name": "TreeNode", + | "fields": [ + | { + | "name": "name", + | "type": "string" + | }, + | { + | "name": "value", + | "type": [ + | "long" + | ] + | }, + | { + | "name": "children", + | "type": [ + | "null", + | { + | "type": "array", + | "items": "TreeNode" + | } + | ], + | "default": null + | } + | ] + |} + """.stripMargin + val nonRecursiveFields = new StructType().add("name", StringType, nullable = false) + .add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("children", + new ArrayType(expectedSchema, false), nullable = true) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - record") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2260,9 +2381,18 @@ abstract class AvroSuite | } | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", StructType(Seq()), nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = new StructType().add("value", + new StructType().add("parent", expectedSchema, nullable = false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - array") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2271,9 +2401,18 @@ abstract class AvroSuite | {"name": "array", "type": {"type": "array", "items": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("array", new ArrayType(expectedSchema, false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - map") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2282,7 +2421,70 @@ abstract class AvroSuite | {"name": "map", "type": {"type": "map", "values": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("map", + new MapType(StringType, expectedSchema, false), nullable = false) + } + } + + test("recursive schema integration test") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", NullType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4, null))), Row(1, null))), + catalystSchema) + + withTempPath { tempDir => + df.write.format("avro").save(tempDir.getPath) + + val exc = intercept[AnalysisException] { + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 16) + .load(tempDir.getPath) + } + assert(exc.getMessage.contains("Should not be greater than 15.")) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 10) + .load(tempDir.getPath), + df) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 1) + .load(tempDir.getPath), + df.select("Id")) + } } test("log a warning of ignoreExtension deprecation") { @@ -2726,7 +2928,7 @@ abstract class AvroSuite |LOCATION '${dir}' |AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) @@ -2777,7 +2979,7 @@ abstract class AvroSuite } test("SPARK-40667: validate Avro Options") { - assert(AvroOptions.getAllOptions.size == 11) + assert(AvroOptions.getAllOptions.size == 12) // Please add validation on any new Avro options here assert(AvroOptions.isValidOption("ignoreExtension")) assert(AvroOptions.isValidOption("mode")) @@ -2790,6 +2992,7 @@ abstract class AvroSuite assert(AvroOptions.isValidOption("datetimeRebaseMode")) assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType")) assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType")) + assert(AvroOptions.isValidOption("recursiveFieldMaxDepth")) } test("SPARK-46633: read file with empty blocks") { @@ -2831,7 +3034,7 @@ class AvroV1Suite extends AvroSuite { sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite) .format("avro").save(dir.getCanonicalPath) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) @@ -2844,7 +3047,7 @@ class AvroV1Suite extends AvroSuite { .write.mode(SaveMode.Overwrite) .format("avro").save(dir.getCanonicalPath) }, - errorClass = "INVALID_COLUMN_NAME_AS_PATH", + condition = "INVALID_COLUMN_NAME_AS_PATH", parameters = Map( "datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`") ) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 7d484d82ec25c..3777f82594aae 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -17,118 +17,27 @@ package org.apache.spark.sql -import java.util.Locale - import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{NAReplace, Relation} import org.apache.spark.connect.proto.Expression.{Literal => GLiteral} import org.apache.spark.connect.proto.NAReplace.Replacement -import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.sql.connect.ConnectConversions._ /** * Functionality for working with missing data in `DataFrame`s. * * @since 3.4.0 */ -final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) { +final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) + extends api.DataFrameNaFunctions { import sparkSession.RichColumn - /** - * Returns a new `DataFrame` that drops rows containing any null or NaN values. - * - * @since 3.4.0 - */ - def drop(): DataFrame = buildDropDataFrame(None, None) - - /** - * Returns a new `DataFrame` that drops rows containing null or NaN values. - * - * If `how` is "any", then drop rows containing any null or NaN values. If `how` is "all", then - * drop rows only if every column is null or NaN for that row. - * - * @since 3.4.0 - */ - def drop(how: String): DataFrame = { - buildDropDataFrame(None, buildMinNonNulls(how)) - } - - /** - * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified - * columns. - * - * @since 3.4.0 - */ - def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values - * in the specified columns. - * - * @since 3.4.0 - */ - def drop(cols: Seq[String]): DataFrame = buildDropDataFrame(Some(cols), None) - - /** - * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified - * columns. - * - * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. - * - * @since 3.4.0 - */ - def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in - * the specified columns. - * - * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. - * - * @since 3.4.0 - */ - def drop(how: String, cols: Seq[String]): DataFrame = { - buildDropDataFrame(Some(cols), buildMinNonNulls(how)) - } - - /** - * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and - * non-NaN values. - * - * @since 3.4.0 - */ - def drop(minNonNulls: Int): DataFrame = { - buildDropDataFrame(None, Some(minNonNulls)) - } + override protected def drop(minNonNulls: Option[Int]): Dataset[Row] = + buildDropDataFrame(None, minNonNulls) - /** - * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and - * non-NaN values in the specified columns. - * - * @since 3.4.0 - */ - def drop(minNonNulls: Int, cols: Array[String]): DataFrame = - drop(minNonNulls, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than `minNonNulls` - * non-null and non-NaN values in the specified columns. - * - * @since 3.4.0 - */ - def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { - buildDropDataFrame(Some(cols), Some(minNonNulls)) - } - - private def buildMinNonNulls(how: String): Option[Int] = { - how.toLowerCase(Locale.ROOT) match { - case "any" => None // No-Op. Do nothing. - case "all" => Some(1) - case _ => throw new IllegalArgumentException(s"how ($how) must be 'any' or 'all'") - } - } + override protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] = + buildDropDataFrame(Option(cols), minNonNulls) private def buildDropDataFrame( cols: Option[Seq[String]], @@ -140,110 +49,42 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: } } - /** - * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Long): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setLong(value).build()) } - /** - * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a - * specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric columns. If a specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Long, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setLong(value).build()) } - /** - * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Double): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setDouble(value).build()) } - /** - * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a - * specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified - * numeric columns. If a specified column is not a numeric column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Double, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setDouble(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in string columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: String): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setString(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in specified string columns. If a - * specified column is not a string column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: String, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified string - * columns. If a specified column is not a string column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: String, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setString(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Boolean): DataFrame = { buildFillDataFrame(None, GLiteral.newBuilder().setBoolean(value).build()) } - /** - * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a - * specified column is not a boolean column, it is ignored. - * - * @since 3.4.0 - */ - def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value, cols.toImmutableArraySeq) - - /** - * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean - * columns. If a specified column is not a boolean column, it is ignored. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def fill(value: Boolean, cols: Seq[String]): DataFrame = { buildFillDataFrame(Some(cols), GLiteral.newBuilder().setBoolean(value).build()) } @@ -256,43 +97,7 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: } } - /** - * Returns a new `DataFrame` that replaces null values. - * - * The key of the map is the column name, and the value of the map is the replacement value. The - * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`, - * `Boolean`. Replacement values are cast to the column data type. - * - * For example, the following replaces null values in column "A" with string "unknown", and null - * values in column "B" with numeric value 1.0. - * {{{ - * import com.google.common.collect.ImmutableMap; - * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); - * }}} - * - * @since 3.4.0 - */ - def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq) - - /** - * Returns a new `DataFrame` that replaces null values. - * - * The key of the map is the column name, and the value of the map is the replacement value. The - * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`, - * `Boolean`. Replacement values are cast to the column data type. - * - * For example, the following replaces null values in column "A" with string "unknown", and null - * values in column "B" with numeric value 1.0. - * {{{ - * import com.google.common.collect.ImmutableMap; - * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0)); - * }}} - * - * @since 3.4.0 - */ - def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq) - - private def fillMap(values: Seq[(String, Any)]): DataFrame = { + protected def fillMap(values: Seq[(String, Any)]): DataFrame = { sparkSession.newDataFrame { builder => val fillNaBuilder = builder.getFillNaBuilder.setInput(root) values.map { case (colName, replaceValue) => @@ -301,104 +106,13 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: } } - /** - * Replaces values matching keys in `replacement` map with the corresponding values. - * - * {{{ - * import com.google.common.collect.ImmutableMap; - * - * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.na.replace("height", ImmutableMap.of(1.0, 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); - * }}} - * - * @param col - * name of the column to apply the value replacement. If `col` is "*", replacement is applied - * on all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ - def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = - replace(col, replacement.asScala.toMap) - - /** - * (Scala-specific) Replaces values matching keys in `replacement` map. - * - * {{{ - * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.na.replace("height", Map(1.0 -> 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.na.replace("name", Map("UNKNOWN" -> "unnamed")); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); - * }}} - * - * @param col - * name of the column to apply the value replacement. If `col` is "*", replacement is applied - * on all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ + /** @inheritdoc */ def replace[T](col: String, replacement: Map[T, T]): DataFrame = { val cols = if (col != "*") Some(Seq(col)) else None buildReplaceDataFrame(cols, buildReplacement(replacement)) } - /** - * Replaces values matching keys in `replacement` map with the corresponding values. - * - * {{{ - * import com.google.common.collect.ImmutableMap; - * - * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); - * }}} - * - * @param cols - * list of columns to apply the value replacement. If `col` is "*", replacement is applied on - * all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ - def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { - replace(cols.toImmutableArraySeq, replacement.asScala.toMap) - } - - /** - * (Scala-specific) Replaces values matching keys in `replacement` map. - * - * {{{ - * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); - * - * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); - * }}} - * - * @param cols - * list of columns to apply the value replacement. If `col` is "*", replacement is applied on - * all string, numeric or boolean columns. - * @param replacement - * value replacement map. Key and value of `replacement` map must have the same type, and can - * only be doubles, strings or booleans. The map value can have nulls. - * @since 3.4.0 - */ + /** @inheritdoc */ def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = { buildReplaceDataFrame(Some(cols), buildReplacement(replacement)) } @@ -441,4 +155,59 @@ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: case v => throw new IllegalArgumentException(s"Unsupported value type ${v.getClass.getName} ($v).") } + + /** @inheritdoc */ + override def drop(): DataFrame = super.drop() + + /** @inheritdoc */ + override def drop(cols: Array[String]): DataFrame = super.drop(cols) + + /** @inheritdoc */ + override def drop(cols: Seq[String]): DataFrame = super.drop(cols) + + /** @inheritdoc */ + override def drop(how: String, cols: Array[String]): DataFrame = super.drop(how, cols) + + /** @inheritdoc */ + override def drop(minNonNulls: Int, cols: Array[String]): DataFrame = + super.drop(minNonNulls, cols) + + /** @inheritdoc */ + override def drop(how: String): DataFrame = super.drop(how) + + /** @inheritdoc */ + override def drop(how: String, cols: Seq[String]): DataFrame = super.drop(how, cols) + + /** @inheritdoc */ + override def drop(minNonNulls: Int): DataFrame = super.drop(minNonNulls) + + /** @inheritdoc */ + override def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = + super.drop(minNonNulls, cols) + + /** @inheritdoc */ + override def fill(value: Long, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(value: Double, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(value: String, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(value: Boolean, cols: Array[String]): DataFrame = super.fill(value, cols) + + /** @inheritdoc */ + override def fill(valueMap: java.util.Map[String, Any]): DataFrame = super.fill(valueMap) + + /** @inheritdoc */ + override def fill(valueMap: Map[String, Any]): DataFrame = super.fill(valueMap) + + /** @inheritdoc */ + override def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = + super.replace[T](col, replacement) + + /** @inheritdoc */ + override def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = + super.replace(cols, replacement) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1ad98dc91b216..60bacd4e18ede 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -23,11 +23,8 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.connect.proto.Parse.ParseFormat -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter -import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.StructType /** @@ -37,144 +34,44 @@ import org.apache.spark.sql.types.StructType * @since 3.4.0 */ @Stable -class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging { - - /** - * Specifies the input data source format. - * - * @since 3.4.0 - */ - def format(source: String): DataFrameReader = { - this.source = source - this - } +class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader { + type DS[U] = Dataset[U] - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can skip - * the schema inference step, and thus speed up data loading. - * - * @since 3.4.0 - */ - def schema(schema: StructType): DataFrameReader = { - if (schema != null) { - val replaced = SparkCharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - this.userSpecifiedSchema = Option(replaced) - } - this - } + /** @inheritdoc */ + override def format(source: String): this.type = super.format(source) - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) - * can infer the input schema automatically from data. By specifying the schema here, the - * underlying data source can skip the schema inference step, and thus speed up data loading. - * - * {{{ - * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv") - * }}} - * - * @since 3.4.0 - */ - def schema(schemaString: String): DataFrameReader = { - schema(StructType.fromDDL(schemaString)) - } + /** @inheritdoc */ + override def schema(schema: StructType): this.type = super.schema(schema) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: String): DataFrameReader = { - this.extraOptions = this.extraOptions + (key -> value) - this - } + /** @inheritdoc */ + override def schema(schemaString: String): this.type = super.schema(schemaString) - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Long): DataFrameReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def option(key: String, value: Double): DataFrameReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def options(options: scala.collection.Map[String, String]): DataFrameReader = { - this.extraOptions ++= options - this - } + /** @inheritdoc */ + override def option(key: String, value: String): this.type = super.option(key, value) - /** - * Adds input options for the underlying data source. - * - * All options are maintained in a case-insensitive way in terms of key names. If a new option - * has the same key case-insensitively, it will override the existing option. - * - * @since 3.4.0 - */ - def options(options: java.util.Map[String, String]): DataFrameReader = { - this.options(options.asScala) - this - } + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) - /** - * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external - * key-value stores). - * - * @since 3.4.0 - */ - def load(): DataFrame = { - load(Seq.empty: _*) // force invocation of `load(...varargs...)` - } + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) - /** - * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a - * local or distributed file system). - * - * @since 3.4.0 - */ - def load(path: String): DataFrame = { - // force invocation of `load(...varargs...)` - load(Seq(path): _*) - } + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = + super.options(options) - /** - * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if - * the source is a HadoopFsRelationProvider. - * - * @since 3.4.0 - */ + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = super.options(options) + + /** @inheritdoc */ + override def load(): DataFrame = load(Nil: _*) + + /** @inheritdoc */ + def load(path: String): DataFrame = load(Seq(path): _*) + + /** @inheritdoc */ @scala.annotation.varargs def load(paths: String*): DataFrame = { sparkSession.newDataFrame { builder => @@ -190,93 +87,29 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table and connection properties. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ - def jdbc(url: String, table: String, properties: Properties): DataFrame = { - // properties should override settings in extraOptions. - this.extraOptions ++= properties.asScala - // explicit url and dbtable should override all - this.extraOptions ++= Seq("url" -> url, "dbtable" -> table) - format("jdbc").load() - } + /** @inheritdoc */ + override def jdbc(url: String, table: String, properties: Properties): DataFrame = + super.jdbc(url, table, properties) - // scalastyle:off line.size.limit - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table. Partitions of the table will be retrieved in parallel based on the parameters passed - * to this function. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @param table - * Name of the table in the external database. - * @param columnName - * Alias of `partitionColumn` option. Refer to `partitionColumn` in - * Data Source Option in the version you use. - * @param connectionProperties - * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least - * a "user" and "password" property should be included. "fetchsize" can be used to control the - * number of rows per fetch and "queryTimeout" can be used to wait for a Statement object to - * execute to the given number of seconds. - * @since 3.4.0 - */ - // scalastyle:on line.size.limit - def jdbc( + /** @inheritdoc */ + override def jdbc( url: String, table: String, columnName: String, lowerBound: Long, upperBound: Long, numPartitions: Int, - connectionProperties: Properties): DataFrame = { - // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. - this.extraOptions ++= Map( - "partitionColumn" -> columnName, - "lowerBound" -> lowerBound.toString, - "upperBound" -> upperBound.toString, - "numPartitions" -> numPartitions.toString) - jdbc(url, table, connectionProperties) - } - - /** - * Construct a `DataFrame` representing the database table accessible via JDBC URL url named - * table using connection properties. The `predicates` parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`. - * - * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash - * your external database systems. - * - * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC - * in - * Data Source Option in the version you use. - * - * @param table - * Name of the table in the external database. - * @param predicates - * Condition in the where clause for each partition. - * @param connectionProperties - * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least - * a "user" and "password" property should be included. "fetchsize" can be used to control the - * number of rows per fetch. - * @since 3.4.0 - */ + connectionProperties: Properties): DataFrame = + super.jdbc( + url, + table, + columnName, + lowerBound, + upperBound, + numPartitions, + connectionProperties) + + /** @inheritdoc */ def jdbc( url: String, table: String, @@ -296,207 +129,56 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Loads a JSON file and returns the results as a `DataFrame`. - * - * See the documentation on the overloaded `json()` method with varargs for more details. - * - * @since 3.4.0 - */ - def json(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - json(Seq(path): _*) - } + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) - /** - * Loads JSON files and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can find the JSON-specific options for reading JSON files in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def json(paths: String*): DataFrame = { - format("json").load(paths: _*) - } + override def json(paths: String*): DataFrame = super.json(paths: _*) - /** - * Loads a `Dataset[String]` storing JSON objects (JSON Lines - * text format or newline-delimited JSON) and returns the result as a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the input - * once to determine the input schema. - * - * @param jsonDataset - * input Dataset with one JSON object per record - * @since 3.4.0 - */ + /** @inheritdoc */ def json(jsonDataset: Dataset[String]): DataFrame = parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON) - /** - * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other - * overloaded `csv()` method for more details. - * - * @since 3.4.0 - */ - def csv(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - csv(Seq(path): _*) - } + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) - /** - * Loads CSV files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the CSV-specific options for reading CSV files in - * Data Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def csv(paths: String*): DataFrame = format("csv").load(paths: _*) - - /** - * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * If the schema is not specified using `schema` function and `inferSchema` option is disabled, - * it determines the columns as string types and it reads only the first line to determine the - * names and the number of fields. - * - * If the enforceSchema is set to `false`, only the CSV header in the first line is checked to - * conform specified or inferred schema. - * - * @note - * if `header` option is set to `true` when calling this API, all lines same with the header - * will be removed if exists. - * @param csvDataset - * input Dataset with one CSV row per record - * @since 3.4.0 - */ + override def csv(paths: String*): DataFrame = super.csv(paths: _*) + + /** @inheritdoc */ def csv(csvDataset: Dataset[String]): DataFrame = parse(csvDataset, ParseFormat.PARSE_FORMAT_CSV) - /** - * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other - * overloaded `xml()` method for more details. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - xml(Seq(path): _*) - } + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) - /** - * Loads XML files and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can find the XML-specific options for reading XML files in - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def xml(paths: String*): DataFrame = format("xml").load(paths: _*) - - /** - * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`. - * - * If the schema is not specified using `schema` function and `inferSchema` option is enabled, - * this function goes through the input once to determine the input schema. - * - * @param xmlDataset - * input Dataset with one XML object per record - * @since 4.0.0 - */ + override def xml(paths: String*): DataFrame = super.xml(paths: _*) + + /** @inheritdoc */ def xml(xmlDataset: Dataset[String]): DataFrame = parse(xmlDataset, ParseFormat.PARSE_FORMAT_UNSPECIFIED) - /** - * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the - * other overloaded `parquet()` method for more details. - * - * @since 3.4.0 - */ - def parquet(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - parquet(Seq(path): _*) - } + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) - /** - * Loads a Parquet file, returning the result as a `DataFrame`. - * - * Parquet-specific option(s) for reading Parquet files can be found in Data - * Source Option in the version you use. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def parquet(paths: String*): DataFrame = { - format("parquet").load(paths: _*) - } + override def parquet(paths: String*): DataFrame = super.parquet(paths: _*) - /** - * Loads an ORC file and returns the result as a `DataFrame`. - * - * @param path - * input path - * @since 3.4.0 - */ - def orc(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - orc(Seq(path): _*) - } + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) - /** - * Loads ORC files and returns the result as a `DataFrame`. - * - * ORC-specific option(s) for reading ORC files can be found in Data - * Source Option in the version you use. - * - * @param paths - * input paths - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def orc(paths: String*): DataFrame = format("orc").load(paths: _*) - - /** - * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch - * reading and the returned DataFrame is the batch scan query plan of this table. If it's a - * view, the returned DataFrame is simply the query plan of the view, which can either be a - * batch or streaming query plan. - * - * @param tableName - * is either a qualified or unqualified name that designates a table or view. If a database is - * specified, it identifies the table/view from the database. Otherwise, it first attempts to - * find a temporary view with the given name and then match the table/view from the current - * database. Note that, the global temporary view database is also valid here. - * @since 3.4.0 - */ + override def orc(paths: String*): DataFrame = super.orc(paths: _*) + + /** @inheritdoc */ def table(tableName: String): DataFrame = { + assertNoSpecifiedSchema("table") sparkSession.newDataFrame { builder => builder.getReadBuilder.getNamedTableBuilder .setUnparsedIdentifier(tableName) @@ -504,80 +186,19 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. See the documentation on the - * other overloaded `text()` method for more details. - * - * @since 3.4.0 - */ - def text(path: String): DataFrame = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - text(Seq(path): _*) - } + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. The text files must be encoded - * as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.text("/path/to/spark/README.md") - * - * // Java: - * spark.read().text("/path/to/spark/README.md") - * }}} - * - * You can find the text-specific options for reading text files in - * Data Source Option in the version you use. - * - * @param paths - * input paths - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def text(paths: String*): DataFrame = format("text").load(paths: _*) - - /** - * Loads text files and returns a [[Dataset]] of String. See the documentation on the other - * overloaded `textFile()` method for more details. - * @since 3.4.0 - */ - def textFile(path: String): Dataset[String] = { - // This method ensures that calls that explicit need single argument works, see SPARK-16009 - textFile(Seq(path): _*) - } + override def text(paths: String*): DataFrame = super.text(paths: _*) - /** - * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset - * contains a single string column named "value". The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.read.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.read().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataFrameReader.text`. - * - * @param paths - * input path - * @since 3.4.0 - */ + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + + /** @inheritdoc */ @scala.annotation.varargs - def textFile(paths: String*): Dataset[String] = { - assertNoSpecifiedSchema("textFile") - text(paths: _*).select("value").as(StringEncoder) - } + override def textFile(paths: String*): Dataset[String] = super.textFile(paths: _*) private def assertSourceFormatSpecified(): Unit = { if (source == null) { @@ -597,24 +218,4 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } } - - /** - * A convenient function for schema validation in APIs. - */ - private def assertNoSpecifiedSchema(operation: String): Unit = { - if (userSpecifiedSchema.nonEmpty) { - throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) - } - } - - /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options - /////////////////////////////////////////////////////////////////////////////////////// - - private var source: String = _ - - private var userSpecifiedSchema: Option[StructType] = None - - private var extraOptions = CaseInsensitiveMap[String](Map.empty) - } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 6365f387afce4..bb7cfa75a9ab9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,89 +18,24 @@ package org.apache.spark.sql import java.{lang => jl, util => ju} -import java.io.ByteArrayInputStream - -import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.functions.lit -import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} /** * Statistic functions for `DataFrame`s. * * @since 3.4.0 */ -final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, root: Relation) { - import sparkSession.RichColumn - - /** - * Calculates the approximate quantiles of a numerical column of a DataFrame. - * - * The result of this algorithm has the following deterministic bound: If the DataFrame has N - * elements and if we request the quantile at probability `p` up to error `err`, then the - * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is - * close to (p * N). More precisely, - * - * {{{ - * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N) - * }}} - * - * This method implements a variation of the Greenwald-Khanna algorithm (with some speed - * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile - * Summaries by Greenwald and Khanna. - * - * @param col - * the name of the numerical column - * @param probabilities - * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the - * minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError - * The relative target precision to achieve (greater than or equal to 0). If set to zero, the - * exact quantiles are computed, which could be very expensive. Note that values greater than - * 1 are accepted but give the same result as 1. - * @return - * the approximate quantiles at the given probabilities - * - * @note - * null and NaN values will be removed from the numerical column before calculation. If the - * dataframe is empty or the column only contains null or NaN, an empty array is returned. - * - * @since 3.4.0 - */ - def approxQuantile( - col: String, - probabilities: Array[Double], - relativeError: Double): Array[Double] = { - approxQuantile(Array(col), probabilities, relativeError).head - } +final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) + extends api.DataFrameStatFunctions { + private def root: Relation = df.plan.getRoot + private val sparkSession: SparkSession = df.sparkSession - /** - * Calculates the approximate quantiles of numerical columns of a DataFrame. - * @see - * `approxQuantile(col:Str* approxQuantile)` for detailed description. - * - * @param cols - * the names of the numerical columns - * @param probabilities - * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the - * minimum, 0.5 is the median, 1 is the maximum. - * @param relativeError - * The relative target precision to achieve (greater than or equal to 0). If set to zero, the - * exact quantiles are computed, which could be very expensive. Note that values greater than - * 1 are accepted but give the same result as 1. - * @return - * the approximate quantiles at the given probabilities of each column - * - * @note - * null and NaN values will be ignored in numerical columns before calculation. For columns - * only containing null or NaN values, an empty array is returned. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def approxQuantile( cols: Array[String], probabilities: Array[Double], @@ -120,24 +55,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculate the sample covariance of two numerical columns of a DataFrame. - * @param col1 - * the name of the first column - * @param col2 - * the name of the second column - * @return - * the covariance of the two columns. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.cov("rand1", "rand2") - * res1: Double = 0.065... - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def cov(col1: String, col2: String): Double = { sparkSession .newDataset(PrimitiveDoubleEncoder) { builder => @@ -146,27 +64,7 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in - * MLlib's Statistics. - * - * @param col1 - * the name of the column - * @param col2 - * the name of the column to calculate the correlation against - * @return - * The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2") - * res1: Double = 0.613... - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def corr(col1: String, col2: String, method: String): Double = { require( method == "pearson", @@ -179,289 +77,48 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo .head() } - /** - * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame. - * - * @param col1 - * the name of the column - * @param col2 - * the name of the column to calculate the correlation against - * @return - * The Pearson Correlation Coefficient as a Double. - * - * {{{ - * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) - * .withColumn("rand2", rand(seed=27)) - * df.stat.corr("rand1", "rand2", "pearson") - * res1: Double = 0.613... - * }}} - * - * @since 3.4.0 - */ - def corr(col1: String, col2: String): Double = { - corr(col1, col2, "pearson") - } - - /** - * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. - * The first column of each row will be the distinct values of `col1` and the column names will - * be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts - * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. - * Null elements will be replaced by "null", and back ticks will be dropped from elements if - * they exist. - * - * @param col1 - * The name of the first column. Distinct items will make the first item of each row. - * @param col2 - * The name of the second column. Distinct items will make the column names of the DataFrame. - * @return - * A DataFrame containing for the contingency table. - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))) - * .toDF("key", "value") - * val ct = df.stat.crosstab("key", "value") - * ct.show() - * +---------+---+---+---+ - * |key_value| 1| 2| 3| - * +---------+---+---+---+ - * | 2| 2| 0| 1| - * | 1| 1| 1| 0| - * | 3| 0| 1| 1| - * +---------+---+---+---+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def crosstab(col1: String, col2: String): DataFrame = { sparkSession.newDataFrame { builder => builder.getCrosstabBuilder.setInput(root).setCol1(col1).setCol2(col2) } } - /** - * Finding frequent items for columns, possibly with false positives. Using the frequent element - * count algorithm described in here, - * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @param support - * The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 3.4.0 - */ - def freqItems(cols: Array[String], support: Double): DataFrame = { - sparkSession.newDataFrame { builder => - val freqItemsBuilder = builder.getFreqItemsBuilder.setInput(root).setSupport(support) - cols.foreach(freqItemsBuilder.addCols) - } - } + /** @inheritdoc */ + override def freqItems(cols: Array[String], support: Double): DataFrame = + super.freqItems(cols, support) - /** - * Finding frequent items for columns, possibly with false positives. Using the frequent element - * count algorithm described in here, - * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * @since 3.4.0 - */ - def freqItems(cols: Array[String]): DataFrame = { - freqItems(cols, 0.01) - } + /** @inheritdoc */ + override def freqItems(cols: Array[String]): DataFrame = super.freqItems(cols) + + /** @inheritdoc */ + override def freqItems(cols: Seq[String]): DataFrame = super.freqItems(cols) - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in here, proposed by Karp, Schenker, and - * Papadimitriou. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * {{{ - * val rows = Seq.tabulate(100) { i => - * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) - * } - * val df = spark.createDataFrame(rows).toDF("a", "b") - * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns - * // "a" and "b" - * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) - * freqSingles.show() - * +-----------+-------------+ - * |a_freqItems| b_freqItems| - * +-----------+-------------+ - * | [1, 99]|[-1.0, -99.0]| - * +-----------+-------------+ - * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" - * val pairDf = df.select(struct("a", "b").as("a-b")) - * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1) - * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() - * +----------+ - * | freq_ab| - * +----------+ - * | [1,-1.0]| - * | ... | - * +----------+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def freqItems(cols: Seq[String], support: Double): DataFrame = { - freqItems(cols.toArray, support) + df.sparkSession.newDataFrame { builder => + val freqItemsBuilder = builder.getFreqItemsBuilder + .setInput(df.plan.getRoot) + .setSupport(support) + cols.foreach(freqItemsBuilder.addCols) + } } - /** - * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the - * frequent element count algorithm described in here, proposed by Karp, Schenker, and - * Papadimitriou. Uses a `default` support of 1%. - * - * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting `DataFrame`. - * - * @param cols - * the names of the columns to search frequent items in. - * @return - * A Local DataFrame with the Array of frequent items for each column. - * - * @since 3.4.0 - */ - def freqItems(cols: Seq[String]): DataFrame = { - freqItems(cols.toArray, 0.01) - } + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * {{{ - * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), - * (3, 3))).toDF("key", "value") - * val fractions = Map(1 -> 1.0, 3 -> 0.5) - * df.stat.sampleBy("key", fractions, 36L).show() - * +---+-----+ - * |key|value| - * +---+-----+ - * | 1| 1| - * | 1| 2| - * | 3| 2| - * +---+-----+ - * }}} - * - * @since 3.4.0 - */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { - sampleBy(Column(col), fractions, seed) - } + /** @inheritdoc */ + override def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * @since 3.4.0 - */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) - } + /** @inheritdoc */ + override def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = + super.sampleBy(col, fractions, seed) - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * The stratified sample can be performed over multiple columns: - * {{{ - * import org.apache.spark.sql.Row - * import org.apache.spark.sql.functions.struct - * - * val df = spark.createDataFrame(Seq(("Bob", 17), ("Alice", 10), ("Nico", 8), ("Bob", 17), - * ("Alice", 10))).toDF("name", "age") - * val fractions = Map(Row("Alice", 10) -> 0.3, Row("Nico", 8) -> 1.0) - * df.stat.sampleBy(struct($"name", $"age"), fractions, 36L).show() - * +-----+---+ - * | name|age| - * +-----+---+ - * | Nico| 8| - * |Alice| 10| - * +-----+---+ - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + import sparkSession.RichColumn require( fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") @@ -479,180 +136,6 @@ final class DataFrameStatFunctions private[sql] (sparkSession: SparkSession, roo } } } - - /** - * (Java-specific) Returns a stratified sample without replacement based on the fraction given - * on each stratum. - * @param col - * column that defines strata - * @param fractions - * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as - * zero. - * @param seed - * random seed - * @tparam T - * stratum type - * @return - * a new `DataFrame` that represents the stratified sample - * - * @since 3.4.0 - */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { - sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName - * name of the column over which the sketch is built - * @param depth - * depth of the sketch - * @param width - * width of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(Column(colName), depth, width, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param colName - * name of the column over which the sketch is built - * @param eps - * relative error of the sketch - * @param confidence - * confidence of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch( - colName: String, - eps: Double, - confidence: Double, - seed: Int): CountMinSketch = { - countMinSketch(Column(colName), eps, confidence, seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col - * the column over which the sketch is built - * @param depth - * depth of the sketch - * @param width - * width of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = { - countMinSketch(col, eps = 2.0 / width, confidence = 1 - 1 / Math.pow(2, depth), seed) - } - - /** - * Builds a Count-min Sketch over a specified column. - * - * @param col - * the column over which the sketch is built - * @param eps - * relative error of the sketch - * @param confidence - * confidence of the sketch - * @param seed - * random seed - * @return - * a `CountMinSketch` over column `colName` - * @since 3.4.0 - */ - def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { - val agg = Column.fn("count_min_sketch", col, lit(eps), lit(confidence), lit(seed)) - val ds = sparkSession.newDataset(BinaryEncoder) { builder => - builder.getProjectBuilder - .setInput(root) - .addExpressions(agg.expr) - } - CountMinSketch.readFrom(ds.head()) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName - * name of the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param fpp - * expected false positive probability of the filter. - * @since 3.5.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, fpp) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col - * the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param fpp - * expected false positive probability of the filter. - * @since 3.5.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = { - val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp) - bloomFilter(col, expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param colName - * name of the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param numBits - * expected number of bits of the filter. - * @since 3.5.0 - */ - def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = { - bloomFilter(Column(colName), expectedNumItems, numBits) - } - - /** - * Builds a Bloom filter over a specified column. - * - * @param col - * the column over which the filter is built - * @param expectedNumItems - * expected number of items which will be put into the filter. - * @param numBits - * expected number of bits of the filter. - * @since 3.5.0 - */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { - val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), lit(numBits)) - val ds = sparkSession.newDataset(BinaryEncoder) { builder => - builder.getProjectBuilder - .setInput(root) - .addExpressions(agg.expr) - } - BloomFilter.readFrom(new ByteArrayInputStream(ds.head())) - } } private object DataFrameStatFunctions { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 164871a6334c0..a5ecb2297cb7b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -33,12 +33,13 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.{struct, to_json} -import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, UnresolvedAttribute, UnresolvedRegex} +import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex} import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.{Metadata, StructType} import org.apache.spark.storage.StorageLevel @@ -135,7 +136,8 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends api.Dataset[T, Dataset] { + extends api.Dataset[T] { + type DS[U] = Dataset[U] import sparkSession.RichColumn @@ -284,29 +286,11 @@ class Dataset[T] private[sql] ( } } - /** - * Returns a [[DataFrameNaFunctions]] for working with missing data. - * {{{ - * // Dropping rows containing any null values. - * ds.na.drop() - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ def na: DataFrameNaFunctions = new DataFrameNaFunctions(sparkSession, plan.getRoot) - /** - * Returns a [[DataFrameStatFunctions]] for working statistic functions support. - * {{{ - * // Finding frequent items in column with name 'a'. - * ds.stat.freqItems(Seq("a")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def stat: DataFrameStatFunctions = new DataFrameStatFunctions(sparkSession, plan.getRoot) + /** @inheritdoc */ + def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF()) private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit): DataFrame = { checkSameSparkSession(right) @@ -504,7 +488,7 @@ class Dataset[T] private[sql] ( val unpivot = builder.getUnpivotBuilder .setInput(plan.getRoot) .addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava) - .setValueColumnName(variableColumnName) + .setVariableColumnName(variableColumnName) .setValueColumnName(valueColumnName) valuesOption.foreach { values => unpivot.getValuesBuilder @@ -512,58 +496,20 @@ class Dataset[T] private[sql] ( } } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + private def buildTranspose(indices: Seq[Column]): DataFrame = + sparkSession.newDataFrame { builder => + val transpose = builder.getTransposeBuilder.setInput(plan.getRoot) + indices.foreach { indexColumn => + transpose.addIndexColumns(indexColumn.expr) + } + } + + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) } - /** - * Groups the Dataset using the specified columns, so that we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * This is a variant of groupBy that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy("department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def groupBy(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) - } - /** @inheritdoc */ def reduce(func: (T, T) => T): T = { val udf = SparkUserDefinedFunction( @@ -584,155 +530,24 @@ class Dataset[T] private[sql] ( result(0) } - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) } - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(UdfUtils.mapFunctionToScalaFunc(func))(encoder) - - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup($"department", $"group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) } - /** - * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * This is a variant of rollup that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns rolled up by department and group. - * ds.rollup("department", "group").avg() - * - * // Compute the max age and average salary, rolled up by department and gender. - * ds.rollup($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def rollup(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) - } - - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube($"department", $"group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def cube(cols: Column*): RelationalGroupedDataset = { new RelationalGroupedDataset(toDF(), cols, proto.Aggregate.GroupType.GROUP_TYPE_CUBE) } - /** - * Create a multi-dimensional cube for the current Dataset using the specified columns, so we - * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate - * functions. - * - * This is a variant of cube that can only group by existing columns using column names (i.e. - * cannot construct expressions). - * - * {{{ - * // Compute the average for all numeric columns cubed by department and group. - * ds.cube("department", "group").avg() - * - * // Compute the max age and average salary, cubed by department and gender. - * ds.cube($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def cube(col1: String, cols: String*): RelationalGroupedDataset = { - val colNames: Seq[String] = col1 +: cols - new RelationalGroupedDataset( - toDF(), - colNames.map(colName => Column(colName)), - proto.Aggregate.GroupType.GROUP_TYPE_CUBE) - } - - /** - * Create multi-dimensional aggregation for the current Dataset using the specified grouping - * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the - * available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns group by specific grouping sets. - * ds.groupingSets(Seq(Seq($"department", $"group"), Seq()), $"department", $"group").avg() - * - * // Compute the max age and average salary, group by specific grouping sets. - * ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department", $"group").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset = { val groupingSetMsgs = groupingSets.map { groupingSet => @@ -749,61 +564,6 @@ class Dataset[T] private[sql] ( groupingSets = Some(groupingSetMsgs)) } - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg("age" -> "max", "salary" -> "avg") - * ds.groupBy().agg("age" -> "max", "salary" -> "avg") - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs: _*) - } - - /** - * (Scala-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * (Java-specific) Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(Map("age" -> "max", "salary" -> "avg")) - * ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) - - /** - * Aggregates on the entire Dataset without groups. - * {{{ - * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) - * ds.agg(max($"age"), avg($"salary")) - * ds.groupBy().agg(max($"age"), avg($"salary")) - * }}} - * - * @group untypedrel - * @since 3.4.0 - */ - @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs: _*) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -821,6 +581,14 @@ class Dataset[T] private[sql] ( buildUnpivot(ids, None, variableColumnName, valueColumnName) } + /** @inheritdoc */ + def transpose(indexColumn: Column): DataFrame = + buildTranspose(Seq(indexColumn)) + + /** @inheritdoc */ + def transpose(): DataFrame = + buildTranspose(Seq.empty) + /** @inheritdoc */ def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder => builder.getLimitBuilder @@ -1104,17 +872,17 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def filter(f: FilterFunction[T]): Dataset[T] = { - filter(UdfUtils.filterFuncToScalaFunc(f)) + filter(ToScalaUDF(f)) } /** @inheritdoc */ def map[U: Encoder](f: T => U): Dataset[U] = { - mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f)) + mapPartitions(UDFAdaptors.mapToMapPartitions(f)) } /** @inheritdoc */ def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - map(UdfUtils.mapFunctionToScalaFunc(f))(encoder) + mapPartitions(UDFAdaptors.mapToMapPartitions(f))(encoder) } /** @inheritdoc */ @@ -1131,25 +899,11 @@ class Dataset[T] private[sql] ( } } - /** @inheritdoc */ - def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder) - } - - /** @inheritdoc */ - override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = - mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func)) - - /** @inheritdoc */ - override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { - flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder) - } - /** @inheritdoc */ @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = { val generator = SparkUserDefinedFunction( - UdfUtils.iterableOnceToSeq(f), + UDFAdaptors.iterableOnceToSeq(f), UnboundRowEncoder :: Nil, ScalaReflection.encoderFor[Seq[A]]) select(col("*"), functions.inline(generator(struct(input: _*)))) @@ -1160,31 +914,16 @@ class Dataset[T] private[sql] ( def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( f: A => IterableOnce[B]): DataFrame = { val generator = SparkUserDefinedFunction( - UdfUtils.iterableOnceToSeq(f), + UDFAdaptors.iterableOnceToSeq(f), Nil, ScalaReflection.encoderFor[Seq[B]]) select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn))) } - /** @inheritdoc */ - def foreach(f: T => Unit): Unit = { - foreachPartition(UdfUtils.foreachFuncToForeachPartitionsAdaptor(f)) - } - - /** @inheritdoc */ - override def foreach(func: ForeachFunction[T]): Unit = - foreach(UdfUtils.foreachFuncToScalaFunc(func)) - /** @inheritdoc */ def foreachPartition(f: Iterator[T] => Unit): Unit = { // Delegate to mapPartition with empty result. - mapPartitions(UdfUtils.foreachPartitionFuncToMapPartitionsAdaptor(f))(RowEncoder(Seq.empty)) - .collect() - } - - /** @inheritdoc */ - override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = { - foreachPartition(UdfUtils.foreachPartitionFuncToScalaFunc(func)) + mapPartitions(UDFAdaptors.foreachPartitionToMapPartitions(f))(NullEncoder).collect() } /** @inheritdoc */ @@ -1281,61 +1020,17 @@ class Dataset[T] private[sql] ( .asScala .toArray - /** - * Interface for saving the content of the non-streaming Dataset out into external storage. - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def write: DataFrameWriter[T] = { - new DataFrameWriter[T](this) + new DataFrameWriterImpl[T](this) } - /** - * Create a write configuration builder for v2 sources. - * - * This builder is used to configure and execute write operations. For example, to append to an - * existing table, run: - * - * {{{ - * df.writeTo("catalog.db.table").append() - * }}} - * - * This can also be used to create or replace existing tables: - * - * {{{ - * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace() - * }}} - * - * @group basic - * @since 3.4.0 - */ + /** @inheritdoc */ def writeTo(table: String): DataFrameWriterV2[T] = { - new DataFrameWriterV2[T](table, this) + new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into a target - * table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { throw new AnalysisException( @@ -1343,7 +1038,7 @@ class Dataset[T] private[sql] ( messageParameters = Map("methodName" -> toSQLId("mergeInto"))) } - new MergeIntoWriter[T](table, this, condition) + new MergeIntoWriterImpl[T](table, this, condition) } /** @@ -1422,28 +1117,7 @@ class Dataset[T] private[sql] ( } } - /** - * Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This is - * equivalent to calling `observe(String, Column, Column*)` but does not require to collect all - * results before returning the metrics - the metrics are filled during iterating the results, - * as soon as they are available. This method does not support streaming datasets. - * - * A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`. - * - * {{{ - * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it - * val observation = Observation("my_metrics") - * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) - * observed_ds.write.parquet("ds.parquet") - * val metrics = observation.get - * }}} - * - * @throws IllegalArgumentException - * If this is a streaming Dataset (this.isStreaming == true) - * - * @group typedrel - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = { val df = observe(observation.name, expr, exprs: _*) @@ -1729,6 +1403,22 @@ class Dataset[T] private[sql] ( override def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = super.dropDuplicatesWithinWatermark(col1, cols: _*) + /** @inheritdoc */ + override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = + super.mapPartitions(f, encoder) + + /** @inheritdoc */ + override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = + super.flatMap(func) + + /** @inheritdoc */ + override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = + super.flatMap(f, encoder) + + /** @inheritdoc */ + override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = + super.foreachPartition(func) + /** @inheritdoc */ @scala.annotation.varargs override def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = @@ -1751,4 +1441,39 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ override def distinct(): Dataset[T] = super.distinct() + + /** @inheritdoc */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): RelationalGroupedDataset = + super.groupBy(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def rollup(col1: String, cols: String*): RelationalGroupedDataset = + super.rollup(col1, cols: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def cube(col1: String, cols: String*): RelationalGroupedDataset = + super.cube(col1, cols: _*) + + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) + + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) + + /** @inheritdoc */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 04b620bdf8f98..6bf2518901470 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -19,17 +19,19 @@ package org.apache.spark.sql import java.util.Arrays +import scala.annotation.unused import scala.jdk.CollectionConverters._ -import scala.language.existentials import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr +import org.apache.spark.sql.internal.UDFAdaptors import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} /** @@ -39,7 +41,10 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * * @since 3.5.0 */ -class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { +class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] { + type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] + + private def unsupported(): Nothing = throw new UnsupportedOperationException() /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -48,499 +53,52 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { * * @since 3.5.0 */ - def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = { - throw new UnsupportedOperationException - } + def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = unsupported() - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to - * the data. The grouping key is unchanged by this. - * - * {{{ - * // Create values grouped by key from a Dataset[(K, V)] - * ds.groupByKey(_._1).mapValues(_._2) // Scala - * }}} - * - * @since 3.5.0 - */ - def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = { - throw new UnsupportedOperationException - } + /** @inheritdoc */ + def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = + unsupported() - /** - * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to - * the data. The grouping key is unchanged by this. - * - * {{{ - * // Create Integer values grouped by String key from a Dataset> - * Dataset> ds = ...; - * KeyValueGroupedDataset grouped = - * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); - * }}} - * - * @since 3.5.0 - */ - def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { - mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder) - } - - /** - * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over - * the Dataset to extract the keys and then running a distinct operation on those. - * - * @since 3.5.0 - */ - def keys: Dataset[K] = { - throw new UnsupportedOperationException - } + /** @inheritdoc */ + def keys: Dataset[K] = unsupported() - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an iterator containing elements of an arbitrary type which - * will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - flatMapSortedGroups()(f) - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an iterator containing elements of an arbitrary type which - * will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and a sorted iterator that contains all of the elements - * in the group. The function can return an iterator containing elements of an arbitrary type - * which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( - f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and a sorted iterator that contains all of the elements - * in the group. The function can return an iterator containing elements of an arbitrary type - * which will be returned as a new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ - def flatMapSortedGroups[U]( - SortExprs: Array[Column], - f: FlatMapGroupsFunction[K, V, U], - encoder: Encoder[U]): Dataset[U] = { - import org.apache.spark.util.ArrayImplicits._ - flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)( - UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an element of arbitrary type which will be returned as a - * new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { - flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f)) - } - - /** - * (Java-specific) Applies the given function to each group of data. For each unique group, the - * function will be passed the group key and an iterator that contains all of the elements in - * the group. The function can return an element of arbitrary type which will be returned as a - * new [[Dataset]]. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. If an application intends to perform an aggregation over each - * key, it is best to use the reduce function or an - * `org.apache.spark.sql.expressions#Aggregator`. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the - * memory constraints of their cluster. - * - * @since 3.5.0 - */ - def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { - mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Reduces the elements of each group of data using the specified binary - * function. The given function must be commutative and associative or the result may be - * non-deterministic. - * - * @since 3.5.0 - */ - def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Reduces the elements of each group of data using the specified binary - * function. The given function must be commutative and associative or the result may be - * non-deterministic. - * - * @since 3.5.0 - */ - def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { - reduceGroups(UdfUtils.mapReduceFuncToScalaFunc(f)) - } - - /** - * Internal helper function for building typed aggregations that return tuples. For simplicity - * and code reuse, we do this without the help of the type system and then use helper functions - * that cast appropriately for the user facing interface. - */ - protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - throw new UnsupportedOperationException - } - - /** - * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the - * result of computing this aggregation over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] + f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = + unsupported() - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] + /** @inheritdoc */ + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = unsupported() - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = - aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] + /** @inheritdoc */ + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = unsupported() - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = - aggUntyped(col1, col2, col3, col4, col5, col6) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6], - col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] - - /** - * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and - * the result of computing these aggregations over all elements in the group. - * - * @since 3.5.0 - */ - def agg[U1, U2, U3, U4, U5, U6, U7, U8]( - col1: TypedColumn[V, U1], - col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6], - col7: TypedColumn[V, U7], - col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = - aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) - .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] - - /** - * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for - * that key. - * - * @since 3.5.0 - */ - def count(): Dataset[(K, Long)] = agg(functions.count("*")) - - /** - * (Scala-specific) Applies the given function to each cogrouped data. For each unique group, - * the function will be passed the grouping key and 2 iterators containing all elements in the - * group from [[Dataset]] `this` and `other`. The function can return an iterator containing - * elements of an arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 3.5.0 - */ - def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - cogroupSorted(other)()()(f) - } - - /** - * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the - * function will be passed the grouping key and 2 iterators containing all elements in the group - * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements - * of an arbitrary type which will be returned as a new [[Dataset]]. - * - * @since 3.5.0 - */ - def cogroup[U, R]( - other: KeyValueGroupedDataset[K, U], - f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) - } - - /** - * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique - * group, the function will be passed the grouping key and 2 sorted iterators containing all - * elements in the group from [[Dataset]] `this` and `other`. The function can return an - * iterator containing elements of an arbitrary type which will be returned as a new - * [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)( - otherSortExprs: Column*)( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique - * group, the function will be passed the grouping key and 2 sorted iterators containing all - * elements in the group from [[Dataset]] `this` and `other`. The function can return an - * iterator containing elements of an arbitrary type which will be returned as a new - * [[Dataset]]. - * - * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be - * sorted according to the given sort expressions. That sorting does not add computational - * complexity. - * - * @since 3.5.0 - */ - def cogroupSorted[U, R]( - other: KeyValueGroupedDataset[K, U], - thisSortExprs: Array[Column], - otherSortExprs: Array[Column], - f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): Dataset[R] = { - import org.apache.spark.util.ArrayImplicits._ - cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)( - otherSortExprs.toImmutableArraySeq: _*)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) - } + otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + unsupported() protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder]( outputMode: Option[OutputMode], timeoutConf: GroupStateTimeout, initialState: Option[KeyValueGroupedDataset[K, S]], isMapGroupWithState: Boolean)( - func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { - throw new UnsupportedOperationException - } + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = unsupported() - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { mapGroupsWithState(GroupStateTimeout.NoTimeout)(func) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)( - UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) + UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func)) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See - * [[org.apache.spark.sql.streaming.GroupState]] for more details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param timeoutConf - * Timeout Conf, see GroupStateTimeout for more details - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to - * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, initialState: KeyValueGroupedDataset[K, S])( @@ -549,134 +107,10 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { None, timeoutConf, Some(initialState), - isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) + isMapGroupWithState = true)(UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func)) } - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( - stateEncoder, - outputEncoder) - } - - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( - stateEncoder, - outputEncoder) - } - - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - mapGroupsWithState[S, U](timeoutConf, initialState)( - UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder) - } - - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout)( @@ -688,33 +122,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { isMapGroupWithState = false)(func) } - /** - * (Scala-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To covert a Dataset `ds` of type of type - * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use - * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what - * types are encodable to Spark SQL. - * @since 3.5.0 - */ + /** @inheritdoc */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, @@ -727,201 +135,244 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { isMapGroupWithState = false)(func) } - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def flatMapGroupsWithState[S, U]( + /** @inheritdoc */ + private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + timeMode: TimeMode, + outputMode: OutputMode): Dataset[U] = + unsupported() + + /** @inheritdoc */ + private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + timeMode: TimeMode, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + unsupported() + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( + statefulProcessor: StatefulProcessor[K, V, U], + eventTimeColumnName: String, + outputMode: OutputMode): Dataset[U] = unsupported() + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( + statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], + eventTimeColumnName: String, + outputMode: OutputMode, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported() + + // Overrides... + /** @inheritdoc */ + override def mapValues[W]( + func: MapFunction[V, W], + encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder) + + /** @inheritdoc */ + override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = + super.flatMapGroups(f) + + /** @inheritdoc */ + override def flatMapGroups[U]( + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder) + + /** @inheritdoc */ + override def flatMapSortedGroups[U]( + SortExprs: Array[Column], + f: FlatMapGroupsFunction[K, V, U], + encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder) + + /** @inheritdoc */ + override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f) + + /** @inheritdoc */ + override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = + super.mapGroups(f, encoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf) + + /** @inheritdoc */ + override def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = + super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState) + + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): Dataset[U] = { - val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) - flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) - } + timeoutConf: GroupStateTimeout): Dataset[U] = + super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf) - /** - * (Java-specific) Applies the given function to each group of data, while maintaining a - * user-defined per-group state. The result Dataset will represent the objects returned by the - * function. For a static batch Dataset, the function will be invoked once per group. For a - * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, - * and updates to each group's state will be saved across invocations. See `GroupState` for more - * details. - * - * @tparam S - * The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param func - * Function to be called on every group. - * @param outputMode - * The output mode of the function. - * @param stateEncoder - * Encoder for the state type. - * @param outputEncoder - * Encoder for the output type. - * @param timeoutConf - * Timeout configuration for groups that do not receive data for a while. - * @param initialState - * The user provided state that will be initialized when the first batch of data is processed - * in the streaming query. The user defined function will be called on the state data even if - * there are no other values in the group. To covert a Dataset `ds` of type of type - * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use - * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 3.5.0 - */ - def flatMapGroupsWithState[S, U]( + /** @inheritdoc */ + override def flatMapGroupsWithState[S, U]( func: FlatMapGroupsWithStateFunction[K, V, S, U], outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) - flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)( - stateEncoder, - outputEncoder) - } - - /** - * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state - * API v2. We allow the user to act on per-group set of input rows along with keyed state and - * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will - * repeatedly invoke the interface methods for new rows in each trigger and the user's - * state/state variables will be stored persistently across invocations. Currently this operator - * is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - */ - private[sql] def transformWithState[U: Encoder]( + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState( + func, + outputMode, + stateEncoder, + outputEncoder, + timeoutConf, + initialState) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, - outputMode: OutputMode): Dataset[U] = { - throw new UnsupportedOperationException - } + outputMode: OutputMode, + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder) - /** - * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API - * v2. We allow the user to act on per-group set of input rows along with keyed state and the - * user can choose to output/return 0 or more rows. For a streaming dataframe, we will - * repeatedly invoke the interface methods for new rows in each trigger and the user's - * state/state variables will be stored persistently across invocations. Currently this operator - * is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param outputEncoder - * Encoder for the output type. - */ - private[sql] def transformWithState[U: Encoder]( + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], - timeMode: TimeMode, + eventTimeColumnName: String, outputMode: OutputMode, - outputEncoder: Encoder[U]): Dataset[U] = { - throw new UnsupportedOperationException - } + outputEncoder: Encoder[U]) = + super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder) - /** - * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state - * API v2. Functions as the function above, but with additional initial state. Currently this - * operator is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S - * The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param initialState - * User provided initial state that will be used to initiate state for the query in the first - * batch. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { - throw new UnsupportedOperationException - } - - /** - * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API - * v2. Functions as the function above, but with additional initial state. Currently this - * operator is not supported with Spark Connect. - * - * @tparam U - * The type of the output objects. Must be encodable to Spark SQL types. - * @tparam S - * The type of initial state objects. Must be encodable to Spark SQL types. - * @param statefulProcessor - * Instance of statefulProcessor whose functions will be invoked by the operator. - * @param timeMode - * The time mode semantics of the stateful processor for timers and TTL. - * @param outputMode - * The output mode of the stateful processor. - * @param initialState - * User provided initial state that will be used to initiate state for the query in the first - * batch. - * @param outputEncoder - * Encoder for the output type. - * @param initialStateEncoder - * Encoder for the initial state type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - */ - private[sql] def transformWithState[U: Encoder, S: Encoder]( + initialState: KeyValueGroupedDataset[K, S], + outputEncoder: Encoder[U], + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + timeMode, + outputMode, + initialState, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override private[sql] def transformWithState[U: Encoder, S: Encoder]( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], - timeMode: TimeMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S], + eventTimeColumnName: String, outputEncoder: Encoder[U], - initialStateEncoder: Encoder[S]): Dataset[U] = { - throw new UnsupportedOperationException - } + initialStateEncoder: Encoder[S]) = super.transformWithState( + statefulProcessor, + outputMode, + initialState, + eventTimeColumnName, + outputEncoder, + initialStateEncoder) + + /** @inheritdoc */ + override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f) + + /** @inheritdoc */ + override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1) + + /** @inheritdoc */ + override def agg[U1, U2]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2) + + /** @inheritdoc */ + override def agg[U1, U2, U3]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + super.agg(col1, col2, col3, col4, col5) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = + super.agg(col1, col2, col3, col4, col5, col6) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = + super.agg(col1, col2, col3, col4, col5, col6, col7) + + /** @inheritdoc */ + override def agg[U1, U2, U3, U4, U5, U6, U7, U8]( + col1: TypedColumn[V, U1], + col2: TypedColumn[V, U2], + col3: TypedColumn[V, U3], + col4: TypedColumn[V, U4], + col5: TypedColumn[V, U5], + col6: TypedColumn[V, U6], + col7: TypedColumn[V, U7], + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + super.agg(col1, col2, col3, col4, col5, col6, col7, col8) + + /** @inheritdoc */ + override def count(): Dataset[(K, Long)] = super.count() + + /** @inheritdoc */ + override def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = + super.cogroup(other)(f) + + /** @inheritdoc */ + override def cogroup[U, R]( + other: KeyValueGroupedDataset[K, U], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = super.cogroup(other, f, encoder) + + /** @inheritdoc */ + override def cogroupSorted[U, R]( + other: KeyValueGroupedDataset[K, U], + thisSortExprs: Array[Column], + otherSortExprs: Array[Column], + f: CoGroupFunction[K, V, U, R], + encoder: Encoder[R]): Dataset[R] = + super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder) } /** @@ -934,12 +385,11 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable { private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( private val sparkSession: SparkSession, private val plan: proto.Plan, - private val ikEncoder: AgnosticEncoder[IK], private val kEncoder: AgnosticEncoder[K], private val ivEncoder: AgnosticEncoder[IV], private val vEncoder: AgnosticEncoder[V], private val groupingExprs: java.util.List[proto.Expression], - private val valueMapFunc: IV => V, + private val valueMapFunc: Option[IV => V], private val keysFunc: () => Dataset[IK]) extends KeyValueGroupedDataset[K, V] { import sparkSession.RichColumn @@ -948,7 +398,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( new KeyValueGroupedDatasetImpl[L, V, IK, IV]( sparkSession, plan, - ikEncoder, encoderFor[L], ivEncoder, vEncoder, @@ -961,12 +410,13 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( new KeyValueGroupedDatasetImpl[K, W, IK, IV]( sparkSession, plan, - ikEncoder, kEncoder, ivEncoder, encoderFor[W], groupingExprs, - valueMapFunc.andThen(valueFunc), + valueMapFunc + .map(_.andThen(valueFunc)) + .orElse(Option(valueFunc.asInstanceOf[IV => W])), keysFunc) } @@ -979,8 +429,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { // Apply mapValues changes to the udf - val nf = - if (valueMapFunc == UdfUtils.identical()) f else UdfUtils.mapValuesAdaptor(f, valueMapFunc) + val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc) val outputEncoder = encoderFor[U] sparkSession.newDataset[U](outputEncoder) { builder => builder.getGroupMapBuilder @@ -994,10 +443,9 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])( thisSortExprs: Column*)(otherSortExprs: Column*)( f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { - assert(other.isInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]]) - val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]] + val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]] // Apply mapValues changes to the udf - val nf = UdfUtils.mapValuesAdaptor(f, valueMapFunc, otherImpl.valueMapFunc) + val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc, otherImpl.valueMapFunc) val outputEncoder = encoderFor[R] sparkSession.newDataset[R](outputEncoder) { builder => builder.getCoGroupMapBuilder @@ -1012,8 +460,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( } override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - // TODO(SPARK-43415): For each column, apply the valueMap func first - // apply keyAs change + // TODO(SPARK-43415): For each column, apply the valueMap func first... val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => encoderFor(c.encoder))) sparkSession.newDataset(rEnc) { builder => builder.getAggregateBuilder @@ -1047,22 +494,15 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( throw new IllegalArgumentException("The output mode of function should be append or update") } - if (initialState.isDefined) { - assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) - } - val initialStateImpl = if (initialState.isDefined) { + assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] } else { null } val outputEncoder = encoderFor[U] - val nf = if (valueMapFunc == UdfUtils.identical()) { - func - } else { - UdfUtils.mapValuesAdaptor(func, valueMapFunc) - } + val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func, valueMapFunc) sparkSession.newDataset[U](outputEncoder) { builder => val groupMapBuilder = builder.getGroupMapBuilder @@ -1097,6 +537,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( * We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the * server side. We null out the instance for now. */ + @unused("this is used by java serialization") private def writeReplace(): Any = null } @@ -1114,11 +555,10 @@ private object KeyValueGroupedDatasetImpl { session, ds.plan, kEncoder, - kEncoder, ds.agnosticEncoder, ds.agnosticEncoder, Arrays.asList(toExpr(gf.apply(col("*")))), - UdfUtils.identical(), + None, () => ds.map(groupingFunc)(kEncoder)) } @@ -1137,11 +577,10 @@ private object KeyValueGroupedDatasetImpl { session, df.plan, kEncoder, - kEncoder, vEncoder, vEncoder, (Seq(dummyGroupingFunc) ++ groupingExprs).map(toExpr).asJava, - UdfUtils.identical(), + None, () => df.select(groupingExprs: _*).as(kEncoder)) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 0c8657e12d8df..14ceb3f4bb144 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import java.util.Locale - import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.ConnectConversions._ /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -36,341 +35,121 @@ import org.apache.spark.connect.proto * @since 3.4.0 */ class RelationalGroupedDataset private[sql] ( - private[sql] val df: DataFrame, + protected val df: DataFrame, private[sql] val groupingExprs: Seq[Column], groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None, - groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) { + groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) + extends api.RelationalGroupedDataset { import df.sparkSession.RichColumn - private[this] def toDF(aggExprs: Seq[Column]): DataFrame = { + protected def toDF(aggExprs: Seq[Column]): DataFrame = { df.sparkSession.newDataFrame { builder => - builder.getAggregateBuilder + val aggBuilder = builder.getAggregateBuilder .setInput(df.plan.getRoot) - .addAllGroupingExpressions(groupingExprs.map(_.expr).asJava) - .addAllAggregateExpressions(aggExprs.map(e => e.typedExpr(df.encoder)).asJava) + groupingExprs.foreach(c => aggBuilder.addGroupingExpressions(c.expr)) + aggExprs.foreach(c => aggBuilder.addAggregateExpressions(c.typedExpr(df.encoder))) groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP) case proto.Aggregate.GroupType.GROUP_TYPE_CUBE => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => - builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY) case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT => assert(pivot.isDefined) - builder.getAggregateBuilder + aggBuilder .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT) .setPivot(pivot.get) case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS => assert(groupingSets.isDefined) - val aggBuilder = builder.getAggregateBuilder - .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) + aggBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS) groupingSets.get.foreach(aggBuilder.addGroupingSets) case g => throw new UnsupportedOperationException(g.toString) } } } - /** - * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of - * current `RelationalGroupedDataset`. - * - * @since 3.5.0 - */ - def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { - KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) + protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] = { + // This behaves different than the classic implementation. The classic implementation validates + // if a column is actually a number, and if it is not it throws an error immediately. In connect + // it depends on the input type (casting) rules for the method invoked. If the input violates + // the a different error will be thrown. However it is also possible to get a result for a + // non-numeric column in connect, for example when you use min/max. + colNames.map(df.col) } - /** - * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The - * resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} - * - * @since 3.4.0 - */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - toDF((aggExpr +: aggExprs).map { case (colName, expr) => - strToColumn(expr, df(colName)) - }) + /** @inheritdoc */ + def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { + KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) } - /** - * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate - * methods. The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg(Map( - * "age" -> "max", - * "expense" -> "sum" - * )) - * }}} - * - * @since 3.4.0 - */ - def agg(exprs: Map[String, String]): DataFrame = { - toDF(exprs.map { case (colName, expr) => - strToColumn(expr, df(colName)) - }.toSeq) - } + /** @inheritdoc */ + override def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = + super.agg(aggExpr, aggExprs: _*) - /** - * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. - * The resulting `DataFrame` will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * import com.google.common.collect.ImmutableMap; - * df.groupBy("department").agg(ImmutableMap.of("age", "max", "expense", "sum")); - * }}} - * - * @since 3.4.0 - */ - def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.asScala.toMap) - } + /** @inheritdoc */ + override def agg(exprs: Map[String, String]): DataFrame = super.agg(exprs) - private[this] def strToColumn(expr: String, inputExpr: Column): Column = { - expr.toLowerCase(Locale.ROOT) match { - case "avg" | "average" | "mean" => functions.avg(inputExpr) - case "stddev" | "std" => functions.stddev(inputExpr) - case "count" | "size" => functions.count(inputExpr) - case name => Column.fn(name, inputExpr) - } - } + /** @inheritdoc */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = super.agg(exprs) - /** - * Compute aggregates by specifying a series of aggregate columns. Note that this function by - * default retains the grouping columns in its output. To not retain grouping columns, set - * `spark.sql.retainGroupColumns` to false. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * Note that before Spark 1.4, the default behavior is to NOT retain grouping columns. To change - * to that behavior, set config variable `spark.sql.retainGroupColumns` to `false`. - * {{{ - * // Scala, 1.3.x: - * df.groupBy("department").agg($"department", max("age"), sum("expense")) - * - * // Java, 1.3.x: - * df.groupBy("department").agg(col("department"), max("age"), sum("expense")); - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = { - toDF(expr +: exprs) - } + override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) - /** - * Count the number of rows for each group. The resulting `DataFrame` will also contain the - * grouping columns. - * - * @since 3.4.0 - */ - def count(): DataFrame = toDF(Seq(functions.count(functions.lit(1)).alias("count"))) + /** @inheritdoc */ + override def count(): DataFrame = super.count() - /** - * Compute the average value for each numeric columns for each group. This is an alias for - * `avg`. The resulting `DataFrame` will also contain the grouping columns. When specified - * columns are given, only compute the average values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def mean(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.mean(colName))) - } + override def mean(colNames: String*): DataFrame = super.mean(colNames: _*) - /** - * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will - * also contain the grouping columns. When specified columns are given, only compute the max - * values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def max(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.max(colName))) - } + override def max(colNames: String*): DataFrame = super.max(colNames: _*) - /** - * Compute the mean value for each numeric columns for each group. The resulting `DataFrame` - * will also contain the grouping columns. When specified columns are given, only compute the - * mean values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def avg(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.avg(colName))) - } + override def avg(colNames: String*): DataFrame = super.avg(colNames: _*) - /** - * Compute the min value for each numeric column for each group. The resulting `DataFrame` will - * also contain the grouping columns. When specified columns are given, only compute the min - * values for them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def min(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.min(colName))) - } + override def min(colNames: String*): DataFrame = super.min(colNames: _*) - /** - * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also - * contain the grouping columns. When specified columns are given, only compute the sum for - * them. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def sum(colNames: String*): DataFrame = { - toDF(colNames.map(colName => functions.sum(colName))) - } + override def sum(colNames: String*): DataFrame = super.sum(colNames: _*) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the - * resulting schema of the transformation. To avoid any eager computations, provide an explicit - * list of values via `pivot(pivotColumn: String, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @since 3.4.0 - */ - def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(Column(pivotColumn)) + /** @inheritdoc */ + override def pivot(pivotColumn: String): RelationalGroupedDataset = super.pivot(pivotColumn) - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are - * two versions of pivot function: one that requires the caller to specify the list of distinct - * values to pivot on, and one that does not. The latter is more concise but less efficient, - * because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings") - * }}} - * - * From Spark 3.0.0, values can be literal columns, for instance, struct. For pivoting by - * multiple columns, use the `struct` function to combine the columns and values: - * - * {{{ - * df.groupBy("year") - * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) - * .agg(sum($"earnings")) - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) - } + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) + + /** @inheritdoc */ + override def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = + super.pivot(pivotColumn, values) - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. - * - * There are two versions of pivot function: one that requires the caller to specify the list of - * distinct values to pivot on, and one that does not. The latter is more concise but less - * efficient, because Spark needs to first compute the list of distinct values internally. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy("year").pivot("course", Arrays.asList("dotNET", "Java")).sum("earnings"); - * - * // Or without specifying column values (less efficient) - * df.groupBy("year").pivot("course").sum("earnings"); - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * Name of the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(Column(pivotColumn), values) + /** @inheritdoc */ + override def pivot( + pivotColumn: Column, + values: java.util.List[Any]): RelationalGroupedDataset = { + super.pivot(pivotColumn, values) } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an - * overloaded version of the `pivot` method with `pivotColumn` of the `String` type. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ + /** @inheritdoc */ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY => - val valueExprs = values.map(_ match { + val valueExprs = values.map { case c: Column if c.expr.hasLiteral => c.expr.getLiteral case c: Column if !c.expr.hasLiteral => throw new IllegalArgumentException("values only accept literal Column") case v => functions.lit(v).expr.getLiteral - }) + } new RelationalGroupedDataset( df, groupingExprs, @@ -386,46 +165,8 @@ class RelationalGroupedDataset private[sql] ( } } - /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the - * resulting schema of the transformation. To avoid any eager computations, provide an explicit - * list of values via `pivot(pivotColumn: Column, values: Seq[Any])`. - * - * {{{ - * // Compute the sum of earnings for each year by course with each course as a separate column - * df.groupBy($"year").pivot($"course").sum($"earnings"); - * }}} - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * he column to pivot. - * @since 3.4.0 - */ + /** @inheritdoc */ def pivot(pivotColumn: Column): RelationalGroupedDataset = { - pivot(pivotColumn, Seq()) - } - - /** - * (Java-specific) Pivots a column of the current `DataFrame` and performs the specified - * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of the - * `String` type. - * - * @see - * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the - * aggregation. - * - * @param pivotColumn - * the column to pivot. - * @param values - * List of values that will be translated to columns in the output DataFrame. - * @since 3.4.0 - */ - def pivot(pivotColumn: Column, values: java.util.List[Any]): RelationalGroupedDataset = { - pivot(pivotColumn, values.asScala.toSeq) + pivot(pivotColumn, Nil) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 6914b2cc8a0f7..484341cb1f0ef 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql import java.math.BigInteger import java.net.URI +import java.nio.file.{Files, Paths} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import com.google.common.cache.{CacheBuilder, CacheLoader} import io.grpc.ClientInterceptor @@ -40,7 +42,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf} +import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, SessionCleaner, SqlApiConf} import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager @@ -68,7 +70,7 @@ import org.apache.spark.util.ArrayImplicits._ class SparkSession private[sql] ( private[sql] val client: SparkConnectClient, private val planIdGenerator: AtomicLong) - extends api.SparkSession[Dataset] + extends api.SparkSession with Logging { private[this] val allocator = new RootAllocator() @@ -87,16 +89,8 @@ class SparkSession private[sql] ( client.hijackServerSideSessionIdForTesting(suffix) } - /** - * Runtime configuration interface for Spark. - * - * This is the interface through which the user can get and set all Spark configurations that - * are relevant to Spark SQL. When getting the value of a config, his defaults to the value set - * in server, if any. - * - * @since 3.4.0 - */ - val conf: RuntimeConfig = new RuntimeConfig(client) + /** @inheritdoc */ + val conf: RuntimeConfig = new ConnectRuntimeConfig(client) /** @inheritdoc */ @transient @@ -213,16 +207,7 @@ class SparkSession private[sql] ( sql(query, Array.empty) } - /** - * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a - * `DataFrame`. - * {{{ - * sparkSession.read.parquet("/path/to/file.parquet") - * sparkSession.read.schema(schema).json("/path/to/file.json") - * }}} - * - * @since 3.4.0 - */ + /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(this) /** @@ -238,12 +223,7 @@ class SparkSession private[sql] ( lazy val streams: StreamingQueryManager = new StreamingQueryManager(this) - /** - * Interface through which the user may create, drop, alter or query underlying databases, - * tables, functions etc. - * - * @since 3.5.0 - */ + /** @inheritdoc */ lazy val catalog: Catalog = new CatalogImpl(this) /** @inheritdoc */ @@ -396,77 +376,30 @@ class SparkSession private[sql] ( execute(command) } - /** - * Add a single artifact to the client session. - * - * Currently only local files with extensions .jar and .class are supported. - * - * @since 3.4.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(path: String): Unit = client.addArtifact(path) + override def addArtifact(path: String): Unit = client.addArtifact(path) - /** - * Add a single artifact to the client session. - * - * Currently it supports local files with extensions .jar and .class and Apache Ivy URIs - * - * @since 3.4.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(uri: URI): Unit = client.addArtifact(uri) + override def addArtifact(uri: URI): Unit = client.addArtifact(uri) - /** - * Add a single in-memory artifact to the session while preserving the directory structure - * specified by `target` under the session's working directory of that particular file - * extension. - * - * Supported target file extensions are .jar and .class. - * - * ==Example== - * {{{ - * addArtifact(bytesBar, "foo/bar.class") - * addArtifact(bytesFlat, "flat.class") - * // Directory structure of the session's working directory for class files would look like: - * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class - * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class - * }}} - * - * @since 4.0.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(bytes: Array[Byte], target: String): Unit = client.addArtifact(bytes, target) + override def addArtifact(bytes: Array[Byte], target: String): Unit = { + client.addArtifact(bytes, target) + } - /** - * Add a single artifact to the session while preserving the directory structure specified by - * `target` under the session's working directory of that particular file extension. - * - * Supported target file extensions are .jar and .class. - * - * ==Example== - * {{{ - * addArtifact("/Users/dummyUser/files/foo/bar.class", "foo/bar.class") - * addArtifact("/Users/dummyUser/files/flat.class", "flat.class") - * // Directory structure of the session's working directory for class files would look like: - * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class - * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class - * }}} - * - * @since 4.0.0 - */ + /** @inheritdoc */ @Experimental - def addArtifact(source: String, target: String): Unit = client.addArtifact(source, target) + override def addArtifact(source: String, target: String): Unit = { + client.addArtifact(source, target) + } - /** - * Add one or more artifacts to the session. - * - * Currently it supports local files with extensions .jar and .class and Apache Ivy URIs - * - * @since 3.4.0 - */ + /** @inheritdoc */ @Experimental @scala.annotation.varargs - def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri) + override def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri) /** * Register a ClassFinder for dynamically generated classes. @@ -493,7 +426,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptAll(): Seq[String] = { + override def interruptAll(): Seq[String] = { client.interruptAll().getInterruptedIdsList.asScala.toSeq } @@ -506,7 +439,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptTag(tag: String): Seq[String] = { + override def interruptTag(tag: String): Seq[String] = { client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq } @@ -519,7 +452,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptOperation(operationId: String): Seq[String] = { + override def interruptOperation(operationId: String): Seq[String] = { client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq } @@ -550,65 +483,17 @@ class SparkSession private[sql] ( SparkSession.onSessionClose(this) } - /** - * Add a tag to be assigned to all the operations started by this thread in this session. - * - * Often, a unit of execution in an application consists of multiple Spark executions. - * Application programmers can use this method to group all those jobs together and give a group - * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all - * running running executions with this tag. For example: - * {{{ - * // In the main thread: - * spark.addTag("myjobs") - * spark.range(10).map(i => { Thread.sleep(10); i }).collect() - * - * // In a separate thread: - * spark.interruptTag("myjobs") - * }}} - * - * There may be multiple tags present at the same time, so different parts of application may - * use different tags to perform cancellation at different levels of granularity. - * - * @param tag - * The tag to be added. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def addTag(tag: String): Unit = { - client.addTag(tag) - } + /** @inheritdoc */ + override def addTag(tag: String): Unit = client.addTag(tag) - /** - * Remove a tag previously added to be assigned to all the operations started by this thread in - * this session. Noop if such a tag was not added earlier. - * - * @param tag - * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def removeTag(tag: String): Unit = { - client.removeTag(tag) - } + /** @inheritdoc */ + override def removeTag(tag: String): Unit = client.removeTag(tag) - /** - * Get the tags that are currently set to be assigned to all the operations started by this - * thread. - * - * @since 3.5.0 - */ - def getTags(): Set[String] = { - client.getTags() - } + /** @inheritdoc */ + override def getTags(): Set[String] = client.getTags() - /** - * Clear the current thread's operation tags. - * - * @since 3.5.0 - */ - def clearTags(): Unit = { - client.clearTags() - } + /** @inheritdoc */ + override def clearTags(): Unit = client.clearTags() /** * We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side. @@ -622,17 +507,14 @@ class SparkSession private[sql] ( private[sql] var releaseSessionOnClose = true private[sql] def registerObservation(planId: Long, observation: Observation): Unit = { - if (observationRegistry.putIfAbsent(planId, observation) != null) { - throw new IllegalArgumentException("An Observation can be used with a Dataset only once") - } + observation.markRegistered() + observationRegistry.putIfAbsent(planId, observation) } - private[sql] def setMetricsAndUnregisterObservation( - planId: Long, - metrics: Map[String, Any]): Unit = { + private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: Row): Unit = { val observationOrNull = observationRegistry.remove(planId) if (observationOrNull != null) { - observationOrNull.setMetricsAndNotify(Some(metrics)) + observationOrNull.setMetricsAndNotify(metrics) } } @@ -647,6 +529,10 @@ class SparkSession private[sql] ( object SparkSession extends Logging { private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong + private var server: Option[Process] = None + private[sql] val sparkOptions = sys.props.filter { p => + p._1.startsWith("spark.") && p._2.nonEmpty + }.toMap private val sessions = CacheBuilder .newBuilder() @@ -679,6 +565,51 @@ object SparkSession extends Logging { } } + /** + * Create a new Spark Connect server to connect locally. + */ + private[sql] def withLocalConnectServer[T](f: => T): T = { + synchronized { + val remoteString = sparkOptions + .get("spark.remote") + .orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit + .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) + + val maybeConnectScript = + Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) + + if (server.isEmpty && + remoteString.exists(_.startsWith("local")) && + maybeConnectScript.exists(Files.exists(_))) { + server = Some { + val args = + Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions + .filter(p => !p._1.startsWith("spark.remote")) + .flatMap { case (k, v) => Seq("--conf", s"$k=$v") } + val pb = new ProcessBuilder(args: _*) + // So don't exclude spark-sql jar in classpath + pb.environment().remove(SparkConnectClient.SPARK_REMOTE) + pb.start() + } + + // Let the server start. We will directly request to set the configurations + // and this sleep makes less noisy with retries. + Thread.sleep(2000L) + System.setProperty("spark.remote", "sc://localhost") + + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = if (server.isDefined) { + new ProcessBuilder(maybeConnectScript.get.toString) + .start() + } + }) + // scalastyle:on runtimeaddshutdownhook + } + } + f + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -821,6 +752,16 @@ object SparkSession extends Logging { } private def applyOptions(session: SparkSession): Unit = { + // Only attempts to set Spark SQL configurations. + // If the configurations are static, it might throw an exception so + // simply ignore it for now. + sparkOptions + .filter { case (k, _) => + k.startsWith("spark.sql.") + } + .foreach { case (key, value) => + Try(session.conf.set(key, value)) + } options.foreach { case (key, value) => session.conf.set(key, value) } @@ -843,7 +784,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def create(): SparkSession = { + def create(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse(SparkSession.this.create(builder.configuration)) setDefaultAndActiveSession(session) @@ -863,7 +804,7 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def getOrCreate(): SparkSession = { + def getOrCreate(): SparkSession = withLocalConnectServer { val session = tryCreateSessionFromClient() .getOrElse({ var existingSession = sessions.get(builder.configuration) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index 86775803a0937..63fa2821a6c6a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.application import java.io.{InputStream, OutputStream} -import java.nio.file.Paths import java.util.concurrent.Semaphore -import scala.util.Try import scala.util.control.NonFatal import ammonite.compiler.CodeClassWrapper @@ -34,6 +32,7 @@ import ammonite.util.Util.newLine import org.apache.spark.SparkBuildInfo.spark_version import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.SparkSession.withLocalConnectServer import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser} /** @@ -64,37 +63,7 @@ Spark session available as 'spark'. semaphore: Option[Semaphore] = None, inputStream: InputStream = System.in, outputStream: OutputStream = System.out, - errorStream: OutputStream = System.err): Unit = { - val configs: Map[String, String] = - sys.props - .filter(p => - p._1.startsWith("spark.") && - p._2.nonEmpty && - // Don't include spark.remote that we manually set later. - !p._1.startsWith("spark.remote")) - .toMap - - val remoteString: Option[String] = - Option(System.getProperty("spark.remote")) // Set from Spark Submit - .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) - - if (remoteString.exists(_.startsWith("local"))) { - server = Some { - val args = Seq( - Paths.get(sparkHome, "sbin", "start-connect-server.sh").toString, - "--master", - remoteString.get) ++ configs.flatMap { case (k, v) => Seq("--conf", s"$k=$v") } - val pb = new ProcessBuilder(args: _*) - // So don't exclude spark-sql jar in classpath - pb.environment().remove(SparkConnectClient.SPARK_REMOTE) - pb.start() - } - // Let the server start. We will directly request to set the configurations - // and this sleep makes less noisy with retries. - Thread.sleep(2000L) - System.setProperty("spark.remote", "sc://localhost") - } - + errorStream: OutputStream = System.err): Unit = withLocalConnectServer { // Build the client. val client = try { @@ -118,13 +87,6 @@ Spark session available as 'spark'. // Build the session. val spark = SparkSession.builder().client(client).getOrCreate() - - // The configurations might not be all runtime configurations. - // Try to set them with ignoring failures for now. - configs - .filter(_._1.startsWith("spark.sql")) - .foreach { case (k, v) => Try(spark.conf.set(k, v)) } - val sparkBind = new Bind("spark", spark) // Add the proper imports and register a [[ClassFinder]]. @@ -197,18 +159,12 @@ Spark session available as 'spark'. } } } - try { - if (semaphore.nonEmpty) { - // Used for testing. - main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) - } else { - main.run(sparkBind) - } - } finally { - if (server.isDefined) { - new ProcessBuilder(Paths.get(sparkHome, "sbin", "stop-connect-server.sh").toString) - .start() - } + + if (semaphore.nonEmpty) { + // Used for testing. + main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get)) + } else { + main.run(sparkBind) } } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index cf0fef147ee84..86b1dbe4754e6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -17,660 +17,152 @@ package org.apache.spark.sql.catalog -import scala.jdk.CollectionConverters._ +import java.util -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset} +import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.types.StructType -import org.apache.spark.storage.StorageLevel -/** - * Catalog interface for Spark. To access this, use `SparkSession.catalog`. - * - * @since 3.5.0 - */ -abstract class Catalog { - - /** - * Returns the current database (namespace) in this session. - * - * @since 3.5.0 - */ - def currentDatabase: String - - /** - * Sets the current database (namespace) in this session. - * - * @since 3.5.0 - */ - def setCurrentDatabase(dbName: String): Unit - - /** - * Returns a list of databases (namespaces) available within the current catalog. - * - * @since 3.5.0 - */ - def listDatabases(): Dataset[Database] - - /** - * Returns a list of databases (namespaces) which name match the specify pattern and available - * within the current catalog. - * - * @since 3.5.0 - */ - def listDatabases(pattern: String): Dataset[Database] - - /** - * Returns a list of tables/views in the current database (namespace). This includes all - * temporary views. - * - * @since 3.5.0 - */ - def listTables(): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) (the name can be - * qualified with catalog). This includes all temporary views. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String): Dataset[Table] - - /** - * Returns a list of tables/views in the specified database (namespace) which name match the - * specify pattern (the name can be qualified with catalog). This includes all temporary views. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listTables(dbName: String, pattern: String): Dataset[Table] - - /** - * Returns a list of functions registered in the current database (namespace). This includes all - * temporary functions. - * - * @since 3.5.0 - */ - def listFunctions(): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) (the name can be - * qualified with catalog). This includes all built-in and temporary functions. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String): Dataset[Function] - - /** - * Returns a list of functions registered in the specified database (namespace) which name match - * the specify pattern (the name can be qualified with catalog). This includes all built-in and - * temporary functions. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String, pattern: String): Dataset[Function] - - /** - * Returns a list of columns for the given table/view or temporary view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("table does not exist") - def listColumns(tableName: String): Dataset[Column] - - /** - * Returns a list of columns for the given table/view in the specified database under the Hive - * Metastore. - * - * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with - * qualified table/view name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param tableName - * is an unqualified name that designates a table/view. - * @since 3.5.0 - */ - @throws[AnalysisException]("database or table does not exist") - def listColumns(dbName: String, tableName: String): Dataset[Column] - - /** - * Get the database (namespace) with the specified name (can be qualified with catalog). This - * throws an AnalysisException when the database (namespace) cannot be found. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database does not exist") - def getDatabase(dbName: String): Database - - /** - * Get the table or view with the specified name. This table can be a temporary view or a - * table/view. This throws an AnalysisException when no Table can be found. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("table does not exist") - def getTable(tableName: String): Table - - /** - * Get the table or view with the specified name in the specified database under the Hive - * Metastore. This throws an AnalysisException when no Table can be found. - * - * To get table/view in other catalogs, please use `getTable(tableName)` with qualified - * table/view name instead. - * - * @since 3.5.0 - */ - @throws[AnalysisException]("database or table does not exist") - def getTable(dbName: String, tableName: String): Table - - /** - * Get the function with the specified name. This function can be a temporary function or a - * function. This throws an AnalysisException when the function cannot be found. - * - * @param functionName - * is either a qualified or unqualified name that designates a function. It follows the same - * resolution rule with SQL: search for built-in/temp functions first then functions in the - * current database (namespace). - * @since 3.5.0 - */ - @throws[AnalysisException]("function does not exist") - def getFunction(functionName: String): Function - - /** - * Get the function with the specified name in the specified database under the Hive Metastore. - * This throws an AnalysisException when the function cannot be found. - * - * To get functions in other catalogs, please use `getFunction(functionName)` with qualified - * function name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param functionName - * is an unqualified name that designates a function in the specified database - * @since 3.5.0 - */ - @throws[AnalysisException]("database or function does not exist") - def getFunction(dbName: String, functionName: String): Function - - /** - * Check if the database (namespace) with the specified name exists (the name can be qualified - * with catalog). - * - * @since 3.5.0 - */ - def databaseExists(dbName: String): Boolean - - /** - * Check if the table or view with the specified name exists. This can either be a temporary - * view or a table/view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. It follows the same - * resolution rule with SQL: search for temp views first then table/views in the current - * database (namespace). - * @since 3.5.0 - */ - def tableExists(tableName: String): Boolean - - /** - * Check if the table or view with the specified name exists in the specified database under the - * Hive Metastore. - * - * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with - * qualified table/view name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param tableName - * is an unqualified name that designates a table. - * @since 3.5.0 - */ - def tableExists(dbName: String, tableName: String): Boolean - - /** - * Check if the function with the specified name exists. This can either be a temporary function - * or a function. - * - * @param functionName - * is either a qualified or unqualified name that designates a function. It follows the same - * resolution rule with SQL: search for built-in/temp functions first then functions in the - * current database (namespace). - * @since 3.5.0 - */ - def functionExists(functionName: String): Boolean - - /** - * Check if the function with the specified name exists in the specified database under the Hive - * Metastore. - * - * To check existence of functions in other catalogs, please use `functionExists(functionName)` - * with qualified function name instead. - * - * @param dbName - * is an unqualified name that designates a database. - * @param functionName - * is an unqualified name that designates a function. - * @since 3.5.0 - */ - def functionExists(dbName: String, functionName: String): Boolean - - /** - * Creates a table from the given path and returns the corresponding DataFrame. It will use the - * default data source configured by spark.sql.sources.default. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String): DataFrame = { - createTable(tableName, path) - } - - /** - * Creates a table from the given path and returns the corresponding DataFrame. It will use the - * default data source configured by spark.sql.sources.default. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, path: String): DataFrame - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String, source: String): DataFrame = { - createTable(tableName, path, source) - } - - /** - * Creates a table from the given path based on a data source and returns the corresponding - * DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, path: String, source: String): DataFrame - - /** - * Creates a table from the given path based on a data source and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( +/** @inheritdoc */ +abstract class Catalog extends api.Catalog { + + /** @inheritdoc */ + override def listDatabases(): Dataset[Database] + + /** @inheritdoc */ + override def listDatabases(pattern: String): Dataset[Database] + + /** @inheritdoc */ + override def listTables(): Dataset[Table] + + /** @inheritdoc */ + override def listTables(dbName: String): Dataset[Table] + + /** @inheritdoc */ + override def listTables(dbName: String, pattern: String): Dataset[Table] + + /** @inheritdoc */ + override def listFunctions(): Dataset[Function] + + /** @inheritdoc */ + override def listFunctions(dbName: String): Dataset[Function] + + /** @inheritdoc */ + override def listFunctions(dbName: String, pattern: String): Dataset[Function] + + /** @inheritdoc */ + override def listColumns(tableName: String): Dataset[Column] + + /** @inheritdoc */ + override def listColumns(dbName: String, tableName: String): Dataset[Column] + + /** @inheritdoc */ + override def createTable(tableName: String, path: String): DataFrame + + /** @inheritdoc */ + override def createTable(tableName: String, path: String, source: String): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } - - /** - * Creates a table based on the dataset in a data source and a set of options. Then, returns the - * corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, options.asScala.toMap) - } - - /** - * (Scala-specific) Creates a table from the given path based on a data source and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + description: String, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, options) - } - - /** - * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable(tableName: String, source: String, options: Map[String, String]): DataFrame - - /** - * Create a table from the given path based on a data source, a schema and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + schema: StructType, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } - - /** - * Creates a table based on the dataset in a data source and a set of options. Then, returns the - * corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + description: String, + options: Map[String, String]): DataFrame + + /** @inheritdoc */ + override def listCatalogs(): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def listCatalogs(pattern: String): Dataset[CatalogMetadata] + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String): DataFrame = + super.createExternalTable(tableName, path) + + /** @inheritdoc */ + override def createExternalTable(tableName: String, path: String, source: String): DataFrame = + super.createExternalTable(tableName, path, source) + + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - description = description, - options = options.asScala.toMap) - } - - /** - * (Scala-specific) Creates a table based on the dataset in a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, - description: String, - options: Map[String, String]): DataFrame + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options.asScala.toMap) - } - - /** - * (Scala-specific) Create a table from the given path based on a data source, a schema and a - * set of options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - @deprecated("use createTable instead.", "2.2.0") - def createExternalTable( + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, options) + + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - options: Map[String, String]): DataFrame = { - createTable(tableName, source, schema, options) - } - - /** - * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: util.Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) + + /** @inheritdoc */ + override def createTable( + tableName: String, + source: String, + description: String, + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, description, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, - options: Map[String, String]): DataFrame + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, options) - /** - * Create a table based on the dataset in a data source, a schema and a set of options. Then, - * returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + /** @inheritdoc */ + override def createExternalTable( tableName: String, source: String, schema: StructType, - description: String, - options: java.util.Map[String, String]): DataFrame = { - createTable( - tableName, - source = source, - schema = schema, - description = description, - options = options.asScala.toMap) - } - - /** - * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of - * options. Then, returns the corresponding DataFrame. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def createTable( + options: Map[String, String]): DataFrame = + super.createExternalTable(tableName, source, schema, options) + + /** @inheritdoc */ + override def createTable( tableName: String, source: String, schema: StructType, description: String, - options: Map[String, String]): DataFrame - - /** - * Drops the local temporary view with the given view name in the catalog. If the view has been - * cached before, then it will also be uncached. - * - * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that - * created it, i.e. it will be automatically dropped when the session terminates. It's not tied - * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view. - * - * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean in - * Spark 2.1. - * - * @param viewName - * the name of the temporary view to be dropped. - * @return - * true if the view is dropped successfully, false otherwise. - * @since 3.5.0 - */ - def dropTempView(viewName: String): Boolean - - /** - * Drops the global temporary view with the given view name in the catalog. If the view has been - * cached before, then it will also be uncached. - * - * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark - * application, i.e. it will be automatically dropped when the application terminates. It's tied - * to a system preserved database `global_temp`, and we must use the qualified name to refer a - * global temp view, e.g. `SELECT * FROM global_temp.view1`. - * - * @param viewName - * the unqualified name of the temporary view to be dropped. - * @return - * true if the view is dropped successfully, false otherwise. - * @since 3.5.0 - */ - def dropGlobalTempView(viewName: String): Boolean - - /** - * Recovers all the partitions in the directory of a table and update the catalog. Only works - * with a partitioned table, and not a view. - * - * @param tableName - * is either a qualified or unqualified name that designates a table. If no database - * identifier is provided, it refers to a table in the current database. - * @since 3.5.0 - */ - def recoverPartitions(tableName: String): Unit - - /** - * Returns true if the table is currently cached in-memory. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def isCached(tableName: String): Boolean - - /** - * Caches the specified table in-memory. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def cacheTable(tableName: String): Unit - - /** - * Caches the specified table with the given storage level. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @param storageLevel - * storage level to cache table. - * @since 3.5.0 - */ - def cacheTable(tableName: String, storageLevel: StorageLevel): Unit - - /** - * Removes the specified table from the in-memory cache. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def uncacheTable(tableName: String): Unit - - /** - * Removes all cached tables from the in-memory cache. - * - * @since 3.5.0 - */ - def clearCache(): Unit - - /** - * Invalidates and refreshes all the cached data and metadata of the given table. For - * performance reasons, Spark SQL or the external data source library it uses might cache - * certain metadata about a table, such as the location of blocks. When those change outside of - * Spark SQL, users should call this function to invalidate the cache. - * - * If this table is cached as an InMemoryRelation, drop the original cached version and make the - * new version cached lazily. - * - * @param tableName - * is either a qualified or unqualified name that designates a table/view. If no database - * identifier is provided, it refers to a temporary view or a table/view in the current - * database. - * @since 3.5.0 - */ - def refreshTable(tableName: String): Unit - - /** - * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` - * that contains the given data source path. Path matching is by prefix, i.e. "/" would - * invalidate everything that is cached. - * - * @since 3.5.0 - */ - def refreshByPath(path: String): Unit - - /** - * Returns the current catalog in this session. - * - * @since 3.5.0 - */ - def currentCatalog(): String - - /** - * Sets the current catalog in this session. - * - * @since 3.5.0 - */ - def setCurrentCatalog(catalogName: String): Unit - - /** - * Returns a list of catalogs available in this session. - * - * @since 3.5.0 - */ - def listCatalogs(): Dataset[CatalogMetadata] - - /** - * Returns a list of catalogs which name match the specify pattern and available in this - * session. - * - * @since 3.5.0 - */ - def listCatalogs(pattern: String): Dataset[CatalogMetadata] + options: util.Map[String, String]): DataFrame = + super.createTable(tableName, source, schema, description, options) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala new file mode 100644 index 0000000000000..7d81f4ead7857 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql._ + +/** + * Conversions from sql interfaces to the Connect specific implementation. + * + * This class is mainly used by the implementation. In the case of connect it should be extremely + * rare that a developer needs these classes. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They + * can create a shim for these conversions, the Spark 4+ version of the shim implements this + * trait, and shims for older versions do not. + */ +@DeveloperApi +trait ConnectConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V]( + kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = + kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] +} + +object ConnectConversions extends ConnectConversions diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala index 4ebc22202b0b7..b359a871d8c28 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/AmmoniteClassFinder.scala @@ -22,6 +22,8 @@ import java.nio.file.Paths import ammonite.repl.api.Session import ammonite.runtime.SpecialClassLoader +import org.apache.spark.sql.Artifact + /** * A special [[ClassFinder]] for the Ammonite REPL to handle in-memory class files. * diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala similarity index 68% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala rename to connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala index f77dd512ef257..7578e2424fb42 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala @@ -14,10 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.internal import org.apache.spark.connect.proto.{ConfigRequest, ConfigResponse, KeyValue} import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.connect.client.SparkConnectClient /** @@ -25,61 +26,31 @@ import org.apache.spark.sql.connect.client.SparkConnectClient * * @since 3.4.0 */ -class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { +class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) + extends RuntimeConfig + with Logging { - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def set(key: String, value: String): Unit = { executeConfigRequest { builder => builder.getSetBuilder.addPairsBuilder().setKey(key).setValue(value) } } - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ - def set(key: String, value: Boolean): Unit = set(key, String.valueOf(value)) - - /** - * Sets the given Spark runtime configuration property. - * - * @since 3.4.0 - */ - def set(key: String, value: Long): Unit = set(key, String.valueOf(value)) - - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @throws java.util.NoSuchElementException - * if the key is not set and does not have a default value - * @since 3.4.0 - */ + /** @inheritdoc */ @throws[NoSuchElementException]("if the key is not set") def get(key: String): String = getOption(key).getOrElse { throw new NoSuchElementException(key) } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def get(key: String, default: String): String = { executeConfigRequestSingleValue { builder => builder.getGetWithDefaultBuilder.addPairsBuilder().setKey(key).setValue(default) } } - /** - * Returns all properties set in this conf. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def getAll: Map[String, String] = { val response = executeConfigRequest { builder => builder.getGetAllBuilder @@ -92,11 +63,7 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { builder.result() } - /** - * Returns the value of Spark runtime configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def getOption(key: String): Option[String] = { val pair = executeConfigRequestSinglePair { builder => builder.getGetOptionBuilder.addKeys(key) @@ -108,27 +75,14 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging { } } - /** - * Resets the configuration property for the given key. - * - * @since 3.4.0 - */ + /** @inheritdoc */ def unset(key: String): Unit = { executeConfigRequest { builder => builder.getUnsetBuilder.addKeys(key) } } - /** - * Indicates whether the configuration property with the given key is modifiable in the current - * session. - * - * @return - * `true` if the configuration property is modifiable. For static SQL, Spark Core, invalid - * (not existing) and other non-modifiable configuration properties, the returned value is - * `false`. - * @since 3.4.0 - */ + /** @inheritdoc */ def isModifiable(key: String): Boolean = { val modifiable = executeConfigRequestSingleValue { builder => builder.getIsModifiableBuilder.addKeys(key) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala new file mode 100644 index 0000000000000..58fbfea48afec --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.annotation.Stable +import org.apache.spark.connect.proto +import org.apache.spark.sql.{DataFrameWriter, Dataset, SaveMode} + +/** + * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, key-value + * stores, etc). Use `Dataset.write` to access this. + * + * @since 3.4.0 + */ +@Stable +final class DataFrameWriterImpl[T] private[sql] (ds: Dataset[T]) extends DataFrameWriter[T] { + + /** @inheritdoc */ + override def mode(saveMode: SaveMode): this.type = super.mode(saveMode) + + /** @inheritdoc */ + override def mode(saveMode: String): this.type = super.mode(saveMode) + + /** @inheritdoc */ + override def format(source: String): this.type = super.format(source) + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = + super.options(options) + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = + super.options(options) + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionBy(colNames: String*): this.type = super.partitionBy(colNames: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def bucketBy(numBuckets: Int, colName: String, colNames: String*): this.type = + super.bucketBy(numBuckets, colName, colNames: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def sortBy(colName: String, colNames: String*): this.type = + super.sortBy(colName, colNames: _*) + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = + super.clusterBy(colName, colNames: _*) + + /** @inheritdoc */ + def save(path: String): Unit = { + saveInternal(Some(path)) + } + + /** @inheritdoc */ + def save(): Unit = saveInternal(None) + + private def saveInternal(path: Option[String]): Unit = { + executeWriteOperation(builder => path.foreach(builder.setPath)) + } + + private def executeWriteOperation(f: proto.WriteOperation.Builder => Unit): Unit = { + val builder = proto.WriteOperation.newBuilder() + + builder.setInput(ds.plan.getRoot) + + // Set path or table + f(builder) + + // Cannot both be set + require(!(builder.hasPath && builder.hasTable)) + + builder.setMode(mode match { + case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND + case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE + case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE + case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS + }) + + if (source.nonEmpty) { + builder.setSource(source) + } + sortColumnNames.foreach(names => builder.addAllSortColumnNames(names.asJava)) + partitioningColumns.foreach(cols => builder.addAllPartitioningColumns(cols.asJava)) + clusteringColumns.foreach(cols => builder.addAllClusteringColumns(cols.asJava)) + + numBuckets.foreach(n => { + val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder() + bucketBuilder.setNumBuckets(n) + bucketColumnNames.foreach(names => bucketBuilder.addAllBucketColumnNames(names.asJava)) + builder.setBucketBy(bucketBuilder) + }) + + extraOptions.foreach { case (k, v) => + builder.putOptions(k, v) + } + + ds.sparkSession.execute(proto.Command.newBuilder().setWriteOperation(builder).build()) + } + + /** @inheritdoc */ + def insertInto(tableName: String): Unit = { + executeWriteOperation(builder => { + builder.setTable( + proto.WriteOperation.SaveTable + .newBuilder() + .setTableName(tableName) + .setSaveMethod( + proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO)) + }) + } + + /** @inheritdoc */ + def saveAsTable(tableName: String): Unit = { + executeWriteOperation(builder => { + builder.setTable( + proto.WriteOperation.SaveTable + .newBuilder() + .setTableName(tableName) + .setSaveMethod( + proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE)) + }) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala new file mode 100644 index 0000000000000..4afa8b6d566c5 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.connect.proto +import org.apache.spark.sql.{Column, DataFrameWriterV2, Dataset} + +/** + * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 + * API. + * + * @since 3.4.0 + */ +@Experimental +final class DataFrameWriterV2Impl[T] private[sql] (table: String, ds: Dataset[T]) + extends DataFrameWriterV2[T] { + import ds.sparkSession.RichColumn + + private val builder = proto.WriteOperationV2 + .newBuilder() + .setInput(ds.plan.getRoot) + .setTableName(table) + + /** @inheritdoc */ + override def using(provider: String): this.type = { + builder.setProvider(provider) + this + } + + /** @inheritdoc */ + override def option(key: String, value: String): this.type = { + builder.putOptions(key, value) + this + } + + /** @inheritdoc */ + override def options(options: scala.collection.Map[String, String]): this.type = { + builder.putAllOptions(options.asJava) + this + } + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = { + builder.putAllOptions(options) + this + } + + /** @inheritdoc */ + override def tableProperty(property: String, value: String): this.type = { + builder.putTableProperties(property, value) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def partitionedBy(column: Column, columns: Column*): this.type = { + builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava) + this + } + + /** @inheritdoc */ + @scala.annotation.varargs + override def clusterBy(colName: String, colNames: String*): this.type = { + builder.addAllClusteringColumns((colName +: colNames).asJava) + this + } + + /** @inheritdoc */ + override def create(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE) + } + + /** @inheritdoc */ + override def replace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE) + } + + /** @inheritdoc */ + override def createOrReplace(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE) + } + + /** @inheritdoc */ + def append(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND) + } + + /** @inheritdoc */ + def overwrite(condition: Column): Unit = { + builder.setOverwriteCondition(condition.expr) + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE) + } + + /** @inheritdoc */ + def overwritePartitions(): Unit = { + executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS) + } + + private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = { + val command = proto.Command + .newBuilder() + .setWriteOperationV2(builder.setMode(mode)) + .build() + ds.sparkSession.execute(command) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala new file mode 100644 index 0000000000000..fba3c6343558b --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.annotation.Experimental +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.{Expression, MergeAction, MergeIntoTableCommand} +import org.apache.spark.connect.proto.MergeAction.ActionType._ +import org.apache.spark.sql.{Column, Dataset, MergeIntoWriter} +import org.apache.spark.sql.functions.expr + +/** + * `MergeIntoWriter` provides methods to define and execute merge actions based on specified + * conditions. + * + * @tparam T + * the type of data in the Dataset. + * @param table + * the name of the target table for the merge operation. + * @param ds + * the source Dataset to merge into the target table. + * @param on + * the merge condition. + * + * @since 4.0.0 + */ +@Experimental +class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column) + extends MergeIntoWriter[T] { + import ds.sparkSession.RichColumn + + private val builder = MergeIntoTableCommand + .newBuilder() + .setTargetTableName(table) + .setSourceTablePlan(ds.plan.getRoot) + .setMergeCondition(on.expr) + + /** + * Executes the merge operation. + */ + def merge(): Unit = { + if (builder.getMatchActionsCount == 0 && + builder.getNotMatchedActionsCount == 0 && + builder.getNotMatchedBySourceActionsCount == 0) { + throw new SparkRuntimeException( + errorClass = "NO_MERGE_ACTION_SPECIFIED", + messageParameters = Map.empty) + } + ds.sparkSession.execute( + proto.Command + .newBuilder() + .setMergeIntoTableCommand(builder.setWithSchemaEvolution(schemaEvolutionEnabled)) + .build()) + } + + override protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] = { + builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT_STAR, condition)) + this + } + + override protected[sql] def insert( + condition: Option[Column], + map: Map[String, Column]): MergeIntoWriter[T] = { + builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT, condition, map)) + this + } + + override protected[sql] def updateAll( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction( + buildMergeAction(ACTION_TYPE_UPDATE_STAR, condition), + notMatchedBySource) + } + + override protected[sql] def update( + condition: Option[Column], + map: Map[String, Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction( + buildMergeAction(ACTION_TYPE_UPDATE, condition, map), + notMatchedBySource) + } + + override protected[sql] def delete( + condition: Option[Column], + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + appendUpdateDeleteAction(buildMergeAction(ACTION_TYPE_DELETE, condition), notMatchedBySource) + } + + private def appendUpdateDeleteAction( + action: Expression, + notMatchedBySource: Boolean): MergeIntoWriter[T] = { + if (notMatchedBySource) { + builder.addNotMatchedBySourceActions(action) + } else { + builder.addMatchActions(action) + } + this + } + + private def buildMergeAction( + actionType: MergeAction.ActionType, + condition: Option[Column], + assignments: Map[String, Column] = Map.empty): Expression = { + val builder = proto.MergeAction.newBuilder().setActionType(actionType) + condition.foreach(c => builder.setCondition(c.expr)) + assignments.foreach { case (k, v) => + builder + .addAssignmentsBuilder() + .setKey(expr(k).expr) + .setValue(v.expr) + } + Expression + .newBuilder() + .setMergeAction(builder) + .build() + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 13a26fa79085e..29fbcc443deb9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -18,166 +18,21 @@ package org.apache.spark.sql.streaming import java.util.UUID -import java.util.concurrent.TimeoutException import scala.jdk.CollectionConverters._ -import org.apache.spark.annotation.Evolving import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.StreamingQueryCommand import org.apache.spark.connect.proto.StreamingQueryCommandResult import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{api, SparkSession} -/** - * A handle to a query that is executing continuously in the background as new data arrives. All - * these methods are thread-safe. - * @since 3.5.0 - */ -@Evolving -trait StreamingQuery { - // This is a copy of StreamingQuery in sql/core/.../streaming/StreamingQuery.scala - - /** - * Returns the user-specified name of the query, or null if not specified. This name can be - * specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as - * `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across - * all active queries. - * - * @since 3.5.0 - */ - def name: String - - /** - * Returns the unique id of this query that persists across restarts from checkpoint data. That - * is, this id is generated when a query is started for the first time, and will be the same - * every time it is restarted from checkpoint data. Also see [[runId]]. - * - * @since 3.5.0 - */ - def id: UUID - - /** - * Returns the unique id of this run of the query. That is, every start/restart of a query will - * generate a unique runId. Therefore, every time a query is restarted from checkpoint, it will - * have the same [[id]] but different [[runId]]s. - */ - def runId: UUID - - /** - * Returns the `SparkSession` associated with `this`. - * - * @since 3.5.0 - */ - def sparkSession: SparkSession - - /** - * Returns `true` if this query is actively running. - * - * @since 3.5.0 - */ - def isActive: Boolean - - /** - * Returns the [[StreamingQueryException]] if the query was terminated by an exception. - * @since 3.5.0 - */ - def exception: Option[StreamingQueryException] - - /** - * Returns the current status of the query. - * - * @since 3.5.0 - */ - def status: StreamingQueryStatus - - /** - * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The - * number of progress updates retained for each stream is configured by Spark session - * configuration `spark.sql.streaming.numRecentProgressUpdates`. - * - * @since 3.5.0 - */ - def recentProgress: Array[StreamingQueryProgress] - - /** - * Returns the most recent [[StreamingQueryProgress]] update of this streaming query. - * - * @since 3.5.0 - */ - def lastProgress: StreamingQueryProgress - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If - * the query has terminated with an exception, then the exception will be thrown. - * - * If the query has terminated, then all subsequent calls to this method will either return - * immediately (if the query was terminated by `stop()`), or throw the exception immediately (if - * the query has terminated with exception). - * - * @throws StreamingQueryException - * if the query has terminated with an exception. - * @since 3.5.0 - */ - @throws[StreamingQueryException] - def awaitTermination(): Unit - - /** - * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If - * the query has terminated with an exception, then the exception will be thrown. Otherwise, it - * returns whether the query has terminated or not within the `timeoutMs` milliseconds. - * - * If the query has terminated, then all subsequent calls to this method will either return - * `true` immediately (if the query was terminated by `stop()`), or throw the exception - * immediately (if the query has terminated with exception). - * - * @throws StreamingQueryException - * if the query has terminated with an exception - * @since 3.5.0 - */ - @throws[StreamingQueryException] - def awaitTermination(timeoutMs: Long): Boolean - - /** - * Blocks until all available data in the source has been processed and committed to the sink. - * This method is intended for testing. Note that in the case of continually arriving data, this - * method may block forever. Additionally, this method is only guaranteed to block until data - * that has been synchronously appended data to a - * `org.apache.spark.sql.execution.streaming.Source` prior to invocation. (i.e. `getOffset` must - * immediately reflect the addition). - * @since 3.5.0 - */ - def processAllAvailable(): Unit - - /** - * Stops the execution of this query if it is running. This waits until the termination of the - * query execution threads or until a timeout is hit. - * - * By default stop will block indefinitely. You can configure a timeout by the configuration - * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block - * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the - * issue persists, it is advisable to kill the Spark application. - * - * @since 3.5.0 - */ - @throws[TimeoutException] - def stop(): Unit - - /** - * Prints the physical plan to the console for debugging purposes. - * @since 3.5.0 - */ - def explain(): Unit - - /** - * Prints the physical plan to the console for debugging purposes. - * - * @param extended - * whether to do extended explain or not - * @since 3.5.0 - */ - def explain(extended: Boolean): Unit +/** @inheritdoc */ +trait StreamingQuery extends api.StreamingQuery { + + /** @inheritdoc */ + override def sparkSession: SparkSession } class RemoteStreamingQuery( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala index 1b166f8ace1d5..04367d3b95f14 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala @@ -71,35 +71,46 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach { test("write") { val df = ss.newDataFrame(_ => ()).limit(10) - val builder = proto.WriteOperation.newBuilder() - builder + def toPlan(builder: proto.WriteOperation.Builder): proto.Plan = { + proto.Plan + .newBuilder() + .setCommand(proto.Command.newBuilder().setWriteOperation(builder)) + .build() + } + + val builder = proto.WriteOperation + .newBuilder() .setInput(df.plan.getRoot) .setPath("my/test/path") .setMode(proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS) .setSource("parquet") - .addSortColumnNames("col1") - .addPartitioningColumns("col99") - .setBucketBy( - proto.WriteOperation.BucketBy - .newBuilder() - .setNumBuckets(2) - .addBucketColumnNames("col1") - .addBucketColumnNames("col2")) - .addClusteringColumns("col3") - val expectedPlan = proto.Plan - .newBuilder() - .setCommand(proto.Command.newBuilder().setWriteOperation(builder)) - .build() + val partitionedPlan = toPlan( + builder + .clone() + .addSortColumnNames("col1") + .addPartitioningColumns("col99") + .setBucketBy( + proto.WriteOperation.BucketBy + .newBuilder() + .setNumBuckets(2) + .addBucketColumnNames("col1") + .addBucketColumnNames("col2"))) df.write .sortBy("col1") .partitionBy("col99") .bucketBy(2, "col1", "col2") + .parquet("my/test/path") + val actualPartionedPlan = service.getAndClearLatestInputPlan() + assert(actualPartionedPlan.equals(partitionedPlan)) + + val clusteredPlan = toPlan(builder.clone().addClusteringColumns("col3")) + df.write .clusterBy("col3") .parquet("my/test/path") - val actualPlan = service.getAndClearLatestInputPlan() - assert(actualPlan.equals(expectedPlan)) + val actualClusteredPlan = service.getAndClearLatestInputPlan() + assert(actualClusteredPlan.equals(clusteredPlan)) } test("write jdbc") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 90f980387aa4e..4bb833e16eeab 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -24,7 +24,7 @@ import java.util.Properties import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.jdk.CollectionConverters._ import org.apache.commons.io.FileUtils @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} @@ -1573,6 +1573,25 @@ class ClientE2ETestSuite val result = df.select(trim(col("col"), " ").as("trimmed_col")).collect() assert(result sameElements Array(Row("a"), Row("b"), Row("c"))) } + + test("SPARK-49673: new batch size, multiple batches") { + val maxBatchSize = spark.conf.get("spark.connect.grpc.arrow.maxBatchSize").dropRight(1).toInt + // Adjust client grpcMaxMessageSize to maxBatchSize (10MiB; set in RemoteSparkSession config) + val sparkWithLowerMaxMessageSize = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .userId("test") + .port(port) + .grpcMaxMessageSize(maxBatchSize) + .retryPolicy(RetryPolicy + .defaultPolicy() + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) + .build()) + .create() + assert(sparkWithLowerMaxMessageSize.range(maxBatchSize).collect().length == maxBatchSize) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 474eac138ab78..315f80e13eff7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -552,6 +552,14 @@ class PlanGenerationTestSuite valueColumnName = "value") } + test("transpose index_column") { + simple.transpose(indexColumn = fn.col("id")) + } + + test("transpose no_index_column") { + simple.transpose() + } + test("offset") { simple.offset(1000) } @@ -1801,7 +1809,11 @@ class PlanGenerationTestSuite fn.sentences(fn.col("g")) } - functionTest("sentences with locale") { + functionTest("sentences with language") { + fn.sentences(fn.col("g"), lit("en")) + } + + functionTest("sentences with language and country") { fn.sentences(fn.col("g"), lit("en"), lit("US")) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index bbc396a937c3e..66a2c943af5f6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -30,6 +30,7 @@ import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.scalatest.BeforeAndAfterEach import org.apache.spark.connect.proto.AddArtifactsRequest +import org.apache.spark.sql.Artifact import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.IvyTestUtils diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 94bf18027b43a..16f6983efb187 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -145,21 +145,6 @@ object CheckConnectJvmClientCompatibility { checkMiMaCompatibility(clientJar, protobufJar, includedRules, excludeRules) } - private lazy val mergeIntoWriterExcludeRules: Seq[ProblemFilter] = { - // Exclude some auto-generated methods in [[MergeIntoWriter]] classes. - // The incompatible changes are due to the uses of [[proto.Expression]] instead - // of [[catalyst.Expression]] in the method signature. - val classNames = Seq("WhenMatched", "WhenNotMatched", "WhenNotMatchedBySource") - val methodNames = Seq("apply", "condition", "copy", "copy$*", "unapply") - - classNames.flatMap { className => - methodNames.map { methodName => - ProblemFilters.exclude[IncompatibleSignatureProblem]( - s"org.apache.spark.sql.$className.$methodName") - } - } - } - private def checkMiMaCompatibilityWithSqlModule( clientJar: File, sqlJar: File): List[Problem] = { @@ -173,6 +158,7 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"), @@ -207,8 +193,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.Dataset$" // private[sql] ), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"), // TODO (SPARK-49096): // Mima check might complain the following Dataset rules does not filter any problem. // This is due to a potential bug in Mima that all methods in `class Dataset` are not being @@ -271,30 +255,16 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.TestGroupState"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.TestGroupState$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.PythonStreamingQueryListener"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.PythonStreamingQueryListenerWrapper"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListener$Event"), // SQLImplicits ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.session"), - // Artifact Manager + // Artifact Manager, client has a totally different implementation. ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.artifact.ArtifactManager"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.artifact.ArtifactManager$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.artifact.util.ArtifactUtils"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.artifact.util.ArtifactUtils$"), - - // UDFRegistration - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.register"), // ColumnNode conversions ProblemFilters.exclude[DirectMissingMethodProblem]( @@ -310,6 +280,8 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.*"), // UDFRegistration + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.register"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.UDFRegistration"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.UDFRegistration.log*"), @@ -326,12 +298,13 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.UDFRegistration.initializeLogIfNecessary$default$2"), - // Datasource V2 partition transforms - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.PartitionTransform$ExtractTransform")) ++ - mergeIntoWriterExcludeRules + // Protected DataFrameReader methods... + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateSingleVariantColumn"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateJsonSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.DataFrameReader.validateXmlSchema")) checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules) } @@ -391,27 +364,8 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.execute"), // Experimental - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addArtifact"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addArtifacts"), ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.registerClassFinder"), - // public - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptAll"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptOperation"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.removeTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.getTags"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.clearTags"), // SparkSession#Builder ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.remote"), @@ -445,8 +399,7 @@ object CheckConnectJvmClientCompatibility { // Encoders are in the wrong JAR ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) ++ - mergeIntoWriterExcludeRules + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 70b471cf74b33..5397dae9dcc5f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,7 +30,7 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{sql, SparkUnsupportedOperationException} +import org.apache.spark.{sql, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.{AnalysisException, Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} @@ -776,6 +776,16 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } + test("kryo serialization") { + val e = intercept[SparkRuntimeException] { + val encoder = sql.encoderFor(Encoders.kryo[(Int, String)]) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "itr_" + i)) + } + } + assert(e.getErrorClass == "CANNOT_USE_KRYO") + } + test("transforming encoder") { val schema = new StructType() .add("key", IntegerType) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 758262ead7f1e..27b1ee014a719 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -334,8 +334,6 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L assert(exception.getErrorClass != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) - assert(!exception.getMessageParameters().get("startOffset").isEmpty) - assert(!exception.getMessageParameters().get("endOffset").isEmpty) assert(exception.getCause.isInstanceOf[SparkException]) assert(exception.getCause.getCause.isInstanceOf[SparkException]) assert( @@ -374,8 +372,6 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L assert(exception.getErrorClass != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) - assert(!exception.getMessageParameters().get("startOffset").isEmpty) - assert(!exception.getMessageParameters().get("endOffset").isEmpty) assert(exception.getCause.isInstanceOf[SparkException]) assert(exception.getCause.getCause.isInstanceOf[SparkException]) assert( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala index a878e42b40aa7..36aaa2cc7fbf6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -24,6 +24,9 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.FiniteDuration import org.scalatest.{BeforeAndAfterAll, Suite} +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.timeout +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkBuildInfo import org.apache.spark.sql.SparkSession @@ -121,6 +124,8 @@ object SparkConnectServerUtils { // to make the tests exercise reattach. "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", "spark.connect.execute.reattachable.senderMaxStreamSize=123", + // Testing SPARK-49673, setting maxBatchSize to 10MiB + s"spark.connect.grpc.arrow.maxBatchSize=${10 * 1024 * 1024}", // Disable UI "spark.ui.enabled=false") Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil) @@ -184,12 +189,14 @@ object SparkConnectServerUtils { .port(port) .retryPolicy(RetryPolicy .defaultPolicy() - .copy(maxRetries = Some(7), maxBackoff = Some(FiniteDuration(10, "s")))) + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) .build()) .create() // Execute an RPC which will get retried until the server is up. - assert(spark.version == SparkBuildInfo.spark_version) + eventually(timeout(1.minute)) { + assert(spark.version == SparkBuildInfo.spark_version) + } // Auto-sync dependencies. SparkConnectServerUtils.syncTestDependencies(spark) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 30009c03c49fd..90cd68e6e1d24 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -490,7 +490,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { .option("query", "SELECT @myvariant1 as variant1, @myvariant2 as variant2") .load() }, - errorClass = "UNRECOGNIZED_SQL_TYPE", + condition = "UNRECOGNIZED_SQL_TYPE", parameters = Map("typeName" -> "sql_variant", "jdbcType" -> "-156")) } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index b337eb2fc9b3b..91a82075a3607 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -87,7 +87,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"DOUBLE\"", "newType" -> "\"VARCHAR(10)\"", diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 27ec98e9ac451..e5fd453cb057c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -97,7 +97,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -115,7 +115,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - errorClass = "_LEGACY_ERROR_TEMP_2271") + condition = "_LEGACY_ERROR_TEMP_2271") } test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 81aacf2c14d7a..700c05b54a256 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -77,8 +77,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest s"""CREATE TABLE pattern_testing_table ( |pattern_testing_col LONGTEXT |) - """.stripMargin + |""".stripMargin ).executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -98,7 +109,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -131,7 +142,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - errorClass = "_LEGACY_ERROR_TEMP_2271") + condition = "_LEGACY_ERROR_TEMP_2271") } override def testCreateTableWithProperty(tbl: String): Unit = { @@ -157,6 +168,79 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 65536) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 2) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022") + checkFilterPushed(df5) + val rows5 = df5.collect() + assert(rows5.length === 2) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4") + checkFilterPushed(df8) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "alex") + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + // MySQL does not support + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10, false) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } } /** diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index 42d82233b421b..5e40f0bbc4554 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -62,7 +62,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) }, - errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) assert(catalog.namespaceExists(Array("foo")) === false) @@ -74,7 +74,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac Array("foo"), NamespaceChange.setProperty("comment", "comment for foo")) }, - errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) @@ -82,7 +82,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) }, - errorClass = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT", + condition = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT", parameters = Map("namespace" -> "`foo`") ) @@ -90,7 +90,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac exception = intercept[SparkSQLFeatureNotSupportedException] { catalog.dropNamespace(Array("foo"), cascade = false) }, - errorClass = "UNSUPPORTED_FEATURE.DROP_NAMESPACE", + condition = "UNSUPPORTED_FEATURE.DROP_NAMESPACE", parameters = Map("namespace" -> "`foo`") ) catalog.dropNamespace(Array("foo"), cascade = true) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 342fb4bb38e60..2c97a588670a8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -118,7 +118,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"DECIMAL(19,0)\"", "newType" -> "\"INT\"", @@ -139,7 +139,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes exception = intercept[SparkRuntimeException] { sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')") }, - errorClass = "EXCEED_LIMIT_LENGTH", + condition = "EXCEED_LIMIT_LENGTH", parameters = Map("limit" -> "255") ) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index e22136a09a56c..850391e8dc33c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -84,7 +84,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT exception = intercept[AnalysisException] { sql(sql1) }, - errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + condition = "NOT_SUPPORTED_CHANGE_COLUMN", parameters = Map( "originType" -> "\"STRING\"", "newType" -> "\"INT\"", @@ -118,7 +118,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT sql(s"CREATE TABLE $t2(c int)") checkError( exception = intercept[TableAlreadyExistsException](sql(s"ALTER TABLE $t1 RENAME TO t2")), - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`t2`") ) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index e4cc88cec0f5e..3b1a457214be7 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -92,7 +92,7 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte catalog.listNamespaces(Array("foo")) } checkError(e, - errorClass = "SCHEMA_NOT_FOUND", + condition = "SCHEMA_NOT_FOUND", parameters = Map("schemaName" -> "`foo`")) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index b0ab614b27d1f..54635f69f8b65 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -71,7 +71,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -92,11 +92,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu private def checkErrorFailedJDBC( e: AnalysisException, - errorClass: String, + condition: String, tbl: String): Unit = { checkErrorMatchPVals( exception = e, - errorClass = errorClass, + condition = condition, parameters = Map( "url" -> "jdbc:.*", "tableName" -> s"`$tbl`") @@ -126,7 +126,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 DOUBLE)") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "add", "fieldNames" -> "`C3`", @@ -159,7 +159,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -182,7 +182,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(sqlText) }, - errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION", sqlState = "42703", parameters = Map( "objectName" -> "`bad_column`", @@ -206,7 +206,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.alt_table RENAME COLUMN ID1 TO ID2") }, - errorClass = "FIELD_ALREADY_EXISTS", + condition = "FIELD_ALREADY_EXISTS", parameters = Map( "op" -> "rename", "fieldNames" -> "`ID2`", @@ -308,7 +308,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[IndexAlreadyExistsException] { sql(s"CREATE index i1 ON $catalogName.new_table (col1)") }, - errorClass = "INDEX_ALREADY_EXISTS", + condition = "INDEX_ALREADY_EXISTS", parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`") ) @@ -333,7 +333,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[NoSuchIndexException] { sql(s"DROP index i1 ON $catalogName.new_table") }, - errorClass = "INDEX_NOT_FOUND", + condition = "INDEX_NOT_FOUND", parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`") ) } @@ -353,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { + protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { val filter = df.queryExecution.optimizedPlan.collect { case f: Filter => f } @@ -975,9 +975,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu exception = intercept[AnalysisException] { sql(s"ALTER TABLE $catalogName.tbl2 RENAME TO tbl1") }, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> "`tbl1`") ) } } + + def testDatetime(tbl: String): Unit = {} + + test("scan with filter push-down with date time functions") { + testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}") + } } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala index 56456f9b1f776..8d0bcc5816775 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala @@ -50,7 +50,7 @@ private[kafka010] class KafkaRecordToRowConverter { new GenericArrayData(cr.headers.iterator().asScala .map(header => InternalRow(UTF8String.fromString(header.key()), header.value()) - ).toArray) + ).toArray[Any]) } else { null } diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 9ae6a9290f80a..1d119de43970f 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1156,7 +1156,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with test("allow group.id prefix") { // Group ID prefix is only supported by consumer based offset reader - if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { + if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { testGroupId("groupIdPrefix", (expected, actual) => { assert(actual.exists(_.startsWith(expected)) && !actual.exists(_ === expected), "Valid consumer groups don't contain the expected group id - " + @@ -1167,7 +1167,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with test("allow group.id override") { // Group ID override is only supported by consumer based offset reader - if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { + if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) { testGroupId("kafka.group.id", (expected, actual) => { assert(actual.exists(_ === expected), "Valid consumer groups don't " + s"contain the expected group id - Valid consumer groups: $actual / " + diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala index 320485a79e59d..6fc22e7ac5e03 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala @@ -153,7 +153,7 @@ class KafkaOffsetReaderSuite extends QueryTest with SharedSparkSession with Kafk } checkError( exception = ex, - errorClass = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED", + condition = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED", parameters = Map( "specifiedPartitions" -> "Set\\(.*,.*\\)", "assignedPartitions" -> "Set\\(.*,.*,.*\\)"), diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 31050887936bd..3b0def8fc73f7 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -20,7 +20,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.Column -import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.protobuf.utils.ProtobufUtils // scalastyle:off: object.name @@ -71,7 +71,13 @@ object functions { messageName: String, binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String]): Column = { - ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap) + Column.fnWithOptions( + "from_protobuf", + options.asScala.iterator, + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** @@ -90,7 +96,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - ProtobufDataToCatalyst(data, messageName, Some(fileContent)) + from_protobuf(data, messageName, fileContent) } /** @@ -109,7 +115,12 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet)) + Column.fn( + "from_protobuf", + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** @@ -129,7 +140,11 @@ object functions { */ @Experimental def from_protobuf(data: Column, messageClassName: String): Column = { - ProtobufDataToCatalyst(data, messageClassName) + Column.fn( + "from_protobuf", + data, + lit(messageClassName) + ) } /** @@ -153,7 +168,12 @@ object functions { data: Column, messageClassName: String, options: java.util.Map[String, String]): Column = { - ProtobufDataToCatalyst(data, messageClassName, None, options.asScala.toMap) + Column.fnWithOptions( + "from_protobuf", + options.asScala.iterator, + data, + lit(messageClassName) + ) } /** @@ -191,7 +211,12 @@ object functions { @Experimental def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet)) + Column.fn( + "to_protobuf", + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** * Converts a column into binary of protobuf format. The Protobuf definition is provided @@ -213,7 +238,7 @@ object functions { descFilePath: String, options: java.util.Map[String, String]): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - CatalystDataToProtobuf(data, messageName, Some(fileContent), options.asScala.toMap) + to_protobuf(data, messageName, fileContent, options) } /** @@ -237,7 +262,13 @@ object functions { binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String] ): Column = { - CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap) + Column.fnWithOptions( + "to_protobuf", + options.asScala.iterator, + data, + lit(messageName), + lit(binaryFileDescriptorSet) + ) } /** @@ -257,7 +288,11 @@ object functions { */ @Experimental def to_protobuf(data: Column, messageClassName: String): Column = { - CatalystDataToProtobuf(data, messageClassName) + Column.fn( + "to_protobuf", + data, + lit(messageClassName) + ) } /** @@ -279,6 +314,11 @@ object functions { @Experimental def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) : Column = { - CatalystDataToProtobuf(data, messageClassName, None, options.asScala.toMap) + Column.fnWithOptions( + "to_protobuf", + options.asScala.iterator, + data, + lit(messageClassName) + ) } } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala index 6644bce98293b..e85097a272f24 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -43,8 +43,8 @@ private[sql] class ProtobufOptions( /** * Adds support for recursive fields. If this option is is not specified, recursive fields are - * not permitted. Setting it to 0 drops the recursive fields, 1 allows it to be recursed once, - * and 2 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not + * not permitted. Setting it to 1 drops the recursive fields, 0 allows it to be recursed once, + * and 3 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not * allowed in order avoid inadvertently creating very large schemas. If a Protobuf message * has depth beyond this limit, the Spark struct returned is truncated after the recursion limit. * @@ -52,8 +52,8 @@ private[sql] class ProtobufOptions( * `message Person { string name = 1; Person friend = 2; }` * The following lists the schema with different values for this setting. * 1: `struct` - * 2: `struct>` - * 3: `struct>>` + * 2: `struct>` + * 3: `struct>>` * and so on. */ val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt @@ -181,7 +181,7 @@ private[sql] class ProtobufOptions( val upcastUnsignedInts: Boolean = parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean - // Whether to unwrap the struct representation for well known primitve wrapper types when + // Whether to unwrap the struct representation for well known primitive wrapper types when // deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value, // google.protobuf.Int64Value, etc.) will get deserialized as structs. We allow the option to // deserialize them as their respective primitives. @@ -221,7 +221,7 @@ private[sql] class ProtobufOptions( // By default, in the spark schema field a will be dropped, which result in schema // b struct // If retain.empty.message.types=true, field a will be retained by inserting a dummy column. - // b struct, name: string> + // b struct, name: string> val retainEmptyMessage: Boolean = parameters.getOrElse("retain.empty.message.types", false.toString).toBoolean } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 4ff432cf7a055..3eaa91e472c43 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -708,7 +708,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = e, - errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND", + condition = "PROTOBUF_DEPENDENCY_NOT_FOUND", parameters = Map("dependencyName" -> "nestedenum.proto")) } @@ -1057,7 +1057,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( ex, - errorClass = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", + condition = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", parameters = Map("filePath" -> "/non/existent/path.desc") ) assert(ex.getCause != null) @@ -1699,7 +1699,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = parseError, - errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", + condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", parameters = Map( "sqlColumn" -> "`basic_enum`", "protobufColumn" -> "field 'basic_enum'", @@ -1711,7 +1711,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot } checkError( exception = parseError, - errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", + condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE", parameters = Map( "sqlColumn" -> "`basic_enum`", "protobufColumn" -> "field 'basic_enum'", @@ -2093,7 +2093,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | to_protobuf(complex_struct, 42, '$testFileDescFile', map()) |FROM protobuf_test_table |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"to_protobuf(complex_struct, 42, $testFileDescFile, map())\"""", "msg" -> ("The second argument of the TO_PROTOBUF SQL function must be a constant " + @@ -2111,11 +2111,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map()) |FROM protobuf_test_table |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> "\"to_protobuf(complex_struct, SimpleMessageJavaTypes, 42, map())\"", "msg" -> ("The third argument of the TO_PROTOBUF SQL function must be a constant " + - "string representing the Protobuf descriptor file path"), + "string or binary data representing the Protobuf descriptor file path"), "hint" -> ""), queryContext = Array(ExpectedContext( fragment = "to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map())", @@ -2130,7 +2130,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | to_protobuf(complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) |FROM protobuf_test_table |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"to_protobuf(complex_struct, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", @@ -2152,7 +2152,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot |SELECT from_protobuf(protobuf_data, 42, '$testFileDescFile', map()) |FROM ($toProtobufSql) |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"from_protobuf(protobuf_data, 42, $testFileDescFile, map())\"""", "msg" -> ("The second argument of the FROM_PROTOBUF SQL function must be a constant " + @@ -2169,11 +2169,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot |SELECT from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map()) |FROM ($toProtobufSql) |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> "\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, 42, map())\"", "msg" -> ("The third argument of the FROM_PROTOBUF SQL function must be a constant " + - "string representing the Protobuf descriptor file path"), + "string or binary data representing the Protobuf descriptor file path"), "hint" -> ""), queryContext = Array(ExpectedContext( fragment = "from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map())", @@ -2188,7 +2188,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot | from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', 42) |FROM ($toProtobufSql) |""".stripMargin)), - errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", parameters = Map( "sqlExpr" -> s"""\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""", diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index 03285c73f1ff1..2737bb9feb3ad 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -95,7 +95,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Deserializer, fieldMatch, - errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", + condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", params = Map( "protobufType" -> "MissMatchTypeInRoot", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -104,7 +104,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Serializer, fieldMatch, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "MissMatchTypeInRoot", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -122,7 +122,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { protoFile, Serializer, BY_NAME, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FieldMissingInProto", "toType" -> toSQLType(CATALYST_STRUCT))) @@ -132,7 +132,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, nonnullCatalyst, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FieldMissingInProto", "toType" -> toSQLType(nonnullCatalyst))) @@ -150,7 +150,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Deserializer, fieldMatch, catalyst, - errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", + condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE", params = Map( "protobufType" -> "MissMatchTypeInDeepNested", "toType" -> toSQLType(catalyst))) @@ -160,7 +160,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, fieldMatch, catalyst, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "MissMatchTypeInDeepNested", "toType" -> toSQLType(catalyst))) @@ -177,7 +177,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, catalystSchema = foobarSQLType, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "FoobarWithRequiredFieldBar", "toType" -> toSQLType(foobarSQLType))) @@ -199,7 +199,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { Serializer, BY_NAME, catalystSchema = nestedFoobarSQLType, - errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", + condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE", params = Map( "protobufType" -> "NestedFoobarWithRequiredFieldBar", "toType" -> toSQLType(nestedFoobarSQLType))) @@ -222,7 +222,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { checkError( exception = e1, - errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR") + condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR") val basicMessageDescWithoutImports = descriptorSetWithoutImports( ProtobufUtils.readDescriptorFileContent( @@ -240,7 +240,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { checkError( exception = e2, - errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND", + condition = "PROTOBUF_DEPENDENCY_NOT_FOUND", parameters = Map("dependencyName" -> "nestedenum.proto")) } @@ -254,7 +254,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { serdeFactory: SerdeFactory[_], fieldMatchType: MatchType, catalystSchema: StructType = CATALYST_STRUCT, - errorClass: String, + condition: String, params: Map[String, String]): Unit = { val e = intercept[AnalysisException] { @@ -274,7 +274,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase { assert(e.getMessage === expectMsg) checkError( exception = e, - errorClass = errorClass, + condition = condition, parameters = params) } diff --git a/core/benchmarks/ChecksumBenchmark-jdk21-results.txt b/core/benchmarks/ChecksumBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..85370450f355c --- /dev/null +++ b/core/benchmarks/ChecksumBenchmark-jdk21-results.txt @@ -0,0 +1,14 @@ +================================================================================================ +Benchmark Checksum Algorithms +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +AMD EPYC 7763 64-Core Processor +Checksum Algorithms: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +CRC32 2743 2746 3 0.0 2678409.9 1.0X +CRC32C 1974 2055 70 0.0 1928129.2 1.4X +Adler32 12689 12709 17 0.0 12391425.9 0.2X +hadoop PureJavaCrc32C 23027 23041 13 0.0 22487098.9 0.1X + + diff --git a/core/benchmarks/ChecksumBenchmark-results.txt b/core/benchmarks/ChecksumBenchmark-results.txt new file mode 100644 index 0000000000000..cce5a61abf637 --- /dev/null +++ b/core/benchmarks/ChecksumBenchmark-results.txt @@ -0,0 +1,14 @@ +================================================================================================ +Benchmark Checksum Algorithms +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +AMD EPYC 7763 64-Core Processor +Checksum Algorithms: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +CRC32 2757 2758 1 0.0 2692250.2 1.0X +CRC32C 2142 2244 116 0.0 2091901.8 1.3X +Adler32 12699 12712 15 0.0 12401205.6 0.2X +hadoop PureJavaCrc32C 23049 23066 15 0.0 22508320.3 0.1X + + diff --git a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt index 40b999e7ee08f..b3bffea826e5f 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt @@ -6,44 +6,44 @@ OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 654 667 12 0.0 65384.8 1.0X -Compression 10000 times at level 2 without buffer pool 701 702 0 0.0 70133.1 0.9X -Compression 10000 times at level 3 without buffer pool 798 799 1 0.0 79817.3 0.8X -Compression 10000 times at level 1 with buffer pool 593 596 3 0.0 59339.9 1.1X -Compression 10000 times at level 2 with buffer pool 629 634 7 0.0 62857.3 1.0X -Compression 10000 times at level 3 with buffer pool 737 738 1 0.0 73690.9 0.9X +Compression 10000 times at level 1 without buffer pool 657 670 14 0.0 65699.2 1.0X +Compression 10000 times at level 2 without buffer pool 697 697 1 0.0 69673.4 0.9X +Compression 10000 times at level 3 without buffer pool 799 802 3 0.0 79855.2 0.8X +Compression 10000 times at level 1 with buffer pool 593 595 1 0.0 59326.9 1.1X +Compression 10000 times at level 2 with buffer pool 622 624 3 0.0 62194.1 1.1X +Compression 10000 times at level 3 with buffer pool 732 733 1 0.0 73178.6 0.9X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 821 822 2 0.0 82050.4 1.0X -Decompression 10000 times from level 2 without buffer pool 820 821 1 0.0 82038.1 1.0X -Decompression 10000 times from level 3 without buffer pool 817 819 2 0.0 81732.1 1.0X -Decompression 10000 times from level 1 with buffer pool 745 746 1 0.0 74456.2 1.1X -Decompression 10000 times from level 2 with buffer pool 746 747 1 0.0 74590.2 1.1X -Decompression 10000 times from level 3 with buffer pool 746 747 1 0.0 74593.1 1.1X +Decompression 10000 times from level 1 without buffer pool 813 820 11 0.0 81273.2 1.0X +Decompression 10000 times from level 2 without buffer pool 810 813 3 0.0 80986.2 1.0X +Decompression 10000 times from level 3 without buffer pool 812 813 2 0.0 81183.1 1.0X +Decompression 10000 times from level 1 with buffer pool 746 747 2 0.0 74568.7 1.1X +Decompression 10000 times from level 2 with buffer pool 744 746 2 0.0 74414.5 1.1X +Decompression 10000 times from level 3 with buffer pool 745 746 1 0.0 74538.6 1.1X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 49 49 1 0.0 379018.5 1.0X -Parallel Compression with 1 workers 35 37 4 0.0 271777.5 1.4X -Parallel Compression with 2 workers 34 38 2 0.0 261820.6 1.4X -Parallel Compression with 4 workers 37 39 2 0.0 285987.9 1.3X -Parallel Compression with 8 workers 39 41 1 0.0 303005.9 1.3X -Parallel Compression with 16 workers 43 45 1 0.0 337834.5 1.1X +Parallel Compression with 0 workers 48 49 1 0.0 374256.1 1.0X +Parallel Compression with 1 workers 34 36 3 0.0 267557.3 1.4X +Parallel Compression with 2 workers 34 38 2 0.0 263684.3 1.4X +Parallel Compression with 4 workers 37 39 2 0.0 289956.1 1.3X +Parallel Compression with 8 workers 39 41 1 0.0 306975.2 1.2X +Parallel Compression with 16 workers 44 45 1 0.0 340992.0 1.1X OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 156 157 1 0.0 1215755.2 1.0X -Parallel Compression with 1 workers 186 187 3 0.0 1449769.5 0.8X -Parallel Compression with 2 workers 111 116 5 0.0 865458.3 1.4X -Parallel Compression with 4 workers 105 110 3 0.0 821557.7 1.5X -Parallel Compression with 8 workers 111 114 2 0.0 868777.0 1.4X -Parallel Compression with 16 workers 110 115 2 0.0 859766.8 1.4X +Parallel Compression with 0 workers 156 158 1 0.0 1220760.5 1.0X +Parallel Compression with 1 workers 191 192 2 0.0 1495168.2 0.8X +Parallel Compression with 2 workers 111 117 5 0.0 864459.9 1.4X +Parallel Compression with 4 workers 106 109 2 0.0 831025.5 1.5X +Parallel Compression with 8 workers 112 115 2 0.0 875732.7 1.4X +Parallel Compression with 16 workers 110 114 2 0.0 858160.9 1.4X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index 6b67147e6a63a..b230f825fecac 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -6,44 +6,44 @@ OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 641 649 10 0.0 64087.9 1.0X -Compression 10000 times at level 2 without buffer pool 688 690 2 0.0 68761.5 0.9X -Compression 10000 times at level 3 without buffer pool 777 777 1 0.0 77675.7 0.8X -Compression 10000 times at level 1 with buffer pool 574 575 0 0.0 57407.8 1.1X -Compression 10000 times at level 2 with buffer pool 604 605 1 0.0 60366.5 1.1X -Compression 10000 times at level 3 with buffer pool 708 708 1 0.0 70794.2 0.9X +Compression 10000 times at level 1 without buffer pool 638 638 0 0.0 63765.0 1.0X +Compression 10000 times at level 2 without buffer pool 675 676 1 0.0 67529.4 0.9X +Compression 10000 times at level 3 without buffer pool 775 783 11 0.0 77531.6 0.8X +Compression 10000 times at level 1 with buffer pool 572 573 1 0.0 57223.2 1.1X +Compression 10000 times at level 2 with buffer pool 603 605 1 0.0 60323.7 1.1X +Compression 10000 times at level 3 with buffer pool 720 727 6 0.0 71980.9 0.9X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 585 586 1 0.0 58531.7 1.0X -Decompression 10000 times from level 2 without buffer pool 585 587 2 0.0 58496.8 1.0X -Decompression 10000 times from level 3 without buffer pool 588 589 1 0.0 58831.8 1.0X -Decompression 10000 times from level 1 with buffer pool 533 534 1 0.0 53331.8 1.1X -Decompression 10000 times from level 2 with buffer pool 533 534 0 0.0 53324.1 1.1X -Decompression 10000 times from level 3 with buffer pool 533 534 0 0.0 53303.4 1.1X +Decompression 10000 times from level 1 without buffer pool 584 585 1 0.0 58381.0 1.0X +Decompression 10000 times from level 2 without buffer pool 585 585 0 0.0 58465.9 1.0X +Decompression 10000 times from level 3 without buffer pool 585 586 1 0.0 58499.5 1.0X +Decompression 10000 times from level 1 with buffer pool 534 534 0 0.0 53375.7 1.1X +Decompression 10000 times from level 2 with buffer pool 533 533 0 0.0 53312.3 1.1X +Decompression 10000 times from level 3 with buffer pool 533 533 1 0.0 53255.1 1.1X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 47 48 1 0.0 364123.8 1.0X -Parallel Compression with 1 workers 34 36 3 0.0 268638.6 1.4X -Parallel Compression with 2 workers 32 36 2 0.0 252026.9 1.4X -Parallel Compression with 4 workers 35 38 4 0.0 271762.4 1.3X -Parallel Compression with 8 workers 38 40 1 0.0 298137.9 1.2X -Parallel Compression with 16 workers 42 44 1 0.0 324881.0 1.1X +Parallel Compression with 0 workers 46 48 1 0.0 360483.5 1.0X +Parallel Compression with 1 workers 34 36 2 0.0 265816.1 1.4X +Parallel Compression with 2 workers 33 36 2 0.0 254525.8 1.4X +Parallel Compression with 4 workers 34 37 1 0.0 266270.8 1.4X +Parallel Compression with 8 workers 37 39 1 0.0 289289.2 1.2X +Parallel Compression with 16 workers 41 43 1 0.0 320243.3 1.1X OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 155 157 1 0.0 1210833.3 1.0X -Parallel Compression with 1 workers 192 193 3 0.0 1500386.2 0.8X -Parallel Compression with 2 workers 114 121 9 0.0 888645.9 1.4X -Parallel Compression with 4 workers 106 109 2 0.0 830468.4 1.5X -Parallel Compression with 8 workers 110 113 2 0.0 857123.0 1.4X -Parallel Compression with 16 workers 109 114 3 0.0 854349.3 1.4X +Parallel Compression with 0 workers 154 156 2 0.0 1205934.0 1.0X +Parallel Compression with 1 workers 191 194 4 0.0 1495729.9 0.8X +Parallel Compression with 2 workers 110 114 5 0.0 859158.9 1.4X +Parallel Compression with 4 workers 105 108 3 0.0 822932.2 1.5X +Parallel Compression with 8 workers 109 113 2 0.0 851560.0 1.4X +Parallel Compression with 16 workers 111 115 2 0.0 870695.9 1.4X diff --git a/core/pom.xml b/core/pom.xml index 0a339e11a5d20..19f58940ed942 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -32,7 +32,6 @@ core - **/OpenTelemetry*.scala @@ -122,19 +121,14 @@ io.jsonwebtoken jjwt-api - 0.12.6 io.jsonwebtoken jjwt-impl - 0.12.6 - test io.jsonwebtoken jjwt-jackson - 0.12.6 - test @@ -627,34 +613,10 @@ - opentelemetry + jjwt - + compile - - - io.opentelemetry - opentelemetry-exporter-otlp - 1.41.0 - - - io.opentelemetry - opentelemetry-sdk-extension-autoconfigure-spi - - - - - io.opentelemetry - opentelemetry-sdk - 1.41.0 - - - com.squareup.okhttp3 - okhttp - 3.12.12 - test - - sparkr diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index 4e251a1c2901b..412d612c7f1d5 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -17,6 +17,7 @@ package org.apache.spark.io; import org.apache.spark.storage.StorageUtils; +import org.apache.spark.unsafe.Platform; import java.io.File; import java.io.IOException; @@ -47,7 +48,7 @@ public final class NioBufferedFileInputStream extends InputStream { private final FileChannel fileChannel; public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException { - byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes); + byteBuffer = Platform.allocateDirectBuffer(bufferSizeInBytes); fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); byteBuffer.flip(); this.cleanable = CLEANER.register(this, new ResourceCleaner(fileChannel, byteBuffer)); diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.css deleted file mode 100644 index 6db36f6e75d39..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.css +++ /dev/null @@ -1 +0,0 @@ -:root{--dt-row-selected: 2, 117, 216;--dt-row-selected-text: 255, 255, 255;--dt-row-selected-link: 9, 10, 11;--dt-row-stripe: 0, 0, 0;--dt-row-hover: 0, 0, 0;--dt-column-ordering: 0, 0, 0;--dt-html-background: white}:root.dark{--dt-html-background: rgb(33, 37, 41)}table.dataTable td.dt-control{text-align:center;cursor:pointer}table.dataTable td.dt-control:before{display:inline-block;color:rgba(0, 0, 0, 0.5);content:"►"}table.dataTable tr.dt-hasChild td.dt-control:before{content:"▼"}html.dark table.dataTable td.dt-control:before{color:rgba(255, 255, 255, 0.5)}html.dark table.dataTable tr.dt-hasChild td.dt-control:before{color:rgba(255, 255, 255, 0.5)}table.dataTable thead>tr>th.sorting,table.dataTable thead>tr>th.sorting_asc,table.dataTable thead>tr>th.sorting_desc,table.dataTable thead>tr>th.sorting_asc_disabled,table.dataTable thead>tr>th.sorting_desc_disabled,table.dataTable thead>tr>td.sorting,table.dataTable thead>tr>td.sorting_asc,table.dataTable thead>tr>td.sorting_desc,table.dataTable thead>tr>td.sorting_asc_disabled,table.dataTable thead>tr>td.sorting_desc_disabled{cursor:pointer;position:relative;padding-right:26px}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after{position:absolute;display:block;opacity:.125;right:10px;line-height:9px;font-size:.8em}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:before{bottom:50%;content:"▲";content:"▲"/""}table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:after{top:50%;content:"▼";content:"▼"/""}table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:after{opacity:.6}table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting_asc_disabled:before{display:none}table.dataTable thead>tr>th:active,table.dataTable thead>tr>td:active{outline:none}div.dataTables_scrollBody>table.dataTable>thead>tr>th:before,div.dataTables_scrollBody>table.dataTable>thead>tr>th:after,div.dataTables_scrollBody>table.dataTable>thead>tr>td:before,div.dataTables_scrollBody>table.dataTable>thead>tr>td:after{display:none}div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:2px}div.dataTables_processing>div:last-child{position:relative;width:80px;height:15px;margin:1em auto}div.dataTables_processing>div:last-child>div{position:absolute;top:0;width:13px;height:13px;border-radius:50%;background:rgb(2, 117, 216);background:rgb(var(--dt-row-selected));animation-timing-function:cubic-bezier(0, 1, 1, 0)}div.dataTables_processing>div:last-child>div:nth-child(1){left:8px;animation:datatables-loader-1 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(2){left:8px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(3){left:32px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(4){left:56px;animation:datatables-loader-3 .6s infinite}@keyframes datatables-loader-1{0%{transform:scale(0)}100%{transform:scale(1)}}@keyframes datatables-loader-3{0%{transform:scale(1)}100%{transform:scale(0)}}@keyframes datatables-loader-2{0%{transform:translate(0, 0)}100%{transform:translate(24px, 0)}}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th,table.dataTable thead td,table.dataTable tfoot th,table.dataTable tfoot td{text-align:left}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable{clear:both;margin-top:6px !important;margin-bottom:6px !important;max-width:none !important;border-collapse:separate !important;border-spacing:0}table.dataTable td,table.dataTable th{-webkit-box-sizing:content-box;box-sizing:content-box}table.dataTable td.dataTables_empty,table.dataTable th.dataTables_empty{text-align:center}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.table-striped>tbody>tr:nth-of-type(2n+1){background-color:transparent}table.dataTable>tbody>tr{background-color:transparent}table.dataTable>tbody>tr.selected>*{box-shadow:inset 0 0 0 9999px rgb(2, 117, 216);box-shadow:inset 0 0 0 9999px rgb(var(--dt-row-selected));color:rgb(255, 255, 255);color:rgb(var(--dt-row-selected-text))}table.dataTable>tbody>tr.selected a{color:rgb(9, 10, 11);color:rgb(var(--dt-row-selected-link))}table.dataTable.table-striped>tbody>tr.odd>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-stripe), 0.05)}table.dataTable.table-striped>tbody>tr.odd.selected>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.95);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.95)}table.dataTable.table-hover>tbody>tr:hover>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-hover), 0.075)}table.dataTable.table-hover>tbody>tr.selected:hover>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.975);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.975)}div.dataTables_wrapper div.dataTables_length label{font-weight:normal;text-align:left;white-space:nowrap}div.dataTables_wrapper div.dataTables_length select{width:auto;display:inline-block}div.dataTables_wrapper div.dataTables_filter{text-align:right}div.dataTables_wrapper div.dataTables_filter label{font-weight:normal;white-space:nowrap;text-align:left}div.dataTables_wrapper div.dataTables_filter input{margin-left:.5em;display:inline-block;width:auto}div.dataTables_wrapper div.dataTables_info{padding-top:.85em}div.dataTables_wrapper div.dataTables_paginate{margin:0;white-space:nowrap;text-align:right}div.dataTables_wrapper div.dataTables_paginate ul.pagination{margin:2px 0;white-space:nowrap;justify-content:flex-end}div.dataTables_wrapper div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:1em 0}div.dataTables_scrollHead table.dataTable{margin-bottom:0 !important}div.dataTables_scrollBody>table{border-top:none;margin-top:0 !important;margin-bottom:0 !important}div.dataTables_scrollBody>table>thead .sorting:before,div.dataTables_scrollBody>table>thead .sorting_asc:before,div.dataTables_scrollBody>table>thead .sorting_desc:before,div.dataTables_scrollBody>table>thead .sorting:after,div.dataTables_scrollBody>table>thead .sorting_asc:after,div.dataTables_scrollBody>table>thead .sorting_desc:after{display:none}div.dataTables_scrollBody>table>tbody tr:first-child th,div.dataTables_scrollBody>table>tbody tr:first-child td{border-top:none}div.dataTables_scrollFoot>.dataTables_scrollFootInner{box-sizing:content-box}div.dataTables_scrollFoot>.dataTables_scrollFootInner>table{margin-top:0 !important;border-top:none}@media screen and (max-width: 767px){div.dataTables_wrapper div.dataTables_length,div.dataTables_wrapper div.dataTables_filter,div.dataTables_wrapper div.dataTables_info,div.dataTables_wrapper div.dataTables_paginate{text-align:center}div.dataTables_wrapper div.dataTables_paginate ul.pagination{justify-content:center !important}}table.dataTable.table-sm>thead>tr>th:not(.sorting_disabled){padding-right:20px}table.table-bordered.dataTable{border-right-width:0}table.table-bordered.dataTable th,table.table-bordered.dataTable td{border-left-width:0}table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable td:last-child,table.table-bordered.dataTable td:last-child{border-right-width:1px}table.table-bordered.dataTable tbody th,table.table-bordered.dataTable tbody td{border-bottom-width:0}div.dataTables_scrollHead table.table-bordered{border-bottom-width:0}div.table-responsive>div.dataTables_wrapper>div.row{margin:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:first-child{padding-left:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:last-child{padding-right:0} diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.css b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.css new file mode 100644 index 0000000000000..d344f78a39748 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.css @@ -0,0 +1 @@ +:root{--dt-row-selected: 2, 117, 216;--dt-row-selected-text: 255, 255, 255;--dt-row-selected-link: 9, 10, 11;--dt-row-stripe: 0, 0, 0;--dt-row-hover: 0, 0, 0;--dt-column-ordering: 0, 0, 0;--dt-html-background: white}:root.dark{--dt-html-background: rgb(33, 37, 41)}table.dataTable td.dt-control{text-align:center;cursor:pointer}table.dataTable td.dt-control:before{display:inline-block;color:rgba(0, 0, 0, 0.5);content:"▶"}table.dataTable tr.dt-hasChild td.dt-control:before{content:"▼"}html.dark table.dataTable td.dt-control:before,:root[data-bs-theme=dark] table.dataTable td.dt-control:before{color:rgba(255, 255, 255, 0.5)}html.dark table.dataTable tr.dt-hasChild td.dt-control:before,:root[data-bs-theme=dark] table.dataTable tr.dt-hasChild td.dt-control:before{color:rgba(255, 255, 255, 0.5)}table.dataTable thead>tr>th.sorting,table.dataTable thead>tr>th.sorting_asc,table.dataTable thead>tr>th.sorting_desc,table.dataTable thead>tr>th.sorting_asc_disabled,table.dataTable thead>tr>th.sorting_desc_disabled,table.dataTable thead>tr>td.sorting,table.dataTable thead>tr>td.sorting_asc,table.dataTable thead>tr>td.sorting_desc,table.dataTable thead>tr>td.sorting_asc_disabled,table.dataTable thead>tr>td.sorting_desc_disabled{cursor:pointer;position:relative;padding-right:26px}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after{position:absolute;display:block;opacity:.125;right:10px;line-height:9px;font-size:.8em}table.dataTable thead>tr>th.sorting:before,table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:before,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>th.sorting_desc_disabled:before,table.dataTable thead>tr>td.sorting:before,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:before,table.dataTable thead>tr>td.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:before{bottom:50%;content:"▲";content:"▲"/""}table.dataTable thead>tr>th.sorting:after,table.dataTable thead>tr>th.sorting_asc:after,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>th.sorting_asc_disabled:after,table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting:after,table.dataTable thead>tr>td.sorting_asc:after,table.dataTable thead>tr>td.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc_disabled:after,table.dataTable thead>tr>td.sorting_desc_disabled:after{top:50%;content:"▼";content:"▼"/""}table.dataTable thead>tr>th.sorting_asc:before,table.dataTable thead>tr>th.sorting_desc:after,table.dataTable thead>tr>td.sorting_asc:before,table.dataTable thead>tr>td.sorting_desc:after{opacity:.6}table.dataTable thead>tr>th.sorting_desc_disabled:after,table.dataTable thead>tr>th.sorting_asc_disabled:before,table.dataTable thead>tr>td.sorting_desc_disabled:after,table.dataTable thead>tr>td.sorting_asc_disabled:before{display:none}table.dataTable thead>tr>th:active,table.dataTable thead>tr>td:active{outline:none}div.dataTables_scrollBody>table.dataTable>thead>tr>th:before,div.dataTables_scrollBody>table.dataTable>thead>tr>th:after,div.dataTables_scrollBody>table.dataTable>thead>tr>td:before,div.dataTables_scrollBody>table.dataTable>thead>tr>td:after{display:none}div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:2px;z-index:10}div.dataTables_processing>div:last-child{position:relative;width:80px;height:15px;margin:1em auto}div.dataTables_processing>div:last-child>div{position:absolute;top:0;width:13px;height:13px;border-radius:50%;background:rgb(2, 117, 216);background:rgb(var(--dt-row-selected));animation-timing-function:cubic-bezier(0, 1, 1, 0)}div.dataTables_processing>div:last-child>div:nth-child(1){left:8px;animation:datatables-loader-1 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(2){left:8px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(3){left:32px;animation:datatables-loader-2 .6s infinite}div.dataTables_processing>div:last-child>div:nth-child(4){left:56px;animation:datatables-loader-3 .6s infinite}@keyframes datatables-loader-1{0%{transform:scale(0)}100%{transform:scale(1)}}@keyframes datatables-loader-3{0%{transform:scale(1)}100%{transform:scale(0)}}@keyframes datatables-loader-2{0%{transform:translate(0, 0)}100%{transform:translate(24px, 0)}}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable th.dt-left,table.dataTable td.dt-left{text-align:left}table.dataTable th.dt-center,table.dataTable td.dt-center,table.dataTable td.dataTables_empty{text-align:center}table.dataTable th.dt-right,table.dataTable td.dt-right{text-align:right}table.dataTable th.dt-justify,table.dataTable td.dt-justify{text-align:justify}table.dataTable th.dt-nowrap,table.dataTable td.dt-nowrap{white-space:nowrap}table.dataTable thead th,table.dataTable thead td,table.dataTable tfoot th,table.dataTable tfoot td{text-align:left}table.dataTable thead th.dt-head-left,table.dataTable thead td.dt-head-left,table.dataTable tfoot th.dt-head-left,table.dataTable tfoot td.dt-head-left{text-align:left}table.dataTable thead th.dt-head-center,table.dataTable thead td.dt-head-center,table.dataTable tfoot th.dt-head-center,table.dataTable tfoot td.dt-head-center{text-align:center}table.dataTable thead th.dt-head-right,table.dataTable thead td.dt-head-right,table.dataTable tfoot th.dt-head-right,table.dataTable tfoot td.dt-head-right{text-align:right}table.dataTable thead th.dt-head-justify,table.dataTable thead td.dt-head-justify,table.dataTable tfoot th.dt-head-justify,table.dataTable tfoot td.dt-head-justify{text-align:justify}table.dataTable thead th.dt-head-nowrap,table.dataTable thead td.dt-head-nowrap,table.dataTable tfoot th.dt-head-nowrap,table.dataTable tfoot td.dt-head-nowrap{white-space:nowrap}table.dataTable tbody th.dt-body-left,table.dataTable tbody td.dt-body-left{text-align:left}table.dataTable tbody th.dt-body-center,table.dataTable tbody td.dt-body-center{text-align:center}table.dataTable tbody th.dt-body-right,table.dataTable tbody td.dt-body-right{text-align:right}table.dataTable tbody th.dt-body-justify,table.dataTable tbody td.dt-body-justify{text-align:justify}table.dataTable tbody th.dt-body-nowrap,table.dataTable tbody td.dt-body-nowrap{white-space:nowrap}table.dataTable{clear:both;margin-top:6px !important;margin-bottom:6px !important;max-width:none !important;border-collapse:separate !important;border-spacing:0}table.dataTable td,table.dataTable th{-webkit-box-sizing:content-box;box-sizing:content-box}table.dataTable td.dataTables_empty,table.dataTable th.dataTables_empty{text-align:center}table.dataTable.nowrap th,table.dataTable.nowrap td{white-space:nowrap}table.dataTable.table-striped>tbody>tr:nth-of-type(2n+1){background-color:transparent}table.dataTable>tbody>tr{background-color:transparent}table.dataTable>tbody>tr.selected>*{box-shadow:inset 0 0 0 9999px rgb(2, 117, 216);box-shadow:inset 0 0 0 9999px rgb(var(--dt-row-selected));color:rgb(255, 255, 255);color:rgb(var(--dt-row-selected-text))}table.dataTable>tbody>tr.selected a{color:rgb(9, 10, 11);color:rgb(var(--dt-row-selected-link))}table.dataTable.table-striped>tbody>tr.odd>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-stripe), 0.05)}table.dataTable.table-striped>tbody>tr.odd.selected>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.95);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.95)}table.dataTable.table-hover>tbody>tr:hover>*{box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-hover), 0.075)}table.dataTable.table-hover>tbody>tr.selected:hover>*{box-shadow:inset 0 0 0 9999px rgba(2, 117, 216, 0.975);box-shadow:inset 0 0 0 9999px rgba(var(--dt-row-selected), 0.975)}div.dataTables_wrapper div.dataTables_length label{font-weight:normal;text-align:left;white-space:nowrap}div.dataTables_wrapper div.dataTables_length select{width:auto;display:inline-block}div.dataTables_wrapper div.dataTables_filter{text-align:right}div.dataTables_wrapper div.dataTables_filter label{font-weight:normal;white-space:nowrap;text-align:left}div.dataTables_wrapper div.dataTables_filter input{margin-left:.5em;display:inline-block;width:auto}div.dataTables_wrapper div.dataTables_info{padding-top:.85em}div.dataTables_wrapper div.dataTables_paginate{margin:0;white-space:nowrap;text-align:right}div.dataTables_wrapper div.dataTables_paginate ul.pagination{margin:2px 0;white-space:nowrap;justify-content:flex-end}div.dataTables_wrapper div.dataTables_processing{position:absolute;top:50%;left:50%;width:200px;margin-left:-100px;margin-top:-26px;text-align:center;padding:1em 0}div.dataTables_scrollHead table.dataTable{margin-bottom:0 !important}div.dataTables_scrollBody>table{border-top:none;margin-top:0 !important;margin-bottom:0 !important}div.dataTables_scrollBody>table>thead .sorting:before,div.dataTables_scrollBody>table>thead .sorting_asc:before,div.dataTables_scrollBody>table>thead .sorting_desc:before,div.dataTables_scrollBody>table>thead .sorting:after,div.dataTables_scrollBody>table>thead .sorting_asc:after,div.dataTables_scrollBody>table>thead .sorting_desc:after{display:none}div.dataTables_scrollBody>table>tbody tr:first-child th,div.dataTables_scrollBody>table>tbody tr:first-child td{border-top:none}div.dataTables_scrollFoot>.dataTables_scrollFootInner{box-sizing:content-box}div.dataTables_scrollFoot>.dataTables_scrollFootInner>table{margin-top:0 !important;border-top:none}@media screen and (max-width: 767px){div.dataTables_wrapper div.dataTables_length,div.dataTables_wrapper div.dataTables_filter,div.dataTables_wrapper div.dataTables_info,div.dataTables_wrapper div.dataTables_paginate{text-align:center}div.dataTables_wrapper div.dataTables_paginate ul.pagination{justify-content:center !important}}table.dataTable.table-sm>thead>tr>th:not(.sorting_disabled){padding-right:20px}table.table-bordered.dataTable{border-right-width:0}table.table-bordered.dataTable th,table.table-bordered.dataTable td{border-left-width:0}table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable th:last-child,table.table-bordered.dataTable td:last-child,table.table-bordered.dataTable td:last-child{border-right-width:1px}table.table-bordered.dataTable tbody th,table.table-bordered.dataTable tbody td{border-bottom-width:0}div.dataTables_scrollHead table.table-bordered{border-bottom-width:0}div.table-responsive>div.dataTables_wrapper>div.row{margin:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:first-child{padding-left:0}div.table-responsive>div.dataTables_wrapper>div.row>div[class^=col-]:last-child{padding-right:0} diff --git a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.js b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.js similarity index 83% rename from core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.js rename to core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.js index 04de9c97cc514..c99016713ab1f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.1.13.5.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dataTables.bootstrap4.min.js @@ -1,4 +1,4 @@ /*! DataTables Bootstrap 4 integration * ©2011-2017 SpryMedia Ltd - datatables.net/license */ -!function(t){var n,o;"function"==typeof define&&define.amd?define(["jquery","datatables.net"],function(e){return t(e,window,document)}):"object"==typeof exports?(n=require("jquery"),o=function(e,a){a.fn.dataTable||require("datatables.net")(e,a)},"undefined"==typeof window?module.exports=function(e,a){return e=e||window,a=a||n(e),o(e,a),t(a,0,e.document)}:(o(window,n),module.exports=t(n,window,window.document))):t(jQuery,window,document)}(function(x,e,n,o){"use strict";var r=x.fn.dataTable;return x.extend(!0,r.defaults,{dom:"<'row'<'col-sm-12 col-md-6'l><'col-sm-12 col-md-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-12 col-md-5'i><'col-sm-12 col-md-7'p>>",renderer:"bootstrap"}),x.extend(r.ext.classes,{sWrapper:"dataTables_wrapper dt-bootstrap4",sFilterInput:"form-control form-control-sm",sLengthSelect:"custom-select custom-select-sm form-control form-control-sm",sProcessing:"dataTables_processing card",sPageButton:"paginate_button page-item"}),r.ext.renderer.pageButton.bootstrap=function(i,e,d,a,l,c){function u(e,a){for(var t,n,o=function(e){e.preventDefault(),x(e.currentTarget).hasClass("disabled")||m.page()==e.data.action||m.page(e.data.action).draw("page")},r=0,s=a.length;r",{class:b.sPageButton+" "+f,id:0===d&&"string"==typeof t?i.sTableId+"_"+t:null}).append(x("",{href:n?null:"#","aria-controls":i.sTableId,"aria-disabled":n?"true":null,"aria-label":w[t],role:"link","aria-current":"active"===f?"page":null,"data-dt-idx":t,tabindex:i.iTabIndex,class:"page-link"}).html(p)).appendTo(e),i.oApi._fnBindAction(n,{action:t},o))}}var p,f,t,m=new r.Api(i),b=i.oClasses,g=i.oLanguage.oPaginate,w=i.oLanguage.oAria.paginate||{};try{t=x(e).find(n.activeElement).data("dt-idx")}catch(e){}u(x(e).empty().html('
    ').children("ul"),a),t!==o&&x(e).find("[data-dt-idx="+t+"]").trigger("focus")},r}); \ No newline at end of file +!function(t){var n,o;"function"==typeof define&&define.amd?define(["jquery","datatables.net"],function(e){return t(e,window,document)}):"object"==typeof exports?(n=require("jquery"),o=function(e,a){a.fn.dataTable||require("datatables.net")(e,a)},"undefined"==typeof window?module.exports=function(e,a){return e=e||window,a=a||n(e),o(e,a),t(a,0,e.document)}:(o(window,n),module.exports=t(n,window,window.document))):t(jQuery,window,document)}(function(x,e,n,o){"use strict";var r=x.fn.dataTable;return x.extend(!0,r.defaults,{dom:"<'row'<'col-sm-12 col-md-6'l><'col-sm-12 col-md-6'f>><'row'<'col-sm-12'tr>><'row'<'col-sm-12 col-md-5'i><'col-sm-12 col-md-7'p>>",renderer:"bootstrap"}),x.extend(r.ext.classes,{sWrapper:"dataTables_wrapper dt-bootstrap4",sFilterInput:"form-control form-control-sm",sLengthSelect:"custom-select custom-select-sm form-control form-control-sm",sProcessing:"dataTables_processing card",sPageButton:"paginate_button page-item"}),r.ext.renderer.pageButton.bootstrap=function(i,e,d,a,l,c){function u(e,a){for(var t,n,o=function(e){e.preventDefault(),x(e.currentTarget).hasClass("disabled")||m.page()==e.data.action||m.page(e.data.action).draw("page")},r=0,s=a.length;r",{class:b.sPageButton+" "+f,id:0===d&&"string"==typeof t?i.sTableId+"_"+t:null}).append(x("",{href:n?null:"#","aria-controls":i.sTableId,"aria-disabled":n?"true":null,"aria-label":w[t],role:"link","aria-current":"active"===f?"page":null,"data-dt-idx":t,tabindex:n?-1:i.iTabIndex,class:"page-link"}).html(p)).appendTo(e),i.oApi._fnBindAction(n,{action:t},o))}}var p,f,t,m=new r.Api(i),b=i.oClasses,g=i.oLanguage.oPaginate,w=i.oLanguage.oAria.paginate||{};try{t=x(e).find(n.activeElement).data("dt-idx")}catch(e){}u(x(e).empty().html('
      ').children("ul"),a),t!==o&&x(e).find("[data-dt-idx="+t+"]").trigger("focus")},r}); \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.13.5.min.js b/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.13.5.min.js deleted file mode 100644 index e71f4cd8ec92a..0000000000000 --- a/core/src/main/resources/org/apache/spark/ui/static/jquery.dataTables.1.13.5.min.js +++ /dev/null @@ -1,4 +0,0 @@ -/*! DataTables 1.13.5 - * ©2008-2023 SpryMedia Ltd - datatables.net/license - */ -!function(n){"use strict";var a;"function"==typeof define&&define.amd?define(["jquery"],function(t){return n(t,window,document)}):"object"==typeof exports?(a=require("jquery"),"undefined"!=typeof window?module.exports=function(t,e){return t=t||window,e=e||a(t),n(e,t,t.document)}:n(a,window,window.document)):window.DataTable=n(jQuery,window,document)}(function(P,j,y,H){"use strict";function d(t){var e=parseInt(t,10);return!isNaN(e)&&isFinite(t)?e:null}function l(t,e,n){var a=typeof t,r="string"==a;return"number"==a||"bigint"==a||!!h(t)||(e&&r&&(t=$(t,e)),n&&r&&(t=t.replace(q,"")),!isNaN(parseFloat(t))&&isFinite(t))}function a(t,e,n){var a;return!!h(t)||(h(a=t)||"string"==typeof a)&&!!l(t.replace(V,"").replace(/ + - + } @@ -446,16 +449,24 @@ private[spark] object UIUtils extends Logging { val startRatio = if (total == 0) 0.0 else (boundedStarted.toDouble / total) * 100 val startWidth = "width: %s%%".format(startRatio) + val killTaskReasonText = reasonToNumKilled.toSeq.sortBy(-_._2).map { + case (reason, count) => s" ($count killed: $reason)" + }.mkString + val progressTitle = s"$completed/$total" + { + if (started > 0) s" ($started running)" else "" + } + { + if (failed > 0) s" ($failed failed)" else "" + } + { + if (skipped > 0) s" ($skipped skipped)" else "" + } + killTaskReasonText +
      - + {completed}/{total} { if (failed == 0 && skipped == 0 && started > 0) s"($started running)" } { if (failed > 0) s"($failed failed)" } { if (skipped > 0) s"($skipped skipped)" } - { reasonToNumKilled.toSeq.sortBy(-_._2).map { - case (reason, count) => s"($count killed: $reason)" - } - } + { killTaskReasonText }
      diff --git a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala index a4145bb36acc9..1683e892511f9 100644 --- a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala @@ -57,7 +57,7 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS if (newCapacity < minCapacity) newCapacity = minCapacity val oldBuffer = buffer oldBuffer.flip() - val newBuffer = ByteBuffer.allocateDirect(newCapacity) + val newBuffer = Platform.allocateDirectBuffer(newCapacity) newBuffer.put(oldBuffer) StorageUtils.dispose(oldBuffer) buffer = newBuffer diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 19cefbc0479a9..e30380f41566a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.JsonMethods.compact import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.internal.config._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.{DeterministicLevel, RDDOperationScope} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceInformation, ResourceProfile, TaskResourceRequest} @@ -37,6 +38,16 @@ import org.apache.spark.storage._ import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils.weakIntern +/** + * Helper class for passing configuration options to JsonProtocol. + * We use this instead of passing SparkConf directly because it lets us avoid + * repeated re-parsing of configuration values on each read. + */ +private[spark] class JsonProtocolOptions(conf: SparkConf) { + val includeTaskMetricsAccumulators: Boolean = + conf.get(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS) +} + /** * Serializes SparkListener events to/from JSON. This protocol provides strong backwards- * and forwards-compatibility guarantees: any version of Spark should be able to read JSON output @@ -55,30 +66,41 @@ import org.apache.spark.util.Utils.weakIntern private[spark] object JsonProtocol extends JsonUtils { // TODO: Remove this file and put JSON serialization into each individual class. + private[util] + val defaultOptions: JsonProtocolOptions = new JsonProtocolOptions(new SparkConf(false)) + /** ------------------------------------------------- * * JSON serialization methods for SparkListenerEvents | * -------------------------------------------------- */ + // Only for use in tests. Production code should use the two-argument overload defined below. def sparkEventToJsonString(event: SparkListenerEvent): String = { + sparkEventToJsonString(event, defaultOptions) + } + + def sparkEventToJsonString(event: SparkListenerEvent, options: JsonProtocolOptions): String = { toJsonString { generator => - writeSparkEventToJson(event, generator) + writeSparkEventToJson(event, generator, options) } } - def writeSparkEventToJson(event: SparkListenerEvent, g: JsonGenerator): Unit = { + def writeSparkEventToJson( + event: SparkListenerEvent, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => - stageSubmittedToJson(stageSubmitted, g) + stageSubmittedToJson(stageSubmitted, g, options) case stageCompleted: SparkListenerStageCompleted => - stageCompletedToJson(stageCompleted, g) + stageCompletedToJson(stageCompleted, g, options) case taskStart: SparkListenerTaskStart => - taskStartToJson(taskStart, g) + taskStartToJson(taskStart, g, options) case taskGettingResult: SparkListenerTaskGettingResult => - taskGettingResultToJson(taskGettingResult, g) + taskGettingResultToJson(taskGettingResult, g, options) case taskEnd: SparkListenerTaskEnd => - taskEndToJson(taskEnd, g) + taskEndToJson(taskEnd, g, options) case jobStart: SparkListenerJobStart => - jobStartToJson(jobStart, g) + jobStartToJson(jobStart, g, options) case jobEnd: SparkListenerJobEnd => jobEndToJson(jobEnd, g) case environmentUpdate: SparkListenerEnvironmentUpdate => @@ -112,12 +134,15 @@ private[spark] object JsonProtocol extends JsonUtils { } } - def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted, g: JsonGenerator): Unit = { + def stageSubmittedToJson( + stageSubmitted: SparkListenerStageSubmitted, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) g.writeFieldName("Stage Info") // SPARK-42205: don't log accumulables in start events: - stageInfoToJson(stageSubmitted.stageInfo, g, includeAccumulables = false) + stageInfoToJson(stageSubmitted.stageInfo, g, options, includeAccumulables = false) Option(stageSubmitted.properties).foreach { properties => g.writeFieldName("Properties") propertiesToJson(properties, g) @@ -125,38 +150,48 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted, g: JsonGenerator): Unit = { + def stageCompletedToJson( + stageCompleted: SparkListenerStageCompleted, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) g.writeFieldName("Stage Info") - stageInfoToJson(stageCompleted.stageInfo, g, includeAccumulables = true) + stageInfoToJson(stageCompleted.stageInfo, g, options, includeAccumulables = true) g.writeEndObject() } - def taskStartToJson(taskStart: SparkListenerTaskStart, g: JsonGenerator): Unit = { + def taskStartToJson( + taskStart: SparkListenerTaskStart, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) g.writeNumberField("Stage ID", taskStart.stageId) g.writeNumberField("Stage Attempt ID", taskStart.stageAttemptId) g.writeFieldName("Task Info") // SPARK-42205: don't log accumulables in start events: - taskInfoToJson(taskStart.taskInfo, g, includeAccumulables = false) + taskInfoToJson(taskStart.taskInfo, g, options, includeAccumulables = false) g.writeEndObject() } def taskGettingResultToJson( taskGettingResult: SparkListenerTaskGettingResult, - g: JsonGenerator): Unit = { + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { val taskInfo = taskGettingResult.taskInfo g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) g.writeFieldName("Task Info") // SPARK-42205: don't log accumulables in "task getting result" events: - taskInfoToJson(taskInfo, g, includeAccumulables = false) + taskInfoToJson(taskInfo, g, options, includeAccumulables = false) g.writeEndObject() } - def taskEndToJson(taskEnd: SparkListenerTaskEnd, g: JsonGenerator): Unit = { + def taskEndToJson( + taskEnd: SparkListenerTaskEnd, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) g.writeNumberField("Stage ID", taskEnd.stageId) @@ -165,7 +200,7 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeFieldName("Task End Reason") taskEndReasonToJson(taskEnd.reason, g) g.writeFieldName("Task Info") - taskInfoToJson(taskEnd.taskInfo, g, includeAccumulables = true) + taskInfoToJson(taskEnd.taskInfo, g, options, includeAccumulables = true) g.writeFieldName("Task Executor Metrics") executorMetricsToJson(taskEnd.taskExecutorMetrics, g) Option(taskEnd.taskMetrics).foreach { m => @@ -175,7 +210,10 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - def jobStartToJson(jobStart: SparkListenerJobStart, g: JsonGenerator): Unit = { + def jobStartToJson( + jobStart: SparkListenerJobStart, + g: JsonGenerator, + options: JsonProtocolOptions): Unit = { g.writeStartObject() g.writeStringField("Event", SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) g.writeNumberField("Job ID", jobStart.jobId) @@ -186,7 +224,7 @@ private[spark] object JsonProtocol extends JsonUtils { // the job was submitted: it is technically possible for a stage to belong to multiple // concurrent jobs, so this situation can arise even without races occurring between // event logging and stage completion. - jobStart.stageInfos.foreach(stageInfoToJson(_, g, includeAccumulables = true)) + jobStart.stageInfos.foreach(stageInfoToJson(_, g, options, includeAccumulables = true)) g.writeEndArray() g.writeArrayFieldStart("Stage IDs") jobStart.stageIds.foreach(g.writeNumber) @@ -386,6 +424,7 @@ private[spark] object JsonProtocol extends JsonUtils { def stageInfoToJson( stageInfo: StageInfo, g: JsonGenerator, + options: JsonProtocolOptions, includeAccumulables: Boolean): Unit = { g.writeStartObject() g.writeNumberField("Stage ID", stageInfo.stageId) @@ -404,7 +443,10 @@ private[spark] object JsonProtocol extends JsonUtils { stageInfo.failureReason.foreach(g.writeStringField("Failure Reason", _)) g.writeFieldName("Accumulables") if (includeAccumulables) { - accumulablesToJson(stageInfo.accumulables.values, g) + accumulablesToJson( + stageInfo.accumulables.values, + g, + includeTaskMetricsAccumulators = options.includeTaskMetricsAccumulators) } else { g.writeStartArray() g.writeEndArray() @@ -418,6 +460,7 @@ private[spark] object JsonProtocol extends JsonUtils { def taskInfoToJson( taskInfo: TaskInfo, g: JsonGenerator, + options: JsonProtocolOptions, includeAccumulables: Boolean): Unit = { g.writeStartObject() g.writeNumberField("Task ID", taskInfo.taskId) @@ -435,7 +478,10 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeBooleanField("Killed", taskInfo.killed) g.writeFieldName("Accumulables") if (includeAccumulables) { - accumulablesToJson(taskInfo.accumulables, g) + accumulablesToJson( + taskInfo.accumulables, + g, + includeTaskMetricsAccumulators = options.includeTaskMetricsAccumulators) } else { g.writeStartArray() g.writeEndArray() @@ -443,13 +489,23 @@ private[spark] object JsonProtocol extends JsonUtils { g.writeEndObject() } - private lazy val accumulableExcludeList = Set("internal.metrics.updatedBlockStatuses") + private[util] val accumulableExcludeList = Set(InternalAccumulator.UPDATED_BLOCK_STATUSES) + + private[this] val taskMetricAccumulableNames = TaskMetrics.empty.nameToAccums.keySet.toSet - def accumulablesToJson(accumulables: Iterable[AccumulableInfo], g: JsonGenerator): Unit = { + def accumulablesToJson( + accumulables: Iterable[AccumulableInfo], + g: JsonGenerator, + includeTaskMetricsAccumulators: Boolean = true): Unit = { g.writeStartArray() accumulables - .filterNot(_.name.exists(accumulableExcludeList.contains)) - .toList.sortBy(_.id).foreach(a => accumulableInfoToJson(a, g)) + .filterNot { acc => + acc.name.exists(accumulableExcludeList.contains) || + (!includeTaskMetricsAccumulators && acc.name.exists(taskMetricAccumulableNames.contains)) + } + .toList + .sortBy(_.id) + .foreach(a => accumulableInfoToJson(a, g)) g.writeEndArray() } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index f0d7059e29be1..380231ce97c0b 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -208,7 +208,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => checkError( exception = e, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "0", @@ -222,7 +222,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.setJobGroup(jobGroupName, "") sc.parallelize(1 to 100).count() }, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "1", @@ -258,7 +258,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => checkError( exception = e, - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map( "jobId" -> "0", diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 7106a780b3256..22c6280198c9a 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -27,7 +27,10 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel def sc: SparkContext = _sc - val conf = new SparkConf(false) + // SPARK-49647: use `SparkConf()` instead of `SparkConf(false)` because we want to + // load defaults from system properties and the classpath, including default test + // settings specified in the SBT and Maven build definitions. + val conf: SparkConf = new SparkConf() /** * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 1966a60c1665e..9f310c06ac5ae 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -327,9 +327,9 @@ abstract class SparkFunSuite } /** - * Checks an exception with an error class against expected results. + * Checks an exception with an error condition against expected results. * @param exception The exception to check - * @param errorClass The expected error class identifying the error + * @param condition The expected error condition identifying the error * @param sqlState Optional the expected SQLSTATE, not verified if not supplied * @param parameters A map of parameter names and values. The names are as defined * in the error-classes file. @@ -338,12 +338,12 @@ abstract class SparkFunSuite */ protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, queryContext: Array[ExpectedContext] = Array.empty): Unit = { - assert(exception.getErrorClass === errorClass) + assert(exception.getErrorClass === condition) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala if (matchPVals) { @@ -390,55 +390,55 @@ abstract class SparkFunSuite protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, parameters: Map[String, String]): Unit = - checkError(exception, errorClass, Some(sqlState), parameters) + checkError(exception, condition, Some(sqlState), parameters) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) + checkError(exception, condition, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, None, parameters, false, Array(context)) + checkError(exception, condition, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: String, context: ExpectedContext): Unit = - checkError(exception, errorClass, None, Map.empty, false, Array(context)) + checkError(exception, condition, Some(sqlState), Map.empty, false, Array(context)) protected def checkError( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String], parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, sqlState, parameters, + checkError(exception, condition, sqlState, parameters, false, Array(context)) protected def checkErrorMatchPVals( exception: SparkThrowable, - errorClass: String, + condition: String, parameters: Map[String, String]): Unit = - checkError(exception, errorClass, None, parameters, matchPVals = true) + checkError(exception, condition, None, parameters, matchPVals = true) protected def checkErrorMatchPVals( exception: SparkThrowable, - errorClass: String, + condition: String, sqlState: Option[String], parameters: Map[String, String], context: ExpectedContext): Unit = - checkError(exception, errorClass, sqlState, parameters, + checkError(exception, condition, sqlState, parameters, matchPVals = true, Array(context)) protected def checkErrorTableNotFound( @@ -446,7 +446,7 @@ abstract class SparkFunSuite tableName: String, queryContext: ExpectedContext): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> tableName), queryContext = Array(queryContext)) @@ -454,13 +454,13 @@ abstract class SparkFunSuite exception: SparkThrowable, tableName: String): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_NOT_FOUND", + condition = "TABLE_OR_VIEW_NOT_FOUND", parameters = Map("relationName" -> tableName)) protected def checkErrorTableAlreadyExists(exception: SparkThrowable, tableName: String): Unit = checkError(exception = exception, - errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS", + condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> tableName)) case class ExpectedContext( diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 0c22edbe984cc..946ea75686e32 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -43,16 +43,13 @@ class SparkThrowableSuite extends SparkFunSuite { /* Used to regenerate the error class file. Run: {{{ SPARK_GENERATE_GOLDEN_FILES=1 build/sbt \ - "core/testOnly *SparkThrowableSuite -- -t \"Error classes are correctly formatted\"" + "core/testOnly *SparkThrowableSuite -- -t \"Error conditions are correctly formatted\"" }}} */ private val regenerateCommand = "SPARK_GENERATE_GOLDEN_FILES=1 build/sbt " + "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error classes match with document\\\"\"" private val errorJsonFilePath = getWorkspaceFilePath( - // Note that though we call them "error classes" here, the proper name is "error conditions", - // hence why the name of the JSON file is different. We will address this inconsistency as part - // of this ticket: https://issues.apache.org/jira/browse/SPARK-47429 "common", "utils", "src", "main", "resources", "error", "error-conditions.json") private val errorReader = new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL)) @@ -81,8 +78,8 @@ class SparkThrowableSuite extends SparkFunSuite { mapper.readValue(errorJsonFilePath.toUri.toURL, new TypeReference[Map[String, ErrorInfo]]() {}) } - test("Error classes are correctly formatted") { - val errorClassFileContents = + test("Error conditions are correctly formatted") { + val errorConditionFileContents = IOUtils.toString(errorJsonFilePath.toUri.toURL.openStream(), StandardCharsets.UTF_8) val mapper = JsonMapper.builder() .addModule(DefaultScalaModule) @@ -96,33 +93,30 @@ class SparkThrowableSuite extends SparkFunSuite { .writeValueAsString(errorReader.errorInfoMap) if (regenerateGoldenFiles) { - if (rewrittenString.trim != errorClassFileContents.trim) { - val errorClassesFile = errorJsonFilePath.toFile - logInfo(s"Regenerating error class file $errorClassesFile") - Files.delete(errorClassesFile.toPath) + if (rewrittenString.trim != errorConditionFileContents.trim) { + val errorConditionsFile = errorJsonFilePath.toFile + logInfo(s"Regenerating error conditions file $errorConditionsFile") + Files.delete(errorConditionsFile.toPath) FileUtils.writeStringToFile( - errorClassesFile, + errorConditionsFile, rewrittenString + lineSeparator, StandardCharsets.UTF_8) } } else { - assert(rewrittenString.trim == errorClassFileContents.trim) + assert(rewrittenString.trim == errorConditionFileContents.trim) } } test("SQLSTATE is mandatory") { - val errorClassesNoSqlState = errorReader.errorInfoMap.filter { + val errorConditionsNoSqlState = errorReader.errorInfoMap.filter { case (error: String, info: ErrorInfo) => !error.startsWith("_LEGACY_ERROR_TEMP") && info.sqlState.isEmpty }.keys.toSeq - assert(errorClassesNoSqlState.isEmpty, - s"Error classes without SQLSTATE: ${errorClassesNoSqlState.mkString(", ")}") + assert(errorConditionsNoSqlState.isEmpty, + s"Error classes without SQLSTATE: ${errorConditionsNoSqlState.mkString(", ")}") } test("Error class and error state / SQLSTATE invariants") { - // Unlike in the rest of the codebase, the term "error class" is used here as it is in our - // documentation as well as in the SQL standard. We can remove this comment as part of this - // ticket: https://issues.apache.org/jira/browse/SPARK-47429 val errorClassesJson = Utils.getSparkClassLoader.getResource("error/error-classes.json") val errorStatesJson = Utils.getSparkClassLoader.getResource("error/error-states.json") val mapper = JsonMapper.builder() @@ -171,9 +165,9 @@ class SparkThrowableSuite extends SparkFunSuite { .enable(SerializationFeature.INDENT_OUTPUT) .build() mapper.writeValue(tmpFile, errorReader.errorInfoMap) - val rereadErrorClassToInfoMap = mapper.readValue( + val rereadErrorConditionToInfoMap = mapper.readValue( tmpFile, new TypeReference[Map[String, ErrorInfo]]() {}) - assert(rereadErrorClassToInfoMap == errorReader.errorInfoMap) + assert(rereadErrorConditionToInfoMap == errorReader.errorInfoMap) } test("Error class names should contain only capital letters, numbers and underscores") { @@ -207,13 +201,6 @@ class SparkThrowableSuite extends SparkFunSuite { } assert(e.getErrorClass === "INTERNAL_ERROR") assert(e.getMessageParameters().get("message").contains("Undefined error message parameter")) - - // Does not fail with too many args (expects 0 args) - assert(getMessage("DIVIDE_BY_ZERO", Map("config" -> "foo", "a" -> "bar")) == - "[DIVIDE_BY_ZERO] Division by zero. " + - "Use `try_divide` to tolerate divisor being 0 and return NULL instead. " + - "If necessary set foo to \"false\" " + - "to bypass this error. SQLSTATE: 22012") } test("Error message is formatted") { @@ -259,6 +246,7 @@ class SparkThrowableSuite extends SparkFunSuite { } catch { case e: SparkThrowable => assert(e.getErrorClass == null) + assert(!e.isInternalError) assert(e.getSqlState == null) case _: Throwable => // Should not end up here @@ -275,6 +263,7 @@ class SparkThrowableSuite extends SparkFunSuite { } catch { case e: SparkThrowable => assert(e.getErrorClass == "CANNOT_PARSE_DECIMAL") + assert(!e.isInternalError) assert(e.getSqlState == "22018") case _: Throwable => // Should not end up here @@ -502,7 +491,7 @@ class SparkThrowableSuite extends SparkFunSuite { |{ | "MISSING_PARAMETER" : { | "message" : [ - | "Parameter ${param} is missing." + | "Parameter is missing." | ] | } |} @@ -515,4 +504,28 @@ class SparkThrowableSuite extends SparkFunSuite { assert(errorMessage.contains("Parameter null is missing.")) } } + + test("detect unused message parameters") { + checkError( + exception = intercept[SparkException] { + SparkThrowableHelper.getMessage( + errorClass = "CANNOT_UP_CAST_DATATYPE", + messageParameters = Map( + "expression" -> "CAST('aaa' AS LONG)", + "sourceType" -> "STRING", + "targetType" -> "LONG", + "op" -> "CAST", // unused parameter + "details" -> "implicit cast" + )) + }, + condition = "INTERNAL_ERROR", + parameters = Map( + "message" -> + ("Found unused message parameters of the error class 'CANNOT_UP_CAST_DATATYPE'. " + + "Its error message format has 4 placeholders, but the passed message parameters map " + + "has 5 items. Consider to add placeholders to the error format or " + + "remove unused message parameters.") + ) + ) + } } diff --git a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala index e7315d6119be0..7e88c7ee684bd 100644 --- a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala +++ b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala @@ -94,9 +94,11 @@ private[spark] class Benchmark( /** * Runs the benchmark and outputs the results to stdout. This should be copied and added as * a comment with the benchmark. Although the results vary from machine to machine, it should - * provide some baseline. + * provide some baseline. If `relativeTime` is set to `true`, the `Relative` column will be + * the relative time of each case relative to the first case (less is better). Otherwise, it + * will be the relative execution speed of each case relative to the first case (more is better). */ - def run(): Unit = { + def run(relativeTime: Boolean = false): Unit = { require(benchmarks.nonEmpty) // scalastyle:off println("Running benchmark: " + name) @@ -112,10 +114,12 @@ private[spark] class Benchmark( out.println(Benchmark.getJVMOSInfo()) out.println(Benchmark.getProcessorName()) val nameLen = Math.max(40, Math.max(name.length, benchmarks.map(_.name.length).max)) + val relativeHeader = if (relativeTime) "Relative time" else "Relative" out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n", - name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", "Relative") + name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", relativeHeader) out.println("-" * (nameLen + 80)) results.zip(benchmarks).foreach { case (result, benchmark) => + val relative = if (relativeTime) result.bestMs / firstBest else firstBest / result.bestMs out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n", benchmark.name, "%5.0f" format result.bestMs, @@ -123,7 +127,7 @@ private[spark] class Benchmark( "%5.0f" format result.stdevMs, "%10.1f" format result.bestRate, "%6.1f" format (1000 / result.bestRate), - "%3.1fX" format (firstBest / result.bestMs)) + "%3.1fX" format relative) } out.println() // scalastyle:on diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 40d8eae644a07..ca81283e073ac 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -1802,6 +1802,23 @@ class SparkSubmitSuite val (_, classpath, _, _) = submit.prepareSubmitEnvironment(appArgs) assert(classpath.contains(".")) } + + // Requires Python dependencies for Spark Connect. Should be enabled by default. + ignore("Spark Connect application submission (Python)") { + val pyFile = File.createTempFile("remote_test", ".py") + pyFile.deleteOnExit() + val content = + "from pyspark.sql import SparkSession;" + + "spark = SparkSession.builder.getOrCreate();" + + "assert 'connect' in str(type(spark));" + + "assert spark.range(1).first()[0] == 0" + FileUtils.write(pyFile, content, StandardCharsets.UTF_8) + val args = Seq( + "--name", "testPyApp", + "--remote", "local", + pyFile.getAbsolutePath) + runSparkSubmit(args) + } } object JarCreationTest extends Logging { diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 5c09a1f965b9e..ff971b72d8910 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -132,7 +132,7 @@ class CompressionCodecSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") }, - errorClass = "CODEC_NOT_AVAILABLE.WITH_CONF_SUGGESTION", + condition = "CODEC_NOT_AVAILABLE.WITH_CONF_SUGGESTION", parameters = Map( "codecName" -> "foobar", "configKey" -> "\"spark.io.compression.codec\"", @@ -171,7 +171,7 @@ class CompressionCodecSuite extends SparkFunSuite { exception = intercept[SparkIllegalArgumentException] { CompressionCodec.getShortName(codecClass.toUpperCase(Locale.ROOT)) }, - errorClass = "CODEC_SHORT_NAME_NOT_FOUND", + condition = "CODEC_SHORT_NAME_NOT_FOUND", parameters = Map("codecName" -> codecClass.toUpperCase(Locale.ROOT))) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala index 55d82aed5c3f2..817d660763361 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/GraphiteSinkSuite.scala @@ -88,7 +88,7 @@ class GraphiteSinkSuite extends SparkFunSuite { val e = intercept[SparkException] { new GraphiteSink(props, registry) } - checkError(e, errorClass = "GRAPHITE_SINK_PROPERTY_MISSING", + checkError(e, condition = "GRAPHITE_SINK_PROPERTY_MISSING", parameters = Map("property" -> "host")) } @@ -100,7 +100,7 @@ class GraphiteSinkSuite extends SparkFunSuite { val e = intercept[SparkException] { new GraphiteSink(props, registry) } - checkError(e, errorClass = "GRAPHITE_SINK_PROPERTY_MISSING", + checkError(e, condition = "GRAPHITE_SINK_PROPERTY_MISSING", parameters = Map("property" -> "port")) } @@ -115,7 +115,7 @@ class GraphiteSinkSuite extends SparkFunSuite { exception = intercept[SparkException] { new GraphiteSink(props, registry) }, - errorClass = "GRAPHITE_SINK_INVALID_PROTOCOL", + condition = "GRAPHITE_SINK_INVALID_PROTOCOL", parameters = Map("protocol" -> "http") ) } diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporterSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporterSuite.scala deleted file mode 100644 index 3f9c75062f78f..0000000000000 --- a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporterSuite.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.metrics.sink.opentelemetry - -import com.codahale.metrics._ -import org.junit.jupiter.api.Assertions.assertNotNull -import org.scalatest.PrivateMethodTester - -import org.apache.spark.SparkFunSuite - -class OpenTelemetryPushReporterSuite - extends SparkFunSuite with PrivateMethodTester { - val reporter = new OpenTelemetryPushReporter( - registry = new MetricRegistry(), - trustedCertificatesPath = null, - privateKeyPemPath = null, - certificatePemPath = null - ) - - test("Normalize metric name key") { - val name = "local-1592132938718.driver.LiveListenerBus." + - "listenerProcessingTime.org.apache.spark.HeartbeatReceiver" - val metricsName = reporter invokePrivate PrivateMethod[String]( - Symbol("normalizeMetricName") - )(name) - assert( - metricsName == "local_1592132938718_driver_livelistenerbus_" + - "listenerprocessingtime_org_apache_spark_heartbeatreceiver" - ) - } - - test("OpenTelemetry actions when one codahale gauge is added") { - val gauge = new Gauge[Double] { - override def getValue: Double = 1.23 - } - reporter.onGaugeAdded("test-gauge", gauge) - assertNotNull(reporter.openTelemetryGauges("test_gauge")) - } - - test("OpenTelemetry actions when one codahale counter is added") { - val counter = new Counter - reporter.onCounterAdded("test_counter", counter) - assertNotNull(reporter.openTelemetryCounters("test_counter")) - } - - test("OpenTelemetry actions when one codahale histogram is added") { - val histogram = new Histogram(new UniformReservoir) - reporter.onHistogramAdded("test_hist", histogram) - assertNotNull(reporter.openTelemetryHistograms("test_hist_count")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_max")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_min")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_mean")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_std_dev")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_50_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_75_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_95_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_98_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_99_percentile")) - assertNotNull(reporter.openTelemetryHistograms("test_hist_999_percentile")) - } - - test("OpenTelemetry actions when one codahale meter is added") { - val meter = new Meter() - reporter.onMeterAdded("test_meter", meter) - assertNotNull(reporter.openTelemetryGauges("test_meter_meter_count")) - assertNotNull(reporter.openTelemetryGauges("test_meter_meter_mean_rate")) - assertNotNull( - reporter.openTelemetryGauges("test_meter_meter_one_minute_rate") - ) - assertNotNull( - reporter.openTelemetryGauges("test_meter_meter_five_minute_rate") - ) - assertNotNull( - reporter.openTelemetryGauges("test_meter_meter_fifteen_minute_rate") - ) - } - - test("OpenTelemetry actions when one codahale timer is added") { - val timer = new Timer() - reporter.onTimerAdded("test_timer", timer) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_count")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_max")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_min")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_mean")) - assertNotNull(reporter.openTelemetryHistograms("test_timer_timer_std_dev")) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_50_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_75_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_95_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_95_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_99_percentile") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_999_percentile") - ) - - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_fifteen_minute_rate") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_five_minute_rate") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_one_minute_rate") - ) - assertNotNull( - reporter.openTelemetryHistograms("test_timer_timer_mean_rate") - ) - } -} diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSinkSuite.scala deleted file mode 100644 index 25aca82a22c40..0000000000000 --- a/core/src/test/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSinkSuite.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.metrics.sink.opentelemetry - -import java.util.Properties - -import com.codahale.metrics._ -import org.junit.jupiter.api.Assertions.assertEquals -import org.scalatest.PrivateMethodTester - -import org.apache.spark.SparkFunSuite - -class OpenTelemetryPushSinkSuite - extends SparkFunSuite with PrivateMethodTester { - test("fetch properties map") { - val properties = new Properties - properties.put("foo1.foo2.foo3.foo4.header.key1.key2.key3", "value1") - properties.put("foo1.foo2.foo3.foo4.header.key2", "value2") - val keyPrefix = "foo1.foo2.foo3.foo4.header" - val propertiesMap: Map[String, String] = OpenTelemetryPushSink invokePrivate - PrivateMethod[Map[String, String]](Symbol("fetchMapFromProperties"))( - properties, - keyPrefix - ) - - assert("value1".equals(propertiesMap("key1.key2.key3"))) - assert("value2".equals(propertiesMap("key2"))) - } - - test("OpenTelemetry sink with one counter added") { - val props = new Properties - props.put("endpoint", "http://127.0.0.1:10086") - val registry = new MetricRegistry - val sink = new OpenTelemetryPushSink(props, registry) - sink.start() - val reporter = sink.reporter - val counter = registry.counter("test-counter") - assertEquals(reporter.openTelemetryCounters.size, 1) - } -} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 7c5db914cd5ba..8bb96a0f53c73 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -922,7 +922,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext with Eventually { exception = intercept[SparkIllegalArgumentException] { rdd1.cartesian(rdd2).partitions }, - errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", + condition = "COLLECTION_SIZE_LIMIT_EXCEEDED.INITIALIZE", sqlState = "54000", parameters = Map( "numberOfElements" -> (numSlices.toLong * numSlices.toLong).toString, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 978ceb16b376c..243d33fe55a79 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -779,7 +779,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(failureReason.isDefined) checkError( exception = failureReason.get.asInstanceOf[SparkException], - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map("jobId" -> "0", "reason" -> "") ) @@ -901,7 +901,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti cancel(jobId) checkError( exception = failure.asInstanceOf[SparkException], - errorClass = "SPARK_JOB_CANCELLED", + condition = "SPARK_JOB_CANCELLED", sqlState = "XXKDA", parameters = scala.collection.immutable.Map("jobId" -> jobId.toString, "reason" -> "") ) diff --git a/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala index e7a57c22ef66e..478e578130fcb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/HealthTrackerSuite.scala @@ -441,6 +441,23 @@ class HealthTrackerSuite extends SparkFunSuite with MockitoSugar with LocalSpark assert(1000 === HealthTracker.getExcludeOnFailureTimeout(conf)) } + test("SPARK-49252: check exclusion enabling config on the application level") { + val conf = new SparkConf().setMaster("local") + assert(!HealthTracker.isExcludeOnFailureEnabled(conf)) + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED, true) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn off taskset level exclusion, application level healthtracker should still be enabled. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn off the application level exclusion specifically, this overrides the global setting. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION, false) + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + assert(!HealthTracker.isExcludeOnFailureEnabled(conf)) + // Turn on application level exclusion, health tracker should be enabled. + conf.set(config.EXCLUDE_ON_FAILURE_ENABLED_APPLICATION, true) + assert(HealthTracker.isExcludeOnFailureEnabled(conf)) + } + test("check exclude configuration invariants") { val conf = new SparkConf().setMaster("yarn").set(config.SUBMIT_DEPLOY_MODE, "cluster") Seq( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ab2c00e368468..7607d4d9fe6d9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -2725,6 +2725,39 @@ class TaskSetManagerSuite assert(executorMonitor.isExecutorIdle("exec2")) } + test("SPARK-49252: TaskSetExcludeList can be created without HealthTracker") { + // When the excludeOnFailure.enabled is set to true, the TaskSetManager should create a + // TaskSetExcludelist even if the application level HealthTracker is not defined. + val conf = new SparkConf().set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, true) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(1) + + val taskSetManager = new TaskSetManager(sched, taskSet, 1, + // No application level HealthTracker. + healthTracker = None) + assert(taskSetManager.taskSetExcludelistHelperOpt.isDefined) + } + + test("SPARK-49252: TaskSetExcludeList will be running in dry run mode when" + + "exludeOnFailure at taskset level is disabled but health tracker is enabled") { + // Disable the excludeOnFailure.enabled at taskset level. + val conf = new SparkConf().set(config.EXCLUDE_ON_FAILURE_ENABLED_TASK_AND_STAGE, false) + + // Create a task with two executors. + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(1) + + val taskSetManager = new TaskSetManager(sched, taskSet, 1, + // Enable the application level HealthTracker. + healthTracker = Some(new HealthTracker(sc, None))) + assert(taskSetManager.taskSetExcludelistHelperOpt.isDefined) + assert(taskSetManager.taskSetExcludelistHelperOpt.get.isDryRun) + } + } class FakeLongTasks(stageId: Int, partitionId: Int) extends FakeTask(stageId, partitionId) { diff --git a/core/src/test/scala/org/apache/spark/shuffle/ChecksumBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/ChecksumBenchmark.scala new file mode 100644 index 0000000000000..16a50fabb7ffd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ChecksumBenchmark.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util.zip.{Adler32, CRC32, CRC32C} + +import org.apache.hadoop.util.PureJavaCrc32C + +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} + +/** + * Benchmark for Checksum Algorithms used by shuffle. + * {{{ + * To run this benchmark: + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "core/Test/runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "core/Test/runMain " + * Results will be written to "benchmarks/ChecksumBenchmark-results.txt". + * }}} + */ +object ChecksumBenchmark extends BenchmarkBase { + + val N = 1024 + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Benchmark Checksum Algorithms") { + val data: Array[Byte] = (1 until 32 * 1024 * 1024).map(_.toByte).toArray + val benchmark = new Benchmark("Checksum Algorithms", N, 3, output = output) + benchmark.addCase("CRC32") { _ => + (1 to N).foreach(_ => new CRC32().update(data)) + } + benchmark.addCase(s"CRC32C") { _ => + (1 to N).foreach(_ => new CRC32C().update(data)) + } + benchmark.addCase(s"Adler32") { _ => + (1 to N).foreach(_ => new Adler32().update(data)) + } + benchmark.addCase(s"hadoop PureJavaCrc32C") { _ => + (1 to N).foreach(_ => new PureJavaCrc32C().update(data)) + } + benchmark.run() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala index 5b6fb31d598ac..aad649b7b2612 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -111,7 +111,7 @@ class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext wi exception = intercept[SparkOutOfMemoryError] { sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) }, - errorClass = "UNABLE_TO_ACQUIRE_MEMORY", + condition = "UNABLE_TO_ACQUIRE_MEMORY", parameters = Map("requestedBytes" -> "800", "receivedBytes" -> "400")) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index cdee6ccda706e..30c9693e6dee3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -32,6 +32,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark._ import org.apache.spark.executor._ +import org.apache.spark.internal.config._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.{DeterministicLevel, RDDOperationScope} import org.apache.spark.resource._ @@ -276,7 +277,8 @@ class JsonProtocolSuite extends SparkFunSuite { test("StageInfo backward compatibility (details, accumulables)") { val info = makeStageInfo(1, 2, 3, 4L, 5L) - val newJson = toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true)) + val newJson = toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true)) // Fields added after 1.0.0. assert(info.details.nonEmpty) @@ -294,7 +296,8 @@ class JsonProtocolSuite extends SparkFunSuite { test("StageInfo resourceProfileId") { val info = makeStageInfo(1, 2, 3, 4L, 5L, 5) - val json = toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true)) + val json = toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true)) // Fields added after 1.0.0. assert(info.details.nonEmpty) @@ -471,7 +474,7 @@ class JsonProtocolSuite extends SparkFunSuite { stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) - val oldEvent = toJsonString(JsonProtocol.jobStartToJson(jobStart, _)).removeField("Stage Infos") + val oldEvent = sparkEventToJsonString(jobStart).removeField("Stage Infos") val expectedJobStart = SparkListenerJobStart(10, jobSubmissionTime, dummyStageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldEvent)) @@ -483,8 +486,7 @@ class JsonProtocolSuite extends SparkFunSuite { val stageIds = Seq[Int](1, 2, 3, 4) val stageInfos = stageIds.map(x => makeStageInfo(x * 10, x * 20, x * 30, x * 40L, x * 50L)) val jobStart = SparkListenerJobStart(11, jobSubmissionTime, stageInfos, properties) - val oldStartEvent = toJsonString(JsonProtocol.jobStartToJson(jobStart, _)) - .removeField("Submission Time") + val oldStartEvent = sparkEventToJsonString(jobStart).removeField("Submission Time") val expectedJobStart = SparkListenerJobStart(11, -1, stageInfos, properties) assertEquals(expectedJobStart, JsonProtocol.jobStartFromJson(oldStartEvent)) @@ -519,8 +521,9 @@ class JsonProtocolSuite extends SparkFunSuite { val stageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq(1, 2, 3), "details", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val oldStageInfo = - toJsonString(JsonProtocol.stageInfoToJson(stageInfo, _, includeAccumulables = true)) - .removeField("Parent IDs") + toJsonString( + JsonProtocol.stageInfoToJson(stageInfo, _, defaultOptions, includeAccumulables = true) + ).removeField("Parent IDs") val expectedStageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq.empty, "details", resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) @@ -785,6 +788,87 @@ class JsonProtocolSuite extends SparkFunSuite { assert(JsonProtocol.sparkEventFromJson(unknownFieldsJson) === expected) } + test("SPARK-42204: spark.eventLog.includeTaskMetricsAccumulators config") { + val includeConf = new JsonProtocolOptions( + new SparkConf().set(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS, true)) + val excludeConf = new JsonProtocolOptions( + new SparkConf().set(EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS, false)) + + val taskMetricsAccumulables = TaskMetrics + .empty + .nameToAccums + .view + .filterKeys(!JsonProtocol.accumulableExcludeList.contains(_)) + .values + .map(_.toInfo(Some(1), None)) + .toSeq + + val taskInfoWithTaskMetricsAccums = makeTaskInfo(222L, 333, 1, 333, 444L, false) + taskInfoWithTaskMetricsAccums.setAccumulables(taskMetricsAccumulables) + val taskInfoWithoutTaskMetricsAccums = makeTaskInfo(222L, 333, 1, 333, 444L, false) + taskInfoWithoutTaskMetricsAccums.setAccumulables(Seq.empty) + + val stageInfoWithTaskMetricsAccums = makeStageInfo(100, 200, 300, 400L, 500L) + stageInfoWithTaskMetricsAccums.accumulables.clear() + stageInfoWithTaskMetricsAccums.accumulables ++= taskMetricsAccumulables.map(x => (x.id, x)) + val stageInfoWithoutTaskMetricsAccums = makeStageInfo(100, 200, 300, 400L, 500L) + stageInfoWithoutTaskMetricsAccums.accumulables.clear() + + // Test events which should be impacted by the config. + + // TaskEnd + { + val originalEvent = SparkListenerTaskEnd(1, 0, "ShuffleMapTask", Success, + taskInfoWithTaskMetricsAccums, + new ExecutorMetrics(Array(12L, 23L, 45L, 67L, 78L, 89L, + 90L, 123L, 456L, 789L, 40L, 20L, 20L, 10L, 20L, 10L, 301L)), + makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, 0, + hasHadoopInput = false, hasOutput = false)) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(taskInfo = taskInfoWithoutTaskMetricsAccums) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // StageCompleted + { + val originalEvent = SparkListenerStageCompleted(stageInfoWithTaskMetricsAccums) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(stageInfo = stageInfoWithoutTaskMetricsAccums) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // JobStart + { + val originalEvent = + SparkListenerJobStart(1, 1, Seq(stageInfoWithTaskMetricsAccums), properties) + assertEquals( + originalEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, includeConf))) + val trimmedEvent = originalEvent.copy(stageInfos = Seq(stageInfoWithoutTaskMetricsAccums)) + assertEquals( + trimmedEvent, + sparkEventFromJson(sparkEventToJsonString(originalEvent, excludeConf))) + } + + // ExecutorMetricsUpdate events should be unaffected by the config: + val executorMetricsUpdate = + SparkListenerExecutorMetricsUpdate("0", Seq((0, 0, 0, taskMetricsAccumulables))) + assert( + sparkEventToJsonString(executorMetricsUpdate, includeConf) === + sparkEventToJsonString(executorMetricsUpdate, excludeConf)) + assertEquals( + JsonProtocol.sparkEventFromJson(sparkEventToJsonString(executorMetricsUpdate, includeConf)), + executorMetricsUpdate) + } + test("SPARK-42403: properly handle null string values") { // Null string values can appear in a few different event types, // so we test multiple known cases here: @@ -966,7 +1050,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def testStageInfo(info: StageInfo): Unit = { val newInfo = JsonProtocol.stageInfoFromJson( - toJsonString(JsonProtocol.stageInfoToJson(info, _, includeAccumulables = true))) + toJsonString( + JsonProtocol.stageInfoToJson(info, _, defaultOptions, includeAccumulables = true))) assertEquals(info, newInfo) } @@ -990,7 +1075,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def testTaskInfo(info: TaskInfo): Unit = { val newInfo = JsonProtocol.taskInfoFromJson( - toJsonString(JsonProtocol.taskInfoToJson(info, _, includeAccumulables = true))) + toJsonString( + JsonProtocol.taskInfoToJson(info, _, defaultOptions, includeAccumulables = true))) assertEquals(info, newInfo) } diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 8bad50951a78f..b82cb7078c9f3 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -31,12 +31,13 @@ graphlib-dot.min.js sorttable.js vis-timeline-graph2d.min.js vis-timeline-graph2d.min.css -dataTables.bootstrap4.*.min.css -dataTables.bootstrap4.*.min.js +dataTables.bootstrap4.min.css +dataTables.bootstrap4.min.js dataTables.rowsGroup.js jquery.blockUI.min.js jquery.cookies.2.2.0.min.js -jquery.dataTables.*.min.js +jquery.dataTables.min.css +jquery.dataTables.min.js jquery.mustache.js .*\.avsc .*\.txt @@ -139,3 +140,4 @@ ui-test/package.json ui-test/package-lock.json core/src/main/resources/org/apache/spark/ui/static/package.json .*\.har +.nojekyll diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index e86b91968bf80..3cba72d042ed6 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -94,7 +94,7 @@ ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library RUN add-apt-repository ppa:pypy/ppa RUN mkdir -p /usr/local/pypy/pypy3.9 && \ curl -sqL https://downloads.python.org/pypy/pypy3.9-v7.3.16-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.9 --strip-components=1 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.8 && \ + ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml @@ -137,6 +137,7 @@ RUN python3.9 -m pip list RUN gem install --no-document "bundler:2.4.22" RUN ln -s "$(which python3.9)" "/usr/local/bin/python" +RUN ln -s "$(which python3.9)" "/usr/local/bin/python3" WORKDIR /opt/spark-rm/output diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index e93e8e94a993e..419625f48fa11 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -4,7 +4,7 @@ JTransforms/3.1//JTransforms-3.1.jar RoaringBitmap/1.2.1//RoaringBitmap-1.2.1.jar ST4/4.0.4//ST4-4.0.4.jar activation/1.1.1//activation-1.1.1.jar -aircompressor/0.27//aircompressor-0.27.jar +aircompressor/2.0.2//aircompressor-2.0.2.jar algebra_2.13/2.8.0//algebra_2.13-2.8.0.jar aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar @@ -33,6 +33,7 @@ breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar bundle/2.24.6//bundle-2.24.6.jar cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar +checker-qual/3.42.0//checker-qual-3.42.0.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.13/0.10.0//chill_2.13-0.10.0.jar commons-cli/1.9.0//commons-cli-1.9.0.jar @@ -43,9 +44,9 @@ commons-compiler/3.1.9//commons-compiler-3.1.9.jar commons-compress/1.27.1//commons-compress-1.27.1.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.16.1//commons-io-2.16.1.jar +commons-io/2.17.0//commons-io-2.17.0.jar commons-lang/2.6//commons-lang-2.6.jar -commons-lang3/3.16.0//commons-lang3-3.16.0.jar +commons-lang3/3.17.0//commons-lang3-3.17.0.jar commons-math3/3.6.1//commons-math3-3.6.1.jar commons-pool/1.5.4//commons-pool-1.5.4.jar commons-text/1.12.0//commons-text-1.12.0.jar @@ -62,12 +63,14 @@ derby/10.16.1.1//derby-10.16.1.1.jar derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar derbytools/10.16.1.1//derbytools-10.16.1.1.jar dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar +error_prone_annotations/2.26.1//error_prone_annotations-2.26.1.jar esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar +failureaccess/1.0.2//failureaccess-1.0.2.jar flatbuffers-java/24.3.25//flatbuffers-java-24.3.25.jar gcs-connector/hadoop3-2.2.21/shaded/gcs-connector-hadoop3-2.2.21-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.11.0//gson-2.11.0.jar -guava/14.0.1//guava-14.0.1.jar +guava/33.2.1-jre//guava-33.2.1-jre.jar hadoop-aliyun/3.4.0//hadoop-aliyun-3.4.0.jar hadoop-annotations/3.4.0//hadoop-annotations-3.4.0.jar hadoop-aws/3.4.0//hadoop-aws-3.4.0.jar @@ -101,6 +104,7 @@ icu4j/75.1//icu4j-75.1.jar ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.2//ivy-2.5.2.jar +j2objc-annotations/3.0.0//j2objc-annotations-3.0.0.jar jackson-annotations/2.17.2//jackson-annotations-2.17.2.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.17.2//jackson-core-2.17.2.jar @@ -142,7 +146,7 @@ jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar jline/3.25.1//jline-3.25.1.jar jna/5.14.0//jna-5.14.0.jar -joda-time/2.12.7//joda-time-2.12.7.jar +joda-time/2.13.0//joda-time-2.13.0.jar jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar json/1.8//json-1.8.jar @@ -184,6 +188,7 @@ lapack/3.0.3//lapack-3.0.3.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.16.0//libthrift-0.16.0.jar +listenablefuture/9999.0-empty-to-avoid-conflict-with-guava//listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar log4j-1.2-api/2.22.1//log4j-1.2-api-2.22.1.jar log4j-api/2.22.1//log4j-api-2.22.1.jar log4j-core/2.22.1//log4j-core-2.22.1.jar @@ -207,12 +212,12 @@ netty-common/4.1.110.Final//netty-common-4.1.110.Final.jar netty-handler-proxy/4.1.110.Final//netty-handler-proxy-4.1.110.Final.jar netty-handler/4.1.110.Final//netty-handler-4.1.110.Final.jar netty-resolver/4.1.110.Final//netty-resolver-4.1.110.Final.jar -netty-tcnative-boringssl-static/2.0.65.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-aarch_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-linux-x86_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-aarch_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-osx-x86_64.jar -netty-tcnative-boringssl-static/2.0.65.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.65.Final-windows-x86_64.jar -netty-tcnative-classes/2.0.65.Final//netty-tcnative-classes-2.0.65.Final.jar +netty-tcnative-boringssl-static/2.0.66.Final/linux-aarch_64/netty-tcnative-boringssl-static-2.0.66.Final-linux-aarch_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/linux-x86_64/netty-tcnative-boringssl-static-2.0.66.Final-linux-x86_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/osx-aarch_64/netty-tcnative-boringssl-static-2.0.66.Final-osx-aarch_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/osx-x86_64/netty-tcnative-boringssl-static-2.0.66.Final-osx-x86_64.jar +netty-tcnative-boringssl-static/2.0.66.Final/windows-x86_64/netty-tcnative-boringssl-static-2.0.66.Final-windows-x86_64.jar +netty-tcnative-classes/2.0.66.Final//netty-tcnative-classes-2.0.66.Final.jar netty-transport-classes-epoll/4.1.110.Final//netty-transport-classes-epoll-4.1.110.Final.jar netty-transport-classes-kqueue/4.1.110.Final//netty-transport-classes-kqueue-4.1.110.Final.jar netty-transport-native-epoll/4.1.110.Final/linux-aarch_64/netty-transport-native-epoll-4.1.110.Final-linux-aarch_64.jar @@ -236,12 +241,12 @@ orc-shims/2.0.2//orc-shims-2.0.2.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar -parquet-column/1.14.1//parquet-column-1.14.1.jar -parquet-common/1.14.1//parquet-common-1.14.1.jar -parquet-encoding/1.14.1//parquet-encoding-1.14.1.jar -parquet-format-structures/1.14.1//parquet-format-structures-1.14.1.jar -parquet-hadoop/1.14.1//parquet-hadoop-1.14.1.jar -parquet-jackson/1.14.1//parquet-jackson-1.14.1.jar +parquet-column/1.14.2//parquet-column-1.14.2.jar +parquet-common/1.14.2//parquet-common-1.14.2.jar +parquet-encoding/1.14.2//parquet-encoding-1.14.2.jar +parquet-format-structures/1.14.2//parquet-format-structures-1.14.2.jar +parquet-hadoop/1.14.2//parquet-hadoop-1.14.2.jar +parquet-jackson/1.14.2//parquet-jackson-1.14.2.jar pickle/1.5//pickle-1.5.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar @@ -256,7 +261,7 @@ scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar slf4j-api/2.0.16//slf4j-api-2.0.16.jar snakeyaml-engine/2.7//snakeyaml-engine-2.7.jar snakeyaml/2.2//snakeyaml-2.2.jar -snappy-java/1.1.10.6//snappy-java-1.1.10.6.jar +snappy-java/1.1.10.7//snappy-java-1.1.10.7.jar spire-macros_2.13/0.18.0//spire-macros_2.13-0.18.0.jar spire-platform_2.13/0.18.0//spire-platform_2.13-0.18.0.jar spire-util_2.13/0.18.0//spire-util_2.13-0.18.0.jar @@ -265,7 +270,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar stream/2.9.8//stream-2.9.8.jar super-csv/2.2.0//super-csv-2.2.0.jar threeten-extra/1.7.1//threeten-extra-1.7.1.jar -tink/1.14.1//tink-1.14.1.jar +tink/1.15.0//tink-1.15.0.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar wildfly-openssl/1.1.3.Final//wildfly-openssl-1.1.3.Final.jar @@ -275,4 +280,4 @@ xz/1.10//xz-1.10.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.9.2//zookeeper-jute-3.9.2.jar zookeeper/3.9.2//zookeeper-3.9.2.jar -zstd-jni/1.5.6-4//zstd-jni-1.5.6-4.jar +zstd-jni/1.5.6-5//zstd-jni-1.5.6-5.jar diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index b01e3c50e28d3..5939e429b2f35 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" # Overwrite this label to avoid exposing the underlying Ubuntu OS version label LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20240318 +ENV FULL_REFRESH_DATE 20240903 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -88,13 +88,13 @@ ENV R_LIBS_SITE "/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library RUN add-apt-repository ppa:pypy/ppa RUN mkdir -p /usr/local/pypy/pypy3.9 && \ curl -sqL https://downloads.python.org/pypy/pypy3.9-v7.3.16-linux64.tar.bz2 | tar xjf - -C /usr/local/pypy/pypy3.9 --strip-components=1 && \ - ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.8 && \ + ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml +RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4 graphviz==0.20.3" diff --git a/dev/lint-scala b/dev/lint-scala index 98b850da68838..23df146a8d1b4 100755 --- a/dev/lint-scala +++ b/dev/lint-scala @@ -29,6 +29,7 @@ ERRORS=$(./build/mvn \ -Dscalafmt.skip=false \ -Dscalafmt.validateOnly=true \ -Dscalafmt.changedOnly=false \ + -pl sql/api \ -pl sql/connect/common \ -pl sql/connect/server \ -pl connector/connect/client/jvm \ @@ -38,7 +39,7 @@ ERRORS=$(./build/mvn \ if test ! -z "$ERRORS"; then echo -e "The scalafmt check failed on sql/connect or connector/connect at following occurrences:\n\n$ERRORS\n" echo "Before submitting your change, please make sure to format your code using the following command:" - echo "./build/mvn scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl sql/connect/common -pl sql/connect/server -pl connector/connect/client/jvm" + echo "./build/mvn scalafmt:format -Dscalafmt.skip=false -Dscalafmt.validateOnly=false -Dscalafmt.changedOnly=false -pl sql/api -pl sql/connect/common -pl sql/connect/server -pl connector/connect/client/jvm" exit 1 else echo -e "Scalafmt checks passed." diff --git a/dev/py-cleanup b/dev/py-cleanup new file mode 100755 index 0000000000000..6a2edd1040171 --- /dev/null +++ b/dev/py-cleanup @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility for temporary files cleanup in 'python'. +# usage: ./dev/py-cleanup + +set -ex + +SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" +cd "$SPARK_HOME" + +rm -rf python/target +rm -rf python/lib/pyspark.zip +rm -rf python/docs/build +rm -rf python/docs/source/reference/*/api diff --git a/dev/requirements.txt b/dev/requirements.txt index e0216a63ba790..cafc73405aaa8 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -3,11 +3,11 @@ py4j>=0.10.9.7 # PySpark dependencies (optional) numpy>=1.21 -pyarrow>=4.0.0 +pyarrow>=10.0.0 six==1.16.0 -pandas>=1.4.4 +pandas>=2.0.0 scipy -plotly +plotly>=4.8 mlflow>=2.3.1 scikit-learn matplotlib diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 181cd28cda78d..b9a4bed715f67 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -206,7 +206,6 @@ def __hash__(self): sbt_test_goals=[ "core/test", ], - build_profile_flags=["-Popentelemetry"], ) api = Module( @@ -549,6 +548,8 @@ def __hash__(self): "pyspark.sql.tests.test_udtf", "pyspark.sql.tests.test_utils", "pyspark.sql.tests.test_resources", + "pyspark.sql.tests.plot.test_frame_plot", + "pyspark.sql.tests.plot.test_frame_plot_plotly", ], ) @@ -1052,6 +1053,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", + "pyspark.sql.tests.connect.test_parity_frame_plot", + "pyspark.sql.tests.connect.test_parity_frame_plot_plotly", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_artifact_localcluster", diff --git a/docs/_data/menu-streaming.yaml b/docs/_data/menu-streaming.yaml new file mode 100644 index 0000000000000..b1dd024451125 --- /dev/null +++ b/docs/_data/menu-streaming.yaml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +- text: Overview + url: streaming/index.html +- text: Getting Started + url: streaming/getting-started.html + subitems: + - text: Quick Example + url: streaming/getting-started.html#quick-example + - text: Programming Model + url: streaming/getting-started.html#programming-model +- text: APIs on DataFrames and Datasets + url: streaming/apis-on-dataframes-and-datasets.html + subitems: + - text: Creating Streaming DataFrames and Streaming Datasets + url: streaming/apis-on-dataframes-and-datasets.html#creating-streaming-dataframes-and-streaming-datasets + - text: Operations on Streaming DataFrames/Datasets + url: streaming/apis-on-dataframes-and-datasets.html#operations-on-streaming-dataframesdatasets + - text: Starting Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#starting-streaming-queries + - text: Managing Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#managing-streaming-queries + - text: Monitoring Streaming Queries + url: streaming/apis-on-dataframes-and-datasets.html#monitoring-streaming-queries + - text: Recovering from Failures with Checkpointing + url: streaming/apis-on-dataframes-and-datasets.html#recovering-from-failures-with-checkpointing + - text: Recovery Semantics after Changes in a Streaming Query + url: streaming/apis-on-dataframes-and-datasets.html#recovery-semantics-after-changes-in-a-streaming-query +- text: Performance Tips + url: streaming/performance-tips.html + subitems: + - text: Asynchronous Progress Tracking + url: streaming/performance-tips.html#asynchronous-progress-tracking + - text: Continuous Processing + url: streaming/performance-tips.html#continuous-processing +- text: Additional Information + url: streaming/additional-information.html + subitems: + - text: Miscellaneous Notes + url: streaming/additional-information.html#miscellaneous-notes + - text: Related Resources + url: streaming/additional-information.html#related-resources + - text: Migration Guide + url: streaming/additional-information.html#migration-guide diff --git a/docs/_includes/nav-left-wrapper-streaming.html b/docs/_includes/nav-left-wrapper-streaming.html new file mode 100644 index 0000000000000..82849f8140f5d --- /dev/null +++ b/docs/_includes/nav-left-wrapper-streaming.html @@ -0,0 +1,22 @@ +{% comment %} +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +{% endcomment %} +
      +
      +

      Structured Streaming Programming Guide

      + {% include nav-left.html nav=include.nav-streaming %} +
      +
      diff --git a/docs/_includes/nav-left.html b/docs/_includes/nav-left.html index 19d68fd191635..935ed0c732ee6 100644 --- a/docs/_includes/nav-left.html +++ b/docs/_includes/nav-left.html @@ -2,7 +2,7 @@
        {% for item in include.nav %}
      • - + {% if navurl contains item.url %} {{ item.text }} {% else %} diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index c61c9349a6d7e..a85fd16451469 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -1,3 +1,9 @@ +{% assign current_page_segments = page.dir | split: "/" | where_exp: "element","element != ''" %} +{% assign rel_path_to_root = "" %} +{% for i in (1..current_page_segments.size) %} + {% assign rel_path_to_root = rel_path_to_root | append: "../" %} +{% endfor %} + @@ -21,12 +27,12 @@ - - + + - + - + {% production %} @@ -51,8 +57,8 @@