From 23021f6861e67b6f5a43a8e7cdc844cbe7c56f54 Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Fri, 7 Mar 2025 15:12:47 +0000 Subject: [PATCH] Fix comparison of protobuf expressions --- .../sql/protobuf/CatalystDataToProtobuf.scala | 31 ++++ .../sql/protobuf/ProtobufDataToCatalyst.scala | 31 ++++ .../ProtobufCatalystDataConversionSuite.scala | 173 ++++++++++++++++++ 3 files changed, 235 insertions(+) diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala index 8805d935093e8..0564eee1602a7 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala @@ -55,4 +55,35 @@ private[sql] case class CatalystDataToProtobuf( override protected def withNewChildInternal(newChild: Expression): CatalystDataToProtobuf = copy(child = newChild) + + override def equals(that: Any): Boolean = { + that match { + case that: CatalystDataToProtobuf => + this.child == that.child && + this.messageName == that.messageName && + ( + (this.binaryFileDescriptorSet.isEmpty && that.binaryFileDescriptorSet.isEmpty) || + ( + this.binaryFileDescriptorSet.nonEmpty && that.binaryFileDescriptorSet.nonEmpty && + this.binaryFileDescriptorSet.get.sameElements(that.binaryFileDescriptorSet.get) + ) + ) && + this.options == that.options + case _ => false + } + } + + override def hashCode(): Int = { + val prime = 31 + var result = 1 + var i = 0 + while (i < binaryFileDescriptorSet.map(_.length).getOrElse(0)) { + result = prime * result + binaryFileDescriptorSet.get.apply(i).hashCode + i += 1 + } + result = prime * result + child.hashCode + result = prime * result + messageName.hashCode + result = prime * result + options.hashCode + result + } } diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala index a182ac854b28b..b3225d61eb01a 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -142,4 +142,35 @@ private[sql] case class ProtobufDataToCatalyst( override protected def withNewChildInternal(newChild: Expression): ProtobufDataToCatalyst = copy(child = newChild) + + override def equals(that: Any): Boolean = { + that match { + case that: ProtobufDataToCatalyst => + this.child == that.child && + this.messageName == that.messageName && + ( + (this.binaryFileDescriptorSet.isEmpty && that.binaryFileDescriptorSet.isEmpty) || + ( + this.binaryFileDescriptorSet.nonEmpty && that.binaryFileDescriptorSet.nonEmpty && + this.binaryFileDescriptorSet.get.sameElements(that.binaryFileDescriptorSet.get) + ) + ) && + this.options == that.options + case _ => false + } + } + + override def hashCode(): Int = { + val prime = 31 + var result = 1 + var i = 0 + while (i < binaryFileDescriptorSet.map(_.length).getOrElse(0)) { + result = prime * result + binaryFileDescriptorSet.get.apply(i).hashCode + i += 1 + } + result = prime * result + child.hashCode + result = prime * result + messageName.hashCode + result = prime * result + options.hashCode + result + } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index abae1d622d3cf..bcb0cba614937 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -244,4 +244,177 @@ class ProtobufCatalystDataConversionSuite testFileDesc, "org.apache.spark.sql.protobuf.protos.BytesMsg") assert(withFullName.findFieldByName("bytes_type") != null) } + + test("CatalystDataToProtobuf equals") { + val catalystDataToProtobuf = generateCatalystDataToProtobuf() + + assert( + catalystDataToProtobuf + == catalystDataToProtobuf.copy() + ) + assert( + catalystDataToProtobuf + != catalystDataToProtobuf.copy(options = Map("mode" -> "FAILFAST")) + ) + assert( + catalystDataToProtobuf + != catalystDataToProtobuf.copy(messageName = "otherMessage") + ) + assert( + catalystDataToProtobuf + != catalystDataToProtobuf.copy(child = Literal.create(0, IntegerType)) + ) + assert( + catalystDataToProtobuf + != catalystDataToProtobuf.copy(binaryFileDescriptorSet = None) + ) + + val testFileDescCopy = new Array[Byte](testFileDesc.length) + testFileDesc.copyToArray(testFileDescCopy) + assert( + catalystDataToProtobuf + == catalystDataToProtobuf.copy(binaryFileDescriptorSet = Some(testFileDescCopy)) + ) + + testFileDescCopy(0) = '0' + assert( + catalystDataToProtobuf + != catalystDataToProtobuf.copy(binaryFileDescriptorSet = Some(testFileDescCopy)) + ) + } + + test("CatalystDataToProtobuf hashCode") { + val catalystDataToProtobuf = generateCatalystDataToProtobuf() + + assert( + catalystDataToProtobuf.hashCode == 18619165 + ) + assert( + catalystDataToProtobuf.copy(options = Map("mode" -> "FAILFAST")).hashCode == -1634963844 + ) + assert( + catalystDataToProtobuf.copy(messageName = "otherMessage").hashCode == -1751271943 + ) + assert( + catalystDataToProtobuf.copy(child = Literal.create(0, IntegerType)).hashCode == -2051339781 + ) + assert( + catalystDataToProtobuf.copy(binaryFileDescriptorSet = None).hashCode == 866765483 + ) + + val testFileDescCopy = new Array[Byte](testFileDesc.length) + testFileDesc.copyToArray(testFileDescCopy) + assert( + catalystDataToProtobuf.copy( + binaryFileDescriptorSet = Some(testFileDescCopy) + ).hashCode == -937893175 + ) + + testFileDescCopy(0) = '0' + assert( + catalystDataToProtobuf.copy( + binaryFileDescriptorSet = Some(testFileDescCopy) + ).hashCode == -1493098769 + ) + } + + test("ProtobufDataToCatalyst equals") { + val catalystDataToProtobuf = generateCatalystDataToProtobuf() + val protobufDataToCatalyst = ProtobufDataToCatalyst( + catalystDataToProtobuf, + "message", + Some(testFileDesc), + Map("mode" -> "PERMISSIVE") + ) + + assert( + protobufDataToCatalyst + == protobufDataToCatalyst.copy() + ) + assert( + protobufDataToCatalyst + != protobufDataToCatalyst.copy(options = Map("mode" -> "FAILFAST")) + ) + assert( + protobufDataToCatalyst + != protobufDataToCatalyst.copy(messageName = "otherMessage") + ) + assert( + protobufDataToCatalyst + != protobufDataToCatalyst.copy(child = Literal.create(0, IntegerType)) + ) + assert( + protobufDataToCatalyst + != protobufDataToCatalyst.copy(binaryFileDescriptorSet = None) + ) + + val testFileDescCopy = new Array[Byte](testFileDesc.length) + testFileDesc.copyToArray(testFileDescCopy) + assert( + protobufDataToCatalyst + == protobufDataToCatalyst.copy(binaryFileDescriptorSet = Some(testFileDescCopy)) + ) + + testFileDescCopy(0) = '0' + assert( + protobufDataToCatalyst + != protobufDataToCatalyst.copy(binaryFileDescriptorSet = Some(testFileDescCopy)) + ) + } + + test("ProtobufDataToCatalyst hashCode") { + val catalystDataToProtobuf = generateCatalystDataToProtobuf() + val protobufDataToCatalyst = ProtobufDataToCatalyst( + catalystDataToProtobuf, + "message", + Some(testFileDesc), + Map("mode" -> "PERMISSIVE") + ) + + assert( + protobufDataToCatalyst.hashCode == -937893175 + ) + assert( + protobufDataToCatalyst.copy(options = Map("mode" -> "FAILFAST")).hashCode == -1634963844 + ) + assert( + protobufDataToCatalyst.copy(messageName = "otherMessage").hashCode == -1751271943 + ) + assert( + protobufDataToCatalyst.copy(child = Literal.create(0, IntegerType)).hashCode == -133420428 + ) + assert( + protobufDataToCatalyst.copy(binaryFileDescriptorSet = None).hashCode == 866765483 + ) + + val testFileDescCopy = new Array[Byte](testFileDesc.length) + testFileDesc.copyToArray(testFileDescCopy) + assert( + protobufDataToCatalyst.copy( + binaryFileDescriptorSet = Some(testFileDescCopy) + ).hashCode == -937893175 + ) + + testFileDescCopy(0) = '0' + assert( + protobufDataToCatalyst.copy( + binaryFileDescriptorSet = Some(testFileDescCopy) + ).hashCode == -1493098769 + ) + } + + private def generateCatalystDataToProtobuf() = { + val schema = StructType( + Seq( + StructField("a", StringType), + StructField("b", IntegerType) + ) + ) + val messageName = "message" + val data = RandomDataGenerator.randomRow(new scala.util.Random(3), schema) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val dataLiteral = Literal.create(converter(data), schema) + + CatalystDataToProtobuf(dataLiteral, messageName, Some(testFileDesc)) + } }