From 968cba200a5a839e9cc63c947cc285cc8277ba94 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 1 Apr 2024 18:33:30 +0900 Subject: [PATCH] [SPARK-47664][PYTHON][CONNECT] Validate the column name with cached schema ### What changes were proposed in this pull request? improve the column name validation, to try the best to avoid RPC. ### Why are the changes needed? existing validation contains two parts: 1. check whether the column name is in `self.columns` <- client side validation; 2. if step 1 fail, validate with additional RPC `df.select(...)` <- RPC; the client side validation is too simple, and this PR aims to improve it to cover more cases: 1. backticks: ``` '`a`' ``` 2. nested fields: ``` 'a.b.c' ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci, added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #45788 from zhengruifeng/column_name_validate. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/dataframe.py | 8 +- python/pyspark/sql/connect/types.py | 74 +++++++++++++++- .../sql/tests/connect/test_connect_basic.py | 85 +++++++++++++++++++ 3 files changed, 163 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b2d0cc5fedcae..806c8c0284e6d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1746,9 +1746,11 @@ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Colum # validate the column name if not hasattr(self._session, "is_mock_session"): - # Different from __getattr__, the name here can be quoted like df['`id`']. - # Only validate the name when it is not in the cached schema. - if item not in self.columns: + from pyspark.sql.connect.types import verify_col_name + + # Try best to verify the column name with cached schema + # If fails, fall back to the server side validation + if not verify_col_name(item, self.schema): self.select(item).isLocal() return self._col(item) diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 751d497657ee6..f058c6390612a 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -20,7 +20,7 @@ import json -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from pyspark.sql.types import ( DataType, @@ -315,3 +315,75 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: error_class="UNSUPPORTED_OPERATION", message_parameters={"operation": f"data type {schema}"}, ) + + +# The python version of org.apache.spark.sql.catalyst.util.AttributeNameParser +def parse_attr_name(name: str) -> Optional[List[str]]: + name_parts: List[str] = [] + tmp: str = "" + + in_backtick = False + i = 0 + n = len(name) + while i < n: + char = name[i] + if in_backtick: + if char == "`": + if i + 1 < n and name[i + 1] == "`": + tmp += "`" + i += 1 + else: + in_backtick = False + if i + 1 < n and name[i + 1] != ".": + return None + else: + tmp += char + else: + if char == "`": + if len(tmp) > 0: + return None + in_backtick = True + elif char == ".": + if name[i - 1] == "." or i == n - 1: + return None + name_parts.append(tmp) + tmp = "" + else: + tmp += char + i += 1 + + if in_backtick: + return None + + name_parts.append(tmp) + return name_parts + + +# Verify whether the input column name can be resolved with the given schema. +# Note that this method can not 100% match the analyzer behavior, it is designed to +# try the best to eliminate unnecessary validation RPCs. +def verify_col_name(name: str, schema: StructType) -> bool: + parts = parse_attr_name(name) + if parts is None or len(parts) == 0: + return False + + def _quick_verify(parts: List[str], schema: DataType) -> bool: + if len(parts) == 0: + return True + + _schema: Optional[StructType] = None + if isinstance(schema, StructType): + _schema = schema + elif isinstance(schema, ArrayType) and isinstance(schema.elementType, StructType): + _schema = schema.elementType + else: + return False + + part = parts[0] + for field in _schema: + if field.name == part: + return _quick_verify(parts[1:], field.dataType) + + return False + + return _quick_verify(parts, schema) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 4776851ba73d9..786b9e2896c2c 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1209,6 +1209,91 @@ def test_df_caache(self): self.assert_eq(10, df.count()) self.assertTrue(df.is_cached) + def test_parse_col_name(self): + from pyspark.sql.connect.types import parse_attr_name + + self.assert_eq(parse_attr_name(""), [""]) + + self.assert_eq(parse_attr_name("a"), ["a"]) + self.assert_eq(parse_attr_name("`a`"), ["a"]) + self.assert_eq(parse_attr_name("`a"), None) + self.assert_eq(parse_attr_name("a`"), None) + + self.assert_eq(parse_attr_name("a.b.c"), ["a", "b", "c"]) + self.assert_eq(parse_attr_name("`a`.`b`.`c`"), ["a", "b", "c"]) + self.assert_eq(parse_attr_name("a.`b`.c"), ["a", "b", "c"]) + + self.assert_eq(parse_attr_name("`a.b.c`"), ["a.b.c"]) + self.assert_eq(parse_attr_name("a.`b.c`"), ["a", "b.c"]) + self.assert_eq(parse_attr_name("`a.b`.c"), ["a.b", "c"]) + self.assert_eq(parse_attr_name("`a.b.c"), None) + self.assert_eq(parse_attr_name("a.b.c`"), None) + self.assert_eq(parse_attr_name("`a.`b.`c"), None) + self.assert_eq(parse_attr_name("a`.b`.c`"), None) + + self.assert_eq(parse_attr_name("`ab..c`e.f"), None) + + def test_verify_col_name(self): + from pyspark.sql.connect.types import verify_col_name + + cdf = ( + self.connect.range(10) + .withColumn("v", CF.lit(123)) + .withColumn("s", CF.struct("id", "v")) + .withColumn("m", CF.struct("s", "v")) + .withColumn("a", CF.array("s")) + ) + + # root + # |-- id: long (nullable = false) + # |-- v: integer (nullable = false) + # |-- s: struct (nullable = false) + # | |-- id: long (nullable = false) + # | |-- v: integer (nullable = false) + # |-- m: struct (nullable = false) + # | |-- s: struct (nullable = false) + # | | |-- id: long (nullable = false) + # | | |-- v: integer (nullable = false) + # | |-- v: integer (nullable = false) + # |-- a: array (nullable = false) + # | |-- element: struct (containsNull = false) + # | | |-- id: long (nullable = false) + # | | |-- v: integer (nullable = false) + + self.assertTrue(verify_col_name("id", cdf.schema)) + self.assertTrue(verify_col_name("`id`", cdf.schema)) + + self.assertTrue(verify_col_name("v", cdf.schema)) + self.assertTrue(verify_col_name("`v`", cdf.schema)) + + self.assertFalse(verify_col_name("x", cdf.schema)) + self.assertFalse(verify_col_name("`x`", cdf.schema)) + + self.assertTrue(verify_col_name("s", cdf.schema)) + self.assertTrue(verify_col_name("`s`", cdf.schema)) + self.assertTrue(verify_col_name("s.id", cdf.schema)) + self.assertTrue(verify_col_name("s.`id`", cdf.schema)) + self.assertTrue(verify_col_name("`s`.id", cdf.schema)) + self.assertTrue(verify_col_name("`s`.`id`", cdf.schema)) + self.assertFalse(verify_col_name("`s.id`", cdf.schema)) + + self.assertTrue(verify_col_name("m", cdf.schema)) + self.assertTrue(verify_col_name("`m`", cdf.schema)) + self.assertTrue(verify_col_name("m.s.id", cdf.schema)) + self.assertTrue(verify_col_name("m.s.`id`", cdf.schema)) + self.assertTrue(verify_col_name("m.`s`.id", cdf.schema)) + self.assertTrue(verify_col_name("`m`.`s`.`id`", cdf.schema)) + self.assertFalse(verify_col_name("m.`s.id`", cdf.schema)) + self.assertFalse(verify_col_name("m.`s.id`", cdf.schema)) + + self.assertTrue(verify_col_name("a", cdf.schema)) + self.assertTrue(verify_col_name("`a`", cdf.schema)) + self.assertTrue(verify_col_name("a.`v`", cdf.schema)) + self.assertTrue(verify_col_name("a.`v`", cdf.schema)) + self.assertTrue(verify_col_name("`a`.v", cdf.schema)) + self.assertTrue(verify_col_name("`a`.`v`", cdf.schema)) + self.assertFalse(verify_col_name("`a`.`x`", cdf.schema)) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401