Skip to content

Commit 9a3d94c

Browse files
committed
PIR & PNNS benchmarks include {de}serialization.
1 parent 9e4c6c2 commit 9a3d94c

4 files changed

Lines changed: 88 additions & 45 deletions

File tree

Sources/ApplicationProtobuf/PirConversion.swift

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public import _CryptoExtras
1616
import Crypto
1717
import Foundation
1818
public import HomomorphicEncryption
19+
import HomomorphicEncryptionProtobuf
1920
public import PrivateInformationRetrieval
2021
import SwiftProtobuf
2122

@@ -293,19 +294,37 @@ extension [KeywordValuePair.Keyword: KeywordValuePair.Value] {
293294

294295
extension Query {
295296
package func size() throws -> Int {
296-
try proto().serializedData().count
297+
try proto().size()
298+
}
299+
}
300+
301+
extension Apple_SwiftHomomorphicEncryption_Pir_V1_EncryptedIndices {
302+
package func size() throws -> Int {
303+
try serializedData().count
297304
}
298305
}
299306

300307
extension Response {
301308
package func size() throws -> Int {
302-
try proto().serializedData().count
309+
try proto().size()
310+
}
311+
}
312+
313+
extension Apple_SwiftHomomorphicEncryption_Api_Pir_V1_PIRResponse {
314+
package func size() throws -> Int {
315+
try serializedData().count
303316
}
304317
}
305318

306319
extension EvaluationKey {
307320
package func size() throws -> Int {
308-
try serialize().proto().serializedData().count
321+
try serialize().proto().size()
322+
}
323+
}
324+
325+
extension Apple_SwiftHomomorphicEncryption_V1_SerializedEvaluationKey {
326+
package func size() throws -> Int {
327+
try serializedData().count
309328
}
310329
}
311330

Sources/ApplicationProtobuf/PnnsConversion.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,13 @@ extension [Apple_SwiftHomomorphicEncryption_Pnns_V1_SerializedCiphertextMatrix]
325325

326326
extension Query {
327327
package func size() throws -> Int {
328-
try proto().map { matrix in try matrix.serializedData().count }.sum()
328+
try proto().size()
329+
}
330+
}
331+
332+
extension [Apple_SwiftHomomorphicEncryption_Pnns_V1_SerializedCiphertextMatrix] {
333+
package func size() throws -> Int {
334+
try map { matrix in try matrix.serializedData().count }.sum()
329335
}
330336
}
331337

Sources/_BenchmarkUtilities/PirBenchmarkUtilities.swift

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ public struct PirBenchmarkConfig<Scalar: ScalarType> {
115115

116116
extension PrivateInformationRetrieval.Response {
117117
func scaledNoiseBudget(using secretKey: Scheme.SecretKey) throws -> Int {
118-
try Int(
119-
noiseBudget(using: secretKey, variableTime: true) * Double(
120-
noiseBudgetScale))
118+
try Int(noiseBudget(using: secretKey, variableTime: true) *
119+
Double(noiseBudgetScale))
121120
}
122121
}
123122

@@ -178,12 +177,14 @@ public func pirProcessBenchmark<PirUtil: PirUtilProtocol>(
178177
struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
179178
where Server.Scheme == Client.Scheme
180179
{
180+
typealias Scheme = Server.Scheme
181181
let processedDatabase: Server.Database
182182
let server: Server
183183
let client: Client
184-
let secretKey: SecretKey<Client.Scheme>
185-
let evaluationKey: Server.Scheme.EvaluationKey
186-
let query: Client.Query
184+
let context: Scheme.Context
185+
let secretKey: SecretKey<Scheme>
186+
let evaluationKey: Apple_SwiftHomomorphicEncryption_V1_SerializedEvaluationKey
187+
let query: Apple_SwiftHomomorphicEncryption_Pir_V1_EncryptedIndices
187188
let evaluationKeySize: Int
188189
let evaluationKeyCount: Int
189190
let querySize: Int
@@ -198,9 +199,8 @@ struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
198199
pirConfig: IndexPirConfig,
199200
encryptionConfig: EncryptionParametersConfig) async throws
200201
{
201-
let encryptParameter: EncryptionParameters<Server.Scheme.Scalar> =
202-
try EncryptionParameters(from: encryptionConfig)
203-
let context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
202+
let encryptParameter: EncryptionParameters<Scheme.Scalar> = try EncryptionParameters(from: encryptionConfig)
203+
self.context = try Scheme.Context(encryptionParameters: encryptParameter)
204204
let indexPirParameters = Server.generateParameter(config: pirConfig, with: context)
205205
let database = getDatabaseForTesting(
206206
numberOfEntries: pirConfig.entryCount,
@@ -210,8 +210,7 @@ struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
210210
self.server = try Server(parameter: indexPirParameters, context: context, database: processedDatabase)
211211
self.client = Client(parameter: indexPirParameters, context: context)
212212
self.secretKey = try context.generateSecretKey()
213-
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
214-
self.query = try client.generateQuery(at: [0], using: secretKey)
213+
let evaluationKey = try client.generateEvaluationKey(using: secretKey)
215214

216215
// Validate correctness
217216
let queryIndex = Int.random(in: 0..<pirConfig.entryCount)
@@ -222,9 +221,11 @@ struct IndexPirBenchmarkContext<Server: IndexPirServer, Client: IndexPirClient>
222221
fatalError("Incorrect PIR response")
223222
}
224223

225-
self.evaluationKeySize = try evaluationKey.size()
224+
self.query = try query.proto()
225+
self.evaluationKey = evaluationKey.serialize().proto()
226+
self.evaluationKeySize = try self.evaluationKey.size()
226227
self.evaluationKeyCount = evaluationKey.config.keyCount
227-
self.querySize = try query.size()
228+
self.querySize = try self.query.size()
228229
self.queryCiphertextCount = query.ciphertexts.count
229230
self.responseSize = try response.size()
230231
self.responseCiphertextCount = response.ciphertexts.count
@@ -252,9 +253,11 @@ public func indexPirBenchmark<PirUtil: PirUtilProtocol>(
252253
benchmark,
253254
benchmarkContext: IndexPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>) in
254255
for _ in benchmark.scaledIterations {
255-
try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query,
256-
using: benchmarkContext
257-
.evaluationKey))
256+
let query: Query<PirUtil.Scheme> = try benchmarkContext.query.native(context: benchmarkContext.context)
257+
let evaluationKey: PirUtil.Scheme.EvaluationKey = try benchmarkContext.evaluationKey
258+
.native(context: benchmarkContext.context)
259+
let response = try await benchmarkContext.server.computeResponse(to: query, using: evaluationKey)
260+
try blackHole(response.proto())
258261
}
259262
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
260263
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
@@ -280,11 +283,13 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
280283
{
281284
typealias Server = KeywordPirServer<IndexServer>
282285
typealias Client = KeywordPirClient<IndexClient>
286+
typealias Scheme = IndexServer.Scheme
283287
let server: Server
284288
let client: Client
285-
let secretKey: SecretKey<Client.Scheme>
286-
let evaluationKey: Server.Scheme.EvaluationKey
287-
let query: Client.Query
289+
let context: Scheme.Context
290+
let secretKey: SecretKey<Scheme>
291+
let evaluationKey: Apple_SwiftHomomorphicEncryption_V1_SerializedEvaluationKey
292+
let query: Apple_SwiftHomomorphicEncryption_Pir_V1_EncryptedIndices
288293
let evaluationKeySize: Int
289294
let evaluationKeyCount: Int
290295
let querySize: Int
@@ -293,10 +298,10 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
293298
let responseCiphertextCount: Int
294299
let noiseBudget: Int
295300

296-
init(config: PirBenchmarkConfig<Server.Scheme.Scalar>) async throws {
297-
let encryptParameter: EncryptionParameters<Server.Scheme.Scalar> =
301+
init(config: PirBenchmarkConfig<Scheme.Scalar>) async throws {
302+
let encryptParameter: EncryptionParameters<Scheme.Scalar> =
298303
try EncryptionParameters(from: config.encryptionConfig)
299-
let context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
304+
self.context = try Server.Scheme.Context(encryptionParameters: encryptParameter)
300305
let rows = (0..<config.databaseConfig.entryCount).map { index in KeywordValuePair(
301306
keyword: [UInt8](String(index).utf8),
302307
value: (0..<config.databaseConfig.entrySizeInBytes).map { _ in UInt8.random(in: 0..<UInt8.max) })
@@ -337,8 +342,7 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
337342
pirParameter: processed.pirParameter,
338343
context: context)
339344
self.secretKey = try context.generateSecretKey()
340-
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
341-
self.query = try client.generateQuery(at: [UInt8]("0".utf8), using: secretKey)
345+
let evaluationKey = try client.generateEvaluationKey(using: secretKey)
342346

343347
// Validate correctness
344348
let queryIndex = Int.random(in: 0..<config.databaseConfig.entryCount)
@@ -355,9 +359,11 @@ struct KeywordPirBenchmarkContext<IndexServer: IndexPirServer, IndexClient: Inde
355359
fatalError("Incorrect PIR response")
356360
}
357361

358-
self.evaluationKeySize = try evaluationKey.size()
362+
self.query = try query.proto()
363+
self.evaluationKey = evaluationKey.serialize().proto()
364+
self.evaluationKeySize = try self.evaluationKey.size()
359365
self.evaluationKeyCount = evaluationKey.config.keyCount
360-
self.querySize = try query.size()
366+
self.querySize = try self.query.size()
361367
self.queryCiphertextCount = query.ciphertexts.count
362368
self.responseSize = try response.size()
363369
self.responseCiphertextCount = response.ciphertexts.count
@@ -380,10 +386,16 @@ public func keywordPirBenchmark<PirUtil: PirUtilProtocol>(
380386
"entrySize=\(config.databaseConfig.entrySizeInBytes)",
381387
"keyCompression=\(config.keywordPirConfig.keyCompression)",
382388
].joined(separator: "/")
383-
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { benchmark, benchmarkContext in
389+
// swiftlint:disable closure_parameter_position
390+
Benchmark(benchmarkName, configuration: config.benchmarkConfig) { (
391+
benchmark,
392+
benchmarkContext: KeywordPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>) in
384393
for _ in benchmark.scaledIterations {
385-
try await blackHole(benchmarkContext.server.computeResponse(to: benchmarkContext.query,
386-
using: benchmarkContext.evaluationKey))
394+
let query: Query<PirUtil.Scheme> = try benchmarkContext.query.native(context: benchmarkContext.context)
395+
let evaluationKey: PirUtil.Scheme.EvaluationKey = try benchmarkContext.evaluationKey
396+
.native(context: benchmarkContext.context)
397+
let response = try await benchmarkContext.server.computeResponse(to: query, using: evaluationKey)
398+
try blackHole(response.proto())
387399
}
388400
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
389401
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
@@ -392,7 +404,9 @@ public func keywordPirBenchmark<PirUtil: PirUtilProtocol>(
392404
benchmark.measurement(.responseSize, benchmarkContext.responseSize)
393405
benchmark.measurement(.responseCiphertextCount, benchmarkContext.responseCiphertextCount)
394406
benchmark.measurement(.noiseBudget, benchmarkContext.noiseBudget)
395-
} setup: {
407+
}
408+
// swiftlint:enable closure_parameter_position
409+
setup: {
396410
try await KeywordPirBenchmarkContext<MulPirServer<PirUtil>, MulPirClient<PirUtil>>(
397411
config: config)
398412
}

Sources/_BenchmarkUtilities/PnnsBenchmarkUtilities.swift

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ public func cosineSimilarityBenchmark<Scheme: HeScheme>(_: Scheme.Type,
146146
benchmark,
147147
benchmarkContext: PnnsBenchmarkContext<Scheme>) in
148148
for _ in benchmark.scaledIterations {
149-
try await blackHole(
150-
benchmarkContext.server.computeResponse(
151-
to: benchmarkContext.query,
152-
using: benchmarkContext.evaluationKey))
149+
let context = benchmarkContext.server.contexts[0]
150+
let evaluationKey: EvaluationKey<Scheme> = try benchmarkContext.evaluationKey.native(context: context)
151+
let query: Query<Scheme> = try benchmarkContext.query.native(context: context)
152+
let response = try await benchmarkContext.server.computeResponse(to: query, using: evaluationKey)
153+
try blackHole(response.proto())
153154
}
154155
benchmark.measurement(.evaluationKeySize, benchmarkContext.evaluationKeySize)
155156
benchmark.measurement(.evaluationKeyCount, benchmarkContext.evaluationKeyCount)
@@ -236,11 +237,12 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
236237
let processedDatabase: ProcessedDatabase<Scheme>
237238
let server: Server<Scheme>
238239
let client: Client<Scheme>
240+
let contexts: [Scheme.Context]
239241
let secretKey: SecretKey<Scheme>
240-
let evaluationKey: Scheme.EvaluationKey
242+
let evaluationKey: Apple_SwiftHomomorphicEncryption_V1_SerializedEvaluationKey
241243
let evaluationKeyCount: Int
242-
let query: Query<Scheme>
243244
let evaluationKeySize: Int
245+
let query: [Apple_SwiftHomomorphicEncryption_Pnns_V1_SerializedCiphertextMatrix]
244246
let querySize: Int
245247
let queryCiphertextCount: Int
246248
let responseSize: Int
@@ -293,18 +295,18 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
293295
databasePacking: .diagonal(babyStepGiantStep: babyStepGiantStep))
294296

295297
let database = getDatabaseForTesting(config: databaseConfig)
296-
let contexts = try clientConfig.encryptionParameters
298+
self.contexts = try clientConfig.encryptionParameters
297299
.map { encryptionParameters in try Scheme.Context(encryptionParameters: encryptionParameters) }
298300
self.processedDatabase = try await database.process(config: serverConfig, contexts: contexts)
299301
self.client = try Client(config: clientConfig, contexts: contexts)
300302
self.server = try Server(database: processedDatabase)
301303
self.secretKey = try client.generateSecretKey()
302-
self.evaluationKey = try client.generateEvaluationKey(using: secretKey)
304+
let evaluationKey = try client.generateEvaluationKey(using: secretKey)
303305

304306
// We query exact matches from rows in the database
305307
let databaseVectors = Array2d(data: database.rows.map { row in row.vector })
306308
let queryVectors = Array2d(data: database.rows.prefix(queryCount).map { row in row.vector })
307-
self.query = try client.generateQuery(for: queryVectors, using: secretKey)
309+
let query = try client.generateQuery(for: queryVectors, using: secretKey)
308310

309311
let response = try await server.computeResponse(to: query, using: evaluationKey)
310312
let decrypted = try client.decrypt(response: response, using: secretKey)
@@ -317,9 +319,11 @@ struct PnnsBenchmarkContext<Scheme: HeScheme> {
317319
scalingFactor: Float(clientConfig.scalingFactor))
318320
precondition(decrypted.distances.data == expected.data, "Wrong response")
319321

320-
self.evaluationKeySize = try evaluationKey.size()
322+
self.evaluationKey = evaluationKey.serialize().proto()
323+
self.evaluationKeySize = try self.evaluationKey.size()
321324
self.evaluationKeyCount = evaluationKey.config.keyCount
322-
self.querySize = try query.size()
325+
self.query = try query.proto()
326+
self.querySize = try self.query.size()
323327
self.queryCiphertextCount = query.ciphertextMatrices.map { matrix in matrix.ciphertexts.count }.sum()
324328
self.responseSize = try response.size()
325329
self.responseCiphertextCount = response.ciphertextMatrices

0 commit comments

Comments
 (0)