Skip to content

Commit

Permalink
[SPARK-49047][PYTHON][CONNECT] Truncate the message for logging
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Truncate the message for logging, by truncating the bytes and string fields

### Why are the changes needed?
existing implementation generates too massive logging

### Does this PR introduce _any_ user-facing change?
No, logging only

```
In [7]: df = spark.createDataFrame([('a B c'), ('X y Z'), ], ['abc'])

In [8]: plan = df._plan.to_proto(spark._client)

In [9]: spark._client._proto_to_string(plan, False)
Out[9]: 'root { common { plan_id: 4 } to_df { input { common { plan_id: 3 } local_relation { data: "\\377\\377\\377\\377p\\000\\000\\000\\020\\000\\000\\000\\000\\000\\n\\000\\014\\000\\006\\000\\005\\000\\010\\000\\n\\000\\000\\000\\000\\001\\004\\000\\014\\000\\000\\000\\010\\000\\010\\000\\000\\000\\004\\000\\010\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\024\\000\\000\\000\\020\\000\\024\\000\\010\\000\\006\\000\\007\\000\\014\\000\\000\\000\\020\\000\\020\\000\\000\\000\\000\\000\\001\\005\\020\\000\\000\\000\\030\\000\\000\\000\\004\\000\\000\\000\\000\\000\\000\\000\\003\\000\\000\\000abc\\000\\004\\000\\004\\000\\004\\000\\000\\000\\000\\000\\000\\000\\377\\377\\377\\377\\230\\000\\000\\000\\024\\000\\000\\000\\000\\000\\000\\000\\014\\000\\026\\000\\006\\000\\005\\000\\010\\000\\014\\000\\014\\000\\000\\000\\000\\003\\004\\000\\030\\000\\000\\000 \\000\\000\\000\\000\\000\\000\\000\\000\\000\\n\\000\\030\\000\\014\\000\\004\\000\\010\\000\\n\\000\\000\\000L\\000\\000\\000\\020\\000\\000\\000\\002\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\003\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\014\\000\\000\\000\\000\\000\\000\\000\\020\\000\\000\\000\\000\\000\\000\\000\\n\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\001\\000\\000\\000\\002\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\005\\000\\000\\000\\n\\000\\000\\000\\000\\000\\000\\000a B cX y Z\\000\\000\\000\\000\\000\\000\\377\\377\\377\\377\\000\\000\\000\\000" schema: "{\\"fields\\":[{\\"metadata\\":{},\\"name\\":\\"abc\\",\\"nullable\\":true,\\"type\\":\\"string\\"}],\\"type\\":\\"struct\\"}" } } column_names: "abc" } }'

In [10]: spark._client._proto_to_string(plan, True)
Out[10]: 'root { common { plan_id: 4 } to_df { input { common { plan_id: 3 } local_relation { data: "\\377\\377\\377\\377p\\000\\000\\000[truncated]" schema: "{\\"fields\\":[{\\"metadata\\":{},\\"name\\"[truncated]" } } column_names: "abc" } }'
```

### How was this patch tested?
added UT

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #47554 from zhengruifeng/py_client_truncate.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Aug 5, 2024
1 parent 4e69e16 commit b76f0b9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 9 deletions.
69 changes: 60 additions & 9 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import threading
import os
import copy
import platform
import urllib.parse
import uuid
Expand Down Expand Up @@ -864,7 +865,7 @@ def to_table_as_iterator(
Return given plan as a PyArrow Table iterator.
"""
if logger.isEnabledFor(logging.INFO):
logger.info(f"Executing plan {self._proto_to_string(plan)}")
logger.info(f"Executing plan {self._proto_to_string(plan, True)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress:
Expand All @@ -881,7 +882,7 @@ def to_table(
Return given plan as a PyArrow Table.
"""
if logger.isEnabledFor(logging.INFO):
logger.info(f"Executing plan {self._proto_to_string(plan)}")
logger.info(f"Executing plan {self._proto_to_string(plan, True)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations)
Expand All @@ -898,7 +899,7 @@ def to_pandas(
Return given plan as a pandas DataFrame.
"""
if logger.isEnabledFor(logging.INFO):
logger.info(f"Executing plan {self._proto_to_string(plan)}")
logger.info(f"Executing plan {self._proto_to_string(plan, True)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
(self_destruct_conf,) = self.get_config_with_defaults(
Expand Down Expand Up @@ -978,7 +979,7 @@ def to_pandas(
pdf.attrs["observed_metrics"] = observed_metrics
return pdf, ei

def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
def _proto_to_string(self, p: google.protobuf.message.Message, truncate: bool = False) -> str:
"""
Helper method to generate a one line string representation of the plan.
Expand All @@ -992,16 +993,62 @@ def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
Single line string of the serialized proto message.
"""
try:
return text_format.MessageToString(p, as_one_line=True)
p2 = self._truncate(p) if truncate else p
return text_format.MessageToString(p2, as_one_line=True)
except RecursionError:
return "<Truncated message due to recursion error>"

def _truncate(self, p: google.protobuf.message.Message) -> google.protobuf.message.Message:
"""
Helper method to truncate the protobuf message.
Refer to 'org.apache.spark.sql.connect.common.Abbreviator' in the server side.
"""

def truncate_str(s: str) -> str:
if len(s) > 1024:
return s[:1024] + "[truncated]"
return s

def truncate_bytes(b: bytes) -> bytes:
if len(b) > 8:
return b[:8] + b"[truncated]"
return b

p2 = copy.deepcopy(p)

for descriptor, value in p.ListFields():
if value is not None:
field_name = descriptor.name

if descriptor.type == descriptor.TYPE_MESSAGE:
if descriptor.label == descriptor.LABEL_REPEATED:
p2.ClearField(field_name)
getattr(p2, field_name).extend([self._truncate(v) for v in value])
else:
getattr(p2, field_name).CopyFrom(self._truncate(value))

elif descriptor.type == descriptor.TYPE_STRING:
if descriptor.label == descriptor.LABEL_REPEATED:
p2.ClearField(field_name)
getattr(p2, field_name).extend([truncate_str(v) for v in value])
else:
setattr(p2, field_name, truncate_str(value))

elif descriptor.type == descriptor.TYPE_BYTES:
if descriptor.label == descriptor.LABEL_REPEATED:
p2.ClearField(field_name)
getattr(p2, field_name).extend([truncate_bytes(v) for v in value])
else:
setattr(p2, field_name, truncate_bytes(value))

return p2

def schema(self, plan: pb2.Plan) -> StructType:
"""
Return schema for given plan.
"""
if logger.isEnabledFor(logging.INFO):
logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
logger.info(f"Schema for plan: {self._proto_to_string(plan, True)}")
schema = self._analyze(method="schema", plan=plan).schema
assert schema is not None
# Server side should populate the struct field which is the schema.
Expand All @@ -1013,7 +1060,9 @@ def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:
Return explain string for given plan.
"""
if logger.isEnabledFor(logging.INFO):
logger.info(f"Explain (mode={explain_mode}) for plan {self._proto_to_string(plan)}")
logger.info(
f"Explain (mode={explain_mode}) for plan {self._proto_to_string(plan, True)}"
)
result = self._analyze(
method="explain", plan=plan, explain_mode=explain_mode
).explain_string
Expand All @@ -1027,7 +1076,7 @@ def execute_command(
Execute given command.
"""
if logger.isEnabledFor(logging.INFO):
logger.info(f"Execute command for command {self._proto_to_string(command)}")
logger.info(f"Execute command for command {self._proto_to_string(command, True)}")
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
Expand All @@ -1049,7 +1098,9 @@ def execute_command_as_iterator(
Execute given command. Similar to execute_command, but the value is returned using yield.
"""
if logger.isEnabledFor(logging.INFO):
logger.info(f"Execute command as iterator for command {self._proto_to_string(command)}")
logger.info(
f"Execute command as iterator for command {self._proto_to_string(command, True)}"
)
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
Expand Down
23 changes: 23 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,29 @@ def test_verify_col_name(self):
self.assertTrue(verify_col_name("`m```.`s.s`.v", cdf.schema))
self.assertTrue(verify_col_name("`m```.`s.s`.`v`", cdf.schema))

def test_truncate_message(self):
cdf1 = self.connect.createDataFrame(
[
("a B c"),
("X y Z"),
],
["a" * 4096],
)
plan1 = cdf1._plan.to_proto(self.connect._client)

proto_string_1 = self.connect._client._proto_to_string(plan1, False)
self.assertTrue(len(proto_string_1) > 10000, len(proto_string_1))
proto_string_truncated_1 = self.connect._client._proto_to_string(plan1, True)
self.assertTrue(len(proto_string_truncated_1) < 4000, len(proto_string_truncated_1))

cdf2 = cdf1.select("a" * 4096, "a" * 4096, "a" * 4096)
plan2 = cdf2._plan.to_proto(self.connect._client)

proto_string_2 = self.connect._client._proto_to_string(plan2, False)
self.assertTrue(len(proto_string_2) > 20000, len(proto_string_2))
proto_string_truncated_2 = self.connect._client._proto_to_string(plan2, True)
self.assertTrue(len(proto_string_truncated_2) < 8000, len(proto_string_truncated_2))


class SparkConnectGCTests(SparkConnectSQLTestCase):
@classmethod
Expand Down

0 comments on commit b76f0b9

Please sign in to comment.