Skip to content

Commit b76f0b9

Browse files
committed
[SPARK-49047][PYTHON][CONNECT] Truncate the message for logging
### What changes were proposed in this pull request? Truncate the message for logging, by truncating the bytes and string fields ### Why are the changes needed? existing implementation generates too massive logging ### Does this PR introduce _any_ user-facing change? No, logging only ``` In [7]: df = spark.createDataFrame([('a B c'), ('X y Z'), ], ['abc']) In [8]: plan = df._plan.to_proto(spark._client) In [9]: spark._client._proto_to_string(plan, False) Out[9]: 'root { common { plan_id: 4 } to_df { input { common { plan_id: 3 } local_relation { data: "\\377\\377\\377\\377p\\000\\000\\000\\020\\000\\000\\000\\000\\000\\n\\000\\014\\000\\006\\000\\005\\000\\010\\000\\n\\000\\000\\000\\000\\001\\004\\000\\014\\000\\000\\000\\010\\000\\010\\000\\000\\000\\004\\000\\010\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\024\\000\\000\\000\\020\\000\\024\\000\\010\\000\\006\\000\\007\\000\\014\\000\\000\\000\\020\\000\\020\\000\\000\\000\\000\\000\\001\\005\\020\\000\\000\\000\\030\\000\\000\\000\\004\\000\\000\\000\\000\\000\\000\\000\\003\\000\\000\\000abc\\000\\004\\000\\004\\000\\004\\000\\000\\000\\000\\000\\000\\000\\377\\377\\377\\377\\230\\000\\000\\000\\024\\000\\000\\000\\000\\000\\000\\000\\014\\000\\026\\000\\006\\000\\005\\000\\010\\000\\014\\000\\014\\000\\000\\000\\000\\003\\004\\000\\030\\000\\000\\000 \\000\\000\\000\\000\\000\\000\\000\\000\\000\\n\\000\\030\\000\\014\\000\\004\\000\\010\\000\\n\\000\\000\\000L\\000\\000\\000\\020\\000\\000\\000\\002\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\003\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\014\\000\\000\\000\\000\\000\\000\\000\\020\\000\\000\\000\\000\\000\\000\\000\\n\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\001\\000\\000\\000\\002\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\005\\000\\000\\000\\n\\000\\000\\000\\000\\000\\000\\000a B cX y Z\\000\\000\\000\\000\\000\\000\\377\\377\\377\\377\\000\\000\\000\\000" schema: "{\\"fields\\":[{\\"metadata\\":{},\\"name\\":\\"abc\\",\\"nullable\\":true,\\"type\\":\\"string\\"}],\\"type\\":\\"struct\\"}" } } column_names: "abc" } }' In [10]: spark._client._proto_to_string(plan, True) Out[10]: 'root { common { plan_id: 4 } to_df { input { common { plan_id: 3 } local_relation { data: "\\377\\377\\377\\377p\\000\\000\\000[truncated]" schema: "{\\"fields\\":[{\\"metadata\\":{},\\"name\\"[truncated]" } } column_names: "abc" } }' ``` ### How was this patch tested? added UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #47554 from zhengruifeng/py_client_truncate. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 4e69e16 commit b76f0b9

File tree

2 files changed

+83
-9
lines changed

2 files changed

+83
-9
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import logging
2828
import threading
2929
import os
30+
import copy
3031
import platform
3132
import urllib.parse
3233
import uuid
@@ -864,7 +865,7 @@ def to_table_as_iterator(
864865
Return given plan as a PyArrow Table iterator.
865866
"""
866867
if logger.isEnabledFor(logging.INFO):
867-
logger.info(f"Executing plan {self._proto_to_string(plan)}")
868+
logger.info(f"Executing plan {self._proto_to_string(plan, True)}")
868869
req = self._execute_plan_request_with_metadata()
869870
req.plan.CopyFrom(plan)
870871
with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress:
@@ -881,7 +882,7 @@ def to_table(
881882
Return given plan as a PyArrow Table.
882883
"""
883884
if logger.isEnabledFor(logging.INFO):
884-
logger.info(f"Executing plan {self._proto_to_string(plan)}")
885+
logger.info(f"Executing plan {self._proto_to_string(plan, True)}")
885886
req = self._execute_plan_request_with_metadata()
886887
req.plan.CopyFrom(plan)
887888
table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations)
@@ -898,7 +899,7 @@ def to_pandas(
898899
Return given plan as a pandas DataFrame.
899900
"""
900901
if logger.isEnabledFor(logging.INFO):
901-
logger.info(f"Executing plan {self._proto_to_string(plan)}")
902+
logger.info(f"Executing plan {self._proto_to_string(plan, True)}")
902903
req = self._execute_plan_request_with_metadata()
903904
req.plan.CopyFrom(plan)
904905
(self_destruct_conf,) = self.get_config_with_defaults(
@@ -978,7 +979,7 @@ def to_pandas(
978979
pdf.attrs["observed_metrics"] = observed_metrics
979980
return pdf, ei
980981

981-
def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
982+
def _proto_to_string(self, p: google.protobuf.message.Message, truncate: bool = False) -> str:
982983
"""
983984
Helper method to generate a one line string representation of the plan.
984985
@@ -992,16 +993,62 @@ def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
992993
Single line string of the serialized proto message.
993994
"""
994995
try:
995-
return text_format.MessageToString(p, as_one_line=True)
996+
p2 = self._truncate(p) if truncate else p
997+
return text_format.MessageToString(p2, as_one_line=True)
996998
except RecursionError:
997999
return "<Truncated message due to recursion error>"
9981000

1001+
def _truncate(self, p: google.protobuf.message.Message) -> google.protobuf.message.Message:
1002+
"""
1003+
Helper method to truncate the protobuf message.
1004+
Refer to 'org.apache.spark.sql.connect.common.Abbreviator' in the server side.
1005+
"""
1006+
1007+
def truncate_str(s: str) -> str:
1008+
if len(s) > 1024:
1009+
return s[:1024] + "[truncated]"
1010+
return s
1011+
1012+
def truncate_bytes(b: bytes) -> bytes:
1013+
if len(b) > 8:
1014+
return b[:8] + b"[truncated]"
1015+
return b
1016+
1017+
p2 = copy.deepcopy(p)
1018+
1019+
for descriptor, value in p.ListFields():
1020+
if value is not None:
1021+
field_name = descriptor.name
1022+
1023+
if descriptor.type == descriptor.TYPE_MESSAGE:
1024+
if descriptor.label == descriptor.LABEL_REPEATED:
1025+
p2.ClearField(field_name)
1026+
getattr(p2, field_name).extend([self._truncate(v) for v in value])
1027+
else:
1028+
getattr(p2, field_name).CopyFrom(self._truncate(value))
1029+
1030+
elif descriptor.type == descriptor.TYPE_STRING:
1031+
if descriptor.label == descriptor.LABEL_REPEATED:
1032+
p2.ClearField(field_name)
1033+
getattr(p2, field_name).extend([truncate_str(v) for v in value])
1034+
else:
1035+
setattr(p2, field_name, truncate_str(value))
1036+
1037+
elif descriptor.type == descriptor.TYPE_BYTES:
1038+
if descriptor.label == descriptor.LABEL_REPEATED:
1039+
p2.ClearField(field_name)
1040+
getattr(p2, field_name).extend([truncate_bytes(v) for v in value])
1041+
else:
1042+
setattr(p2, field_name, truncate_bytes(value))
1043+
1044+
return p2
1045+
9991046
def schema(self, plan: pb2.Plan) -> StructType:
10001047
"""
10011048
Return schema for given plan.
10021049
"""
10031050
if logger.isEnabledFor(logging.INFO):
1004-
logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
1051+
logger.info(f"Schema for plan: {self._proto_to_string(plan, True)}")
10051052
schema = self._analyze(method="schema", plan=plan).schema
10061053
assert schema is not None
10071054
# Server side should populate the struct field which is the schema.
@@ -1013,7 +1060,9 @@ def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:
10131060
Return explain string for given plan.
10141061
"""
10151062
if logger.isEnabledFor(logging.INFO):
1016-
logger.info(f"Explain (mode={explain_mode}) for plan {self._proto_to_string(plan)}")
1063+
logger.info(
1064+
f"Explain (mode={explain_mode}) for plan {self._proto_to_string(plan, True)}"
1065+
)
10171066
result = self._analyze(
10181067
method="explain", plan=plan, explain_mode=explain_mode
10191068
).explain_string
@@ -1027,7 +1076,7 @@ def execute_command(
10271076
Execute given command.
10281077
"""
10291078
if logger.isEnabledFor(logging.INFO):
1030-
logger.info(f"Execute command for command {self._proto_to_string(command)}")
1079+
logger.info(f"Execute command for command {self._proto_to_string(command, True)}")
10311080
req = self._execute_plan_request_with_metadata()
10321081
if self._user_id:
10331082
req.user_context.user_id = self._user_id
@@ -1049,7 +1098,9 @@ def execute_command_as_iterator(
10491098
Execute given command. Similar to execute_command, but the value is returned using yield.
10501099
"""
10511100
if logger.isEnabledFor(logging.INFO):
1052-
logger.info(f"Execute command as iterator for command {self._proto_to_string(command)}")
1101+
logger.info(
1102+
f"Execute command as iterator for command {self._proto_to_string(command, True)}"
1103+
)
10531104
req = self._execute_plan_request_with_metadata()
10541105
if self._user_id:
10551106
req.user_context.user_id = self._user_id

python/pyspark/sql/tests/connect/test_connect_basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,29 @@ def test_verify_col_name(self):
14051405
self.assertTrue(verify_col_name("`m```.`s.s`.v", cdf.schema))
14061406
self.assertTrue(verify_col_name("`m```.`s.s`.`v`", cdf.schema))
14071407

1408+
def test_truncate_message(self):
1409+
cdf1 = self.connect.createDataFrame(
1410+
[
1411+
("a B c"),
1412+
("X y Z"),
1413+
],
1414+
["a" * 4096],
1415+
)
1416+
plan1 = cdf1._plan.to_proto(self.connect._client)
1417+
1418+
proto_string_1 = self.connect._client._proto_to_string(plan1, False)
1419+
self.assertTrue(len(proto_string_1) > 10000, len(proto_string_1))
1420+
proto_string_truncated_1 = self.connect._client._proto_to_string(plan1, True)
1421+
self.assertTrue(len(proto_string_truncated_1) < 4000, len(proto_string_truncated_1))
1422+
1423+
cdf2 = cdf1.select("a" * 4096, "a" * 4096, "a" * 4096)
1424+
plan2 = cdf2._plan.to_proto(self.connect._client)
1425+
1426+
proto_string_2 = self.connect._client._proto_to_string(plan2, False)
1427+
self.assertTrue(len(proto_string_2) > 20000, len(proto_string_2))
1428+
proto_string_truncated_2 = self.connect._client._proto_to_string(plan2, True)
1429+
self.assertTrue(len(proto_string_truncated_2) < 8000, len(proto_string_truncated_2))
1430+
14081431

14091432
class SparkConnectGCTests(SparkConnectSQLTestCase):
14101433
@classmethod

0 commit comments

Comments
 (0)