Skip to content

Commit

Permalink
[SPARK-47664][PYTHON][CONNECT] Validate the column name with cached s…
Browse files Browse the repository at this point in the history
…chema

### What changes were proposed in this pull request?
improve the column name validation, to try the best to avoid RPC.

### Why are the changes needed?
existing validation contains two parts:

1. check whether the column name is in `self.columns` <- client side validation;
2. if step 1 fail, validate with additional RPC `df.select(...)` <- RPC;

the client side validation is too simple, and this PR aims to improve it to cover more cases:
1. backticks:
```
'`a`'
```
2. nested fields:
```
'a.b.c'
```

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci, added ut

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #45788 from zhengruifeng/column_name_validate.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Apr 1, 2024
1 parent cf02b1a commit 968cba2
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 4 deletions.
8 changes: 5 additions & 3 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,9 +1746,11 @@ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Colum

# validate the column name
if not hasattr(self._session, "is_mock_session"):
# Different from __getattr__, the name here can be quoted like df['`id`'].
# Only validate the name when it is not in the cached schema.
if item not in self.columns:
from pyspark.sql.connect.types import verify_col_name

# Try best to verify the column name with cached schema
# If fails, fall back to the server side validation
if not verify_col_name(item, self.schema):
self.select(item).isLocal()

return self._col(item)
Expand Down
74 changes: 73 additions & 1 deletion python/pyspark/sql/connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import json

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List

from pyspark.sql.types import (
DataType,
Expand Down Expand Up @@ -315,3 +315,75 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
error_class="UNSUPPORTED_OPERATION",
message_parameters={"operation": f"data type {schema}"},
)


# The python version of org.apache.spark.sql.catalyst.util.AttributeNameParser
def parse_attr_name(name: str) -> Optional[List[str]]:
name_parts: List[str] = []
tmp: str = ""

in_backtick = False
i = 0
n = len(name)
while i < n:
char = name[i]
if in_backtick:
if char == "`":
if i + 1 < n and name[i + 1] == "`":
tmp += "`"
i += 1
else:
in_backtick = False
if i + 1 < n and name[i + 1] != ".":
return None
else:
tmp += char
else:
if char == "`":
if len(tmp) > 0:
return None
in_backtick = True
elif char == ".":
if name[i - 1] == "." or i == n - 1:
return None
name_parts.append(tmp)
tmp = ""
else:
tmp += char
i += 1

if in_backtick:
return None

name_parts.append(tmp)
return name_parts


# Verify whether the input column name can be resolved with the given schema.
# Note that this method can not 100% match the analyzer behavior, it is designed to
# try the best to eliminate unnecessary validation RPCs.
def verify_col_name(name: str, schema: StructType) -> bool:
parts = parse_attr_name(name)
if parts is None or len(parts) == 0:
return False

def _quick_verify(parts: List[str], schema: DataType) -> bool:
if len(parts) == 0:
return True

_schema: Optional[StructType] = None
if isinstance(schema, StructType):
_schema = schema
elif isinstance(schema, ArrayType) and isinstance(schema.elementType, StructType):
_schema = schema.elementType
else:
return False

part = parts[0]
for field in _schema:
if field.name == part:
return _quick_verify(parts[1:], field.dataType)

return False

return _quick_verify(parts, schema)
85 changes: 85 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,91 @@ def test_df_caache(self):
self.assert_eq(10, df.count())
self.assertTrue(df.is_cached)

def test_parse_col_name(self):
from pyspark.sql.connect.types import parse_attr_name

self.assert_eq(parse_attr_name(""), [""])

self.assert_eq(parse_attr_name("a"), ["a"])
self.assert_eq(parse_attr_name("`a`"), ["a"])
self.assert_eq(parse_attr_name("`a"), None)
self.assert_eq(parse_attr_name("a`"), None)

self.assert_eq(parse_attr_name("a.b.c"), ["a", "b", "c"])
self.assert_eq(parse_attr_name("`a`.`b`.`c`"), ["a", "b", "c"])
self.assert_eq(parse_attr_name("a.`b`.c"), ["a", "b", "c"])

self.assert_eq(parse_attr_name("`a.b.c`"), ["a.b.c"])
self.assert_eq(parse_attr_name("a.`b.c`"), ["a", "b.c"])
self.assert_eq(parse_attr_name("`a.b`.c"), ["a.b", "c"])
self.assert_eq(parse_attr_name("`a.b.c"), None)
self.assert_eq(parse_attr_name("a.b.c`"), None)
self.assert_eq(parse_attr_name("`a.`b.`c"), None)
self.assert_eq(parse_attr_name("a`.b`.c`"), None)

self.assert_eq(parse_attr_name("`ab..c`e.f"), None)

def test_verify_col_name(self):
from pyspark.sql.connect.types import verify_col_name

cdf = (
self.connect.range(10)
.withColumn("v", CF.lit(123))
.withColumn("s", CF.struct("id", "v"))
.withColumn("m", CF.struct("s", "v"))
.withColumn("a", CF.array("s"))
)

# root
# |-- id: long (nullable = false)
# |-- v: integer (nullable = false)
# |-- s: struct (nullable = false)
# | |-- id: long (nullable = false)
# | |-- v: integer (nullable = false)
# |-- m: struct (nullable = false)
# | |-- s: struct (nullable = false)
# | | |-- id: long (nullable = false)
# | | |-- v: integer (nullable = false)
# | |-- v: integer (nullable = false)
# |-- a: array (nullable = false)
# | |-- element: struct (containsNull = false)
# | | |-- id: long (nullable = false)
# | | |-- v: integer (nullable = false)

self.assertTrue(verify_col_name("id", cdf.schema))
self.assertTrue(verify_col_name("`id`", cdf.schema))

self.assertTrue(verify_col_name("v", cdf.schema))
self.assertTrue(verify_col_name("`v`", cdf.schema))

self.assertFalse(verify_col_name("x", cdf.schema))
self.assertFalse(verify_col_name("`x`", cdf.schema))

self.assertTrue(verify_col_name("s", cdf.schema))
self.assertTrue(verify_col_name("`s`", cdf.schema))
self.assertTrue(verify_col_name("s.id", cdf.schema))
self.assertTrue(verify_col_name("s.`id`", cdf.schema))
self.assertTrue(verify_col_name("`s`.id", cdf.schema))
self.assertTrue(verify_col_name("`s`.`id`", cdf.schema))
self.assertFalse(verify_col_name("`s.id`", cdf.schema))

self.assertTrue(verify_col_name("m", cdf.schema))
self.assertTrue(verify_col_name("`m`", cdf.schema))
self.assertTrue(verify_col_name("m.s.id", cdf.schema))
self.assertTrue(verify_col_name("m.s.`id`", cdf.schema))
self.assertTrue(verify_col_name("m.`s`.id", cdf.schema))
self.assertTrue(verify_col_name("`m`.`s`.`id`", cdf.schema))
self.assertFalse(verify_col_name("m.`s.id`", cdf.schema))
self.assertFalse(verify_col_name("m.`s.id`", cdf.schema))

self.assertTrue(verify_col_name("a", cdf.schema))
self.assertTrue(verify_col_name("`a`", cdf.schema))
self.assertTrue(verify_col_name("a.`v`", cdf.schema))
self.assertTrue(verify_col_name("a.`v`", cdf.schema))
self.assertTrue(verify_col_name("`a`.v", cdf.schema))
self.assertTrue(verify_col_name("`a`.`v`", cdf.schema))
self.assertFalse(verify_col_name("`a`.`x`", cdf.schema))


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
Expand Down

0 comments on commit 968cba2

Please sign in to comment.