Skip to content

Commit

Permalink
Fix comparison of protobuf expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimirg-db committed Mar 7, 2025
1 parent 27c6b5b commit 23021f6
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,35 @@ private[sql] case class CatalystDataToProtobuf(

override protected def withNewChildInternal(newChild: Expression): CatalystDataToProtobuf =
copy(child = newChild)

override def equals(that: Any): Boolean = {
that match {
case that: CatalystDataToProtobuf =>
this.child == that.child &&
this.messageName == that.messageName &&
(
(this.binaryFileDescriptorSet.isEmpty && that.binaryFileDescriptorSet.isEmpty) ||
(
this.binaryFileDescriptorSet.nonEmpty && that.binaryFileDescriptorSet.nonEmpty &&
this.binaryFileDescriptorSet.get.sameElements(that.binaryFileDescriptorSet.get)
)
) &&
this.options == that.options
case _ => false
}
}

override def hashCode(): Int = {
val prime = 31
var result = 1
var i = 0
while (i < binaryFileDescriptorSet.map(_.length).getOrElse(0)) {
result = prime * result + binaryFileDescriptorSet.get.apply(i).hashCode
i += 1
}
result = prime * result + child.hashCode
result = prime * result + messageName.hashCode
result = prime * result + options.hashCode
result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,35 @@ private[sql] case class ProtobufDataToCatalyst(

override protected def withNewChildInternal(newChild: Expression): ProtobufDataToCatalyst =
copy(child = newChild)

override def equals(that: Any): Boolean = {
that match {
case that: ProtobufDataToCatalyst =>
this.child == that.child &&
this.messageName == that.messageName &&
(
(this.binaryFileDescriptorSet.isEmpty && that.binaryFileDescriptorSet.isEmpty) ||
(
this.binaryFileDescriptorSet.nonEmpty && that.binaryFileDescriptorSet.nonEmpty &&
this.binaryFileDescriptorSet.get.sameElements(that.binaryFileDescriptorSet.get)
)
) &&
this.options == that.options
case _ => false
}
}

override def hashCode(): Int = {
val prime = 31
var result = 1
var i = 0
while (i < binaryFileDescriptorSet.map(_.length).getOrElse(0)) {
result = prime * result + binaryFileDescriptorSet.get.apply(i).hashCode
i += 1
}
result = prime * result + child.hashCode
result = prime * result + messageName.hashCode
result = prime * result + options.hashCode
result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,177 @@ class ProtobufCatalystDataConversionSuite
testFileDesc, "org.apache.spark.sql.protobuf.protos.BytesMsg")
assert(withFullName.findFieldByName("bytes_type") != null)
}

test("CatalystDataToProtobuf equals") {
val catalystDataToProtobuf = generateCatalystDataToProtobuf()

assert(
catalystDataToProtobuf
== catalystDataToProtobuf.copy()
)
assert(
catalystDataToProtobuf
!= catalystDataToProtobuf.copy(options = Map("mode" -> "FAILFAST"))
)
assert(
catalystDataToProtobuf
!= catalystDataToProtobuf.copy(messageName = "otherMessage")
)
assert(
catalystDataToProtobuf
!= catalystDataToProtobuf.copy(child = Literal.create(0, IntegerType))
)
assert(
catalystDataToProtobuf
!= catalystDataToProtobuf.copy(binaryFileDescriptorSet = None)
)

val testFileDescCopy = new Array[Byte](testFileDesc.length)
testFileDesc.copyToArray(testFileDescCopy)
assert(
catalystDataToProtobuf
== catalystDataToProtobuf.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
)

testFileDescCopy(0) = '0'
assert(
catalystDataToProtobuf
!= catalystDataToProtobuf.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
)
}

test("CatalystDataToProtobuf hashCode") {
val catalystDataToProtobuf = generateCatalystDataToProtobuf()

assert(
catalystDataToProtobuf.hashCode == 18619165
)
assert(
catalystDataToProtobuf.copy(options = Map("mode" -> "FAILFAST")).hashCode == -1634963844
)
assert(
catalystDataToProtobuf.copy(messageName = "otherMessage").hashCode == -1751271943
)
assert(
catalystDataToProtobuf.copy(child = Literal.create(0, IntegerType)).hashCode == -2051339781
)
assert(
catalystDataToProtobuf.copy(binaryFileDescriptorSet = None).hashCode == 866765483
)

val testFileDescCopy = new Array[Byte](testFileDesc.length)
testFileDesc.copyToArray(testFileDescCopy)
assert(
catalystDataToProtobuf.copy(
binaryFileDescriptorSet = Some(testFileDescCopy)
).hashCode == -937893175
)

testFileDescCopy(0) = '0'
assert(
catalystDataToProtobuf.copy(
binaryFileDescriptorSet = Some(testFileDescCopy)
).hashCode == -1493098769
)
}

test("ProtobufDataToCatalyst equals") {
val catalystDataToProtobuf = generateCatalystDataToProtobuf()
val protobufDataToCatalyst = ProtobufDataToCatalyst(
catalystDataToProtobuf,
"message",
Some(testFileDesc),
Map("mode" -> "PERMISSIVE")
)

assert(
protobufDataToCatalyst
== protobufDataToCatalyst.copy()
)
assert(
protobufDataToCatalyst
!= protobufDataToCatalyst.copy(options = Map("mode" -> "FAILFAST"))
)
assert(
protobufDataToCatalyst
!= protobufDataToCatalyst.copy(messageName = "otherMessage")
)
assert(
protobufDataToCatalyst
!= protobufDataToCatalyst.copy(child = Literal.create(0, IntegerType))
)
assert(
protobufDataToCatalyst
!= protobufDataToCatalyst.copy(binaryFileDescriptorSet = None)
)

val testFileDescCopy = new Array[Byte](testFileDesc.length)
testFileDesc.copyToArray(testFileDescCopy)
assert(
protobufDataToCatalyst
== protobufDataToCatalyst.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
)

testFileDescCopy(0) = '0'
assert(
protobufDataToCatalyst
!= protobufDataToCatalyst.copy(binaryFileDescriptorSet = Some(testFileDescCopy))
)
}

test("ProtobufDataToCatalyst hashCode") {
val catalystDataToProtobuf = generateCatalystDataToProtobuf()
val protobufDataToCatalyst = ProtobufDataToCatalyst(
catalystDataToProtobuf,
"message",
Some(testFileDesc),
Map("mode" -> "PERMISSIVE")
)

assert(
protobufDataToCatalyst.hashCode == -937893175
)
assert(
protobufDataToCatalyst.copy(options = Map("mode" -> "FAILFAST")).hashCode == -1634963844
)
assert(
protobufDataToCatalyst.copy(messageName = "otherMessage").hashCode == -1751271943
)
assert(
protobufDataToCatalyst.copy(child = Literal.create(0, IntegerType)).hashCode == -133420428
)
assert(
protobufDataToCatalyst.copy(binaryFileDescriptorSet = None).hashCode == 866765483
)

val testFileDescCopy = new Array[Byte](testFileDesc.length)
testFileDesc.copyToArray(testFileDescCopy)
assert(
protobufDataToCatalyst.copy(
binaryFileDescriptorSet = Some(testFileDescCopy)
).hashCode == -937893175
)

testFileDescCopy(0) = '0'
assert(
protobufDataToCatalyst.copy(
binaryFileDescriptorSet = Some(testFileDescCopy)
).hashCode == -1493098769
)
}

private def generateCatalystDataToProtobuf() = {
val schema = StructType(
Seq(
StructField("a", StringType),
StructField("b", IntegerType)
)
)
val messageName = "message"
val data = RandomDataGenerator.randomRow(new scala.util.Random(3), schema)
val converter = CatalystTypeConverters.createToCatalystConverter(schema)
val dataLiteral = Literal.create(converter(data), schema)

CatalystDataToProtobuf(dataLiteral, messageName, Some(testFileDesc))
}
}

0 comments on commit 23021f6

Please sign in to comment.