Skip to content

Commit 04d29d3

Browse files
authored
Adopt Swift Testing in PlaintextMatrixTests and ConversionTests (#165)
* Adopt Swift Testing in PlaintextMatrixTests * Adopt Swift Testing in ConversionTests
1 parent 838fb0b commit 04d29d3

File tree

2 files changed

+96
-75
lines changed

2 files changed

+96
-75
lines changed

Tests/PrivateNearestNeighborSearchProtobufTests/ConversionTests.swift

Lines changed: 47 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
1+
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
1616
@testable import PrivateNearestNeighborSearch
1717
import PrivateNearestNeighborSearchProtobuf
1818

19-
import XCTest
19+
import Testing
2020

2121
func increasingData<T: ScalarType>(dimensions: MatrixDimensions, modulus: T) -> [[T]] {
2222
(0..<dimensions.rowCount).map { rowIndex in
@@ -27,32 +27,36 @@ func increasingData<T: ScalarType>(dimensions: MatrixDimensions, modulus: T) ->
2727
}
2828
}
2929

30-
class ConversionTests: XCTestCase {
31-
func testDistanceMetric() throws {
30+
@Suite
31+
struct ConversionTests {
32+
@Test
33+
func distanceMetric() throws {
3234
for metric in DistanceMetric.allCases {
33-
XCTAssertEqual(try metric.proto().native(), metric)
35+
#expect(try metric.proto().native() == metric)
3436
}
3537
}
3638

37-
func testPacking() throws {
38-
XCTAssertEqual(
39-
try MatrixPacking.denseColumn.proto().native(),
40-
MatrixPacking.denseColumn)
41-
XCTAssertEqual(
42-
try MatrixPacking.denseRow.proto().native(),
43-
MatrixPacking.denseRow)
39+
@Test
40+
func packing() throws {
41+
#expect(
42+
try MatrixPacking.denseColumn.proto().native() ==
43+
MatrixPacking.denseColumn)
44+
#expect(
45+
try MatrixPacking.denseRow.proto().native() ==
46+
MatrixPacking.denseRow)
4447
let bsgs = BabyStepGiantStep(vectorDimension: 128)
45-
XCTAssertEqual(bsgs.proto().native(), bsgs)
48+
#expect(bsgs.proto().native() == bsgs)
4649

47-
XCTAssertEqual(
50+
#expect(
4851
try MatrixPacking
4952
.diagonal(babyStepGiantStep: bsgs)
5053
.proto()
51-
.native(),
52-
MatrixPacking.diagonal(babyStepGiantStep: bsgs))
54+
.native() ==
55+
MatrixPacking.diagonal(babyStepGiantStep: bsgs))
5356
}
5457

55-
func testClientAndServerConfig() throws {
58+
@Test
59+
func clientAndServerConfig() throws {
5660
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
5761
let vectorDimension = 4
5862
let clientConfig = try ClientConfig<Scheme>(
@@ -67,35 +71,37 @@ class ConversionTests: XCTestCase {
6771
significantBitCounts: [15],
6872
preferringSmall: true,
6973
nttDegree: 8))
70-
XCTAssertEqual(try clientConfig.proto().native(), clientConfig)
74+
#expect(try clientConfig.proto().native() == clientConfig)
7175

7276
let serverConfig = ServerConfig<Scheme>(
7377
clientConfig: clientConfig,
7478
databasePacking: MatrixPacking
7579
.diagonal(
7680
babyStepGiantStep: BabyStepGiantStep(vectorDimension: vectorDimension)))
77-
XCTAssertEqual(try serverConfig.proto().native(), serverConfig)
81+
#expect(try serverConfig.proto().native() == serverConfig)
7882
}
7983

8084
try runTest(Bfv<UInt32>.self)
8185
try runTest(Bfv<UInt64>.self)
8286
}
8387

84-
func testDatabase() throws {
88+
@Test
89+
func database() throws {
8590
let rows = (0...10).map { rowIndex in
8691
DatabaseRow(
8792
entryId: rowIndex,
8893
entryMetadata: rowIndex.littleEndianBytes,
8994
vector: [Float(rowIndex)])
9095
}
9196
for row in rows {
92-
XCTAssertEqual(row.proto().native(), row)
97+
#expect(row.proto().native() == row)
9398
}
9499
let database = Database(rows: rows)
95-
XCTAssertEqual(database.proto().native(), database)
100+
#expect(database.proto().native() == database)
96101
}
97102

98-
func testSerializedPlaintextMatrix() throws {
103+
@Test
104+
func serializedPlaintextMatrix() throws {
99105
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
100106
let encryptionParameters = try EncryptionParameters<Scheme>(from: .insecure_n_8_logq_5x18_logt_5)
101107
let context = try Context<Scheme>(encryptionParameters: encryptionParameters)
@@ -110,27 +116,28 @@ class ConversionTests: XCTestCase {
110116
packing: .denseColumn,
111117
values: scalars.flatMap { $0 })
112118
let serialized = try plaintextMatrix.serialize()
113-
XCTAssertEqual(try serialized.proto().native(), serialized)
119+
#expect(try serialized.proto().native() == serialized)
114120
let deserialized = try PlaintextMatrix(deserialize: serialized, context: context)
115-
XCTAssertEqual(deserialized, plaintextMatrix)
121+
#expect(deserialized == plaintextMatrix)
116122

117123
for moduliCount in 1..<encryptionParameters.coefficientModuli.count {
118124
let evalPlaintextMatrix = try plaintextMatrix.convertToEvalFormat(moduliCount: moduliCount)
119125
let serialized = try evalPlaintextMatrix.serialize()
120-
XCTAssertEqual(try serialized.proto().native(), serialized)
126+
#expect(try serialized.proto().native() == serialized)
121127
let deserialized = try PlaintextMatrix(
122128
deserialize: serialized,
123129
context: context,
124130
moduliCount: moduliCount)
125-
XCTAssertEqual(deserialized, evalPlaintextMatrix)
131+
#expect(deserialized == evalPlaintextMatrix)
126132
}
127133
}
128134

129135
try runTest(Bfv<UInt32>.self)
130136
try runTest(Bfv<UInt64>.self)
131137
}
132138

133-
func testSerializedCiphertextMatrix() throws {
139+
@Test
140+
func serializedCiphertextMatrix() throws {
134141
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
135142
let encryptionParameters = try EncryptionParameters<Scheme>(from: .insecure_n_8_logq_5x18_logt_5)
136143
let context = try Context<Scheme>(encryptionParameters: encryptionParameters)
@@ -150,18 +157,18 @@ class ConversionTests: XCTestCase {
150157
let ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey)
151158
let serialized = try ciphertextMatrix.serialize()
152159
let serializedProto = try serialized.proto()
153-
XCTAssertEqual(try serializedProto.native(), serialized)
160+
#expect(try serializedProto.native() == serialized)
154161
}
155162
// Check Evaluation format
156163
do {
157164
let ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey)
158165
let evalCiphertextMatrix = try ciphertextMatrix.convertToEvalFormat()
159166
let serialized = try evalCiphertextMatrix.serialize()
160-
XCTAssertEqual(try serialized.proto().native(), serialized)
167+
#expect(try serialized.proto().native() == serialized)
161168
let deserialized = try CiphertextMatrix<Scheme, Eval>(
162169
deserialize: serialized,
163170
context: context)
164-
XCTAssertEqual(deserialized, evalCiphertextMatrix)
171+
#expect(deserialized == evalCiphertextMatrix)
165172
}
166173
// Check serializeForDecryption
167174
do {
@@ -174,20 +181,21 @@ class ConversionTests: XCTestCase {
174181
let serializedProto = try serialized.proto()
175182
let serializedSize = try serializedProto.serializedData().count
176183

177-
XCTAssertLessThan(serializedForDecryptionSize, serializedSize)
184+
#expect(serializedForDecryptionSize < serializedSize)
178185
let deserialized = try CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>(
179186
deserialize: serializedForDecryption,
180187
context: context, moduliCount: 1)
181188
let decrypted = try deserialized.decrypt(using: secretKey)
182-
XCTAssertEqual(decrypted, plaintextMatrix)
189+
#expect(decrypted == plaintextMatrix)
183190
}
184191
}
185192

186193
try runTest(Bfv<UInt32>.self)
187194
try runTest(Bfv<UInt64>.self)
188195
}
189196

190-
func testQuery() throws {
197+
@Test
198+
func query() throws {
191199
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
192200
let encryptionParameters = try EncryptionParameters<Scheme>(from: .insecure_n_8_logq_5x18_logt_5)
193201
let context = try Context<Scheme>(encryptionParameters: encryptionParameters)
@@ -208,13 +216,14 @@ class ConversionTests: XCTestCase {
208216

209217
let query = Query(ciphertextMatrices: ciphertextMatrices)
210218
let roundtrip = try query.proto().native(context: context)
211-
XCTAssertEqual(roundtrip, query)
219+
#expect(roundtrip == query)
212220
}
213221
try runTest(Bfv<UInt32>.self)
214222
try runTest(Bfv<UInt64>.self)
215223
}
216224

217-
func testSerializedProcessedDatabase() throws {
225+
@Test
226+
func serializedProcessedDatabase() throws {
218227
func runTest<Scheme: HeScheme>(_: Scheme.Type) throws {
219228
let encryptionParameters = try EncryptionParameters<Scheme>(from: .insecure_n_8_logq_5x18_logt_5)
220229
let vectorDimension = 4
@@ -226,7 +235,7 @@ class ConversionTests: XCTestCase {
226235
vector: Array(repeating: Float(rowIndex), count: vectorDimension))
227236
}
228237
for row in rows {
229-
XCTAssertEqual(row.proto().native(), row)
238+
#expect(row.proto().native() == row)
230239
}
231240
let database = Database(rows: rows)
232241

@@ -250,7 +259,7 @@ class ConversionTests: XCTestCase {
250259

251260
let processed = try database.process(config: serverConfig)
252261
let serialized = try processed.serialize()
253-
XCTAssertEqual(try serialized.proto().native(), serialized)
262+
#expect(try serialized.proto().native() == serialized)
254263
}
255264
try runTest(Bfv<UInt32>.self)
256265
try runTest(Bfv<UInt64>.self)

0 commit comments

Comments
 (0)