Skip to content

Commit

Permalink
[SPARK-46505][CONNECT] Make bytes threshold configurable in `ProtoUti…
Browse files Browse the repository at this point in the history
…ls.abbreviate`

### What changes were proposed in this pull request?
Make bytes threshold configurable in `ProtoUtils.abbreviate`

### Why are the changes needed?
the bytes threshold should be also configurable, like string type

### Does this PR introduce _any_ user-facing change?
no, this function is only used internally

### How was this patch tested?
added ut

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #44486 from zhengruifeng/connect_ab_config.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Dec 26, 2023
1 parent 439ec6b commit 4911a5b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,38 @@ import com.google.protobuf.Descriptors.FieldDescriptor

private[connect] object ProtoUtils {
private val format = java.text.NumberFormat.getInstance()
private val BYTES = "BYTES"
private val STRING = "STRING"
private val MAX_BYTES_SIZE = 8
private val MAX_STRING_SIZE = 1024

def abbreviate(message: Message, maxStringSize: Int = MAX_STRING_SIZE): Message = {
abbreviate(message, Map(STRING -> maxStringSize))
}

def abbreviate(message: Message, thresholds: Map[String, Int]): Message = {
val builder = message.toBuilder

message.getAllFields.asScala.iterator.foreach {
case (field: FieldDescriptor, string: String)
if field.getJavaType == FieldDescriptor.JavaType.STRING && string != null =>
val size = string.length
if (size > maxStringSize) {
builder.setField(field, createString(string.take(maxStringSize), size))
val threshold = thresholds.getOrElse(STRING, MAX_STRING_SIZE)
if (size > threshold) {
builder.setField(field, createString(string.take(threshold), size))
} else {
builder.setField(field, string)
}

case (field: FieldDescriptor, byteString: ByteString)
if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteString != null =>
val size = byteString.size
if (size > MAX_BYTES_SIZE) {
val threshold = thresholds.getOrElse(BYTES, MAX_BYTES_SIZE)
if (size > threshold) {
builder.setField(
field,
byteString
.substring(0, MAX_BYTES_SIZE)
.substring(0, threshold)
.concat(createTruncatedByteString(size)))
} else {
builder.setField(field, byteString)
Expand All @@ -56,11 +64,12 @@ private[connect] object ProtoUtils {
case (field: FieldDescriptor, byteArray: Array[Byte])
if field.getJavaType == FieldDescriptor.JavaType.BYTE_STRING && byteArray != null =>
val size = byteArray.length
if (size > MAX_BYTES_SIZE) {
val threshold = thresholds.getOrElse(BYTES, MAX_BYTES_SIZE)
if (size > threshold) {
builder.setField(
field,
ByteString
.copyFrom(byteArray, 0, MAX_BYTES_SIZE)
.copyFrom(byteArray, 0, threshold)
.concat(createTruncatedByteString(size)))
} else {
builder.setField(field, byteArray)
Expand All @@ -69,7 +78,7 @@ private[connect] object ProtoUtils {
// TODO(SPARK-43117): should also support 1, repeated msg; 2, map<xxx, msg>
case (field: FieldDescriptor, msg: Message)
if field.getJavaType == FieldDescriptor.JavaType.MESSAGE && msg != null =>
builder.setField(field, abbreviate(msg, maxStringSize))
builder.setField(field, abbreviate(msg, thresholds))

case (field: FieldDescriptor, value: Any) => builder.setField(field, value)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,34 @@ class AbbreviateSuite extends SparkFunSuite {
}
}
}

test("truncate bytes with threshold: simple python udf") {
val bytes = Array.ofDim[Byte](1024)
val message = proto.PythonUDF
.newBuilder()
.setEvalType(1)
.setOutputType(ProtoDataTypes.BinaryType)
.setCommand(ByteString.copyFrom(bytes))
.setPythonVer("3.12")
.build()

Seq(1, 16, 256, 512, 1024, 2048).foreach { threshold =>
val truncated = ProtoUtils.abbreviate(message, Map("BYTES" -> threshold))
assert(truncated.isInstanceOf[proto.PythonUDF])

val truncatedUDF = truncated.asInstanceOf[proto.PythonUDF]
assert(truncatedUDF.getEvalType === 1)
assert(truncatedUDF.getOutputType === ProtoDataTypes.BinaryType)
assert(truncatedUDF.getPythonVer === "3.12")

if (threshold < 1024) {
// with suffix: [truncated(size=...)]
assert(
threshold < truncatedUDF.getCommand.size() &&
truncatedUDF.getCommand.size() < threshold + 64)
} else {
assert(truncatedUDF.getCommand.size() === 1024)
}
}
}
}

0 comments on commit 4911a5b

Please sign in to comment.