Skip to content

Commit

Permalink
[SPARK-48713][SQL] Add index range check for UnsafeRow.pointTo when b…
Browse files Browse the repository at this point in the history
…aseObject is byte array

### What changes were proposed in this pull request?

This PR proposes to add a index range check for `UnsafeRow.pointTo()` when `baseObject` is byte array.

### Why are the changes needed?

All the other places like `readExternal()`, `read()` ensures `sizeInBytes` can't be larger than the lenght of `baseObject` when it is a byte array excepet `pointTo()`. So adding this check helps us to get better error stack info in the first place when the index went wrong.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added unit test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47087 from Ngone51/ur.

Authored-by: Yi Wu <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
  • Loading branch information
Ngone51 authored and yaooqinn committed Jun 26, 2024
1 parent 4cf5450 commit 313479c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.Map;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;

import org.apache.spark.SparkIllegalArgumentException;
import org.apache.spark.SparkUnsupportedOperationException;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.types.*;
Expand Down Expand Up @@ -155,6 +157,17 @@ public UnsafeRow() {}
public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
assert numFields >= 0 : "numFields (" + numFields + ") should >= 0";
assert sizeInBytes % 8 == 0 : "sizeInBytes (" + sizeInBytes + ") should be a multiple of 8";
if (baseObject instanceof byte[] bytes) {
int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET);
if (offsetInByteArray < 0 || sizeInBytes < 0 ||
bytes.length < offsetInByteArray + sizeInBytes) {
throw new SparkIllegalArgumentException(
"INTERNAL_ERROR",
Map.of("message", "Invalid byte array backed UnsafeRow: byte array length=" +
bytes.length + ", offset=" + offsetInByteArray + ", byte size=" + sizeInBytes)
);
}
}
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.sizeInBytes = sizeInBytes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql

import java.io.ByteArrayOutputStream

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.{SparkConf, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
Expand Down Expand Up @@ -188,4 +188,40 @@ class UnsafeRowSuite extends SparkFunSuite {
unsafeRow.setDecimal(0, d2, 38)
assert(unsafeRow.getDecimal(0, 38, 18) === null)
}

test("SPARK-48713: throw SparkIllegalArgumentException for illegal UnsafeRow.pointTo") {
val emptyRow = UnsafeRow.createFromByteArray(64, 2)
val byteArray = new Array[Byte](64)

// Out of bounds
var errorMsg = intercept[SparkIllegalArgumentException] {
emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, 32)
}.getMessage
assert(
errorMsg.contains(
"Invalid byte array backed UnsafeRow: byte array length=64, offset=50, byte size=32"
)
)

// Negative size
errorMsg = intercept[SparkIllegalArgumentException] {
emptyRow.pointTo(byteArray, Platform.BYTE_ARRAY_OFFSET + 50, -32)
}.getMessage
assert(
errorMsg.contains(
"Invalid byte array backed UnsafeRow: byte array length=64, offset=50, byte size=-32"
)
)

// Negative offset
errorMsg = intercept[SparkIllegalArgumentException] {
emptyRow.pointTo(byteArray, -5, 32)
}.getMessage
assert(
errorMsg.contains(
s"Invalid byte array backed UnsafeRow: byte array length=64, " +
s"offset=${-5 - Platform.BYTE_ARRAY_OFFSET}, byte size=32"
)
)
}
}

0 comments on commit 313479c

Please sign in to comment.