Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions Tests/PIRGenerateDatabaseTests/PIRGenerateDatabaseTests.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -13,12 +13,14 @@
// limitations under the License.

@testable import PIRGenerateDatabase
import XCTest
import Testing

class PIRGenerateDatabaseTests: XCTestCase {
func testValueSizeArguments() throws {
XCTAssertEqual(try XCTUnwrap(ValueSizeArguments(argument: "1")?.range), 1..<2)
XCTAssertEqual(try XCTUnwrap(ValueSizeArguments(argument: "1..<10")?.range), 1..<10)
XCTAssertEqual(try XCTUnwrap(ValueSizeArguments(argument: "1...10")?.range), 1..<11)
@Suite
struct PIRGenerateDatabaseTests {
@Test
func valueSizeArguments() throws {
#expect(try #require(ValueSizeArguments(argument: "1")?.range) == 1..<2)
#expect(try #require(ValueSizeArguments(argument: "1..<10")?.range) == 1..<10)
#expect(try #require(ValueSizeArguments(argument: "1...10")?.range) == 1..<11)
}
}
23 changes: 15 additions & 8 deletions Tests/PIRProcessDatabaseTests/ProcessDatabaseTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import HomomorphicEncryption
@testable import PIRProcessDatabase
import PrivateInformationRetrieval
import XCTest
import Testing

class ProcessDatabaseTests: XCTestCase {
func testArgumentsJsonParsing() throws {
@Suite
struct ProcessDatabaseTests {
@Test
func argumentsJsonParsing() throws {
do {
let configString = """
{
Expand All @@ -34,8 +37,10 @@ class ProcessDatabaseTests: XCTestCase {
}
}
"""
let configData = try XCTUnwrap(configString.data(using: .utf8))
let parsedConfig = try XCTUnwrap(JSONDecoder().decode(PIRProcessDatabase.Arguments.self, from: configData))
let configData = try #require(configString.data(using: .utf8))
let parsedConfig = try #require(try JSONDecoder().decode(
PIRProcessDatabase.Arguments.self,
from: configData))

let config = PIRProcessDatabase.Arguments(
inputDatabase: "input-database.txtpb",
Expand All @@ -45,14 +50,16 @@ class ProcessDatabaseTests: XCTestCase {
outputEvaluationKeyConfig: "output-evaluation-key-config.txtpb",
sharding: Sharding.shardCount(10),
trialsPerShard: 1)
XCTAssertEqual(parsedConfig, config)
#expect(parsedConfig == config)
}

// Can parse default JSON string
do {
let configString = PIRProcessDatabase.Arguments.defaultJsonString()
let configData = try XCTUnwrap(configString.data(using: .utf8))
XCTAssertNoThrow(try JSONDecoder().decode(PIRProcessDatabase.Arguments.self, from: configData))
let configData = try #require(configString.data(using: .utf8))
#expect(throws: Never.self) {
try JSONDecoder().decode(PIRProcessDatabase.Arguments.self, from: configData)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

import _CryptoExtras
import Crypto
import Foundation
@testable import HomomorphicEncryption
@testable import PrivateInformationRetrieval
import PrivateInformationRetrievalProtobuf
import Testing
import TestUtilities
import XCTest

class ConversionTests: XCTestCase {
func testKeywordDatabase() throws {
@Suite
struct ConversionTests {
@Test
func keywordDatabase() throws {
let rowCount = 10
let payloadSize = 5
let databaseRows = (0..<rowCount).map { index in KeywordValuePair(
Expand All @@ -30,14 +33,15 @@ class ConversionTests: XCTestCase {
}

let proto = databaseRows.proto()
XCTAssertEqual(proto.rows.count, rowCount)
XCTAssert(proto.rows.map(\.value).allSatisfy { $0.count == payloadSize })
#expect(proto.rows.count == rowCount)
#expect(proto.rows.map(\.value).allSatisfy { $0.count == payloadSize })
let native = proto.native()

XCTAssertEqual(native, databaseRows)
#expect(native == databaseRows)
}

func testProcessedDatabaseWithParameters() throws {
@Test
func processedDatabaseWithParameters() throws {
let rows = (0..<10).map { KeywordValuePair(keyword: Array(String($0).utf8), value: Array(String($0).utf8)) }
let context: Context<Bfv<UInt32>> = try .init(encryptionParameters: .init(from: .n_4096_logq_27_28_28_logt_13))
let config = try KeywordPirConfig(
Expand All @@ -54,60 +58,61 @@ class ConversionTests: XCTestCase {

let pirParameters = try processedDatabaseWithParameters.proto(context: context)
let loadedProcessedDatabaseWithParameters = try pirParameters.native(database: processedDatabase)
XCTAssertEqual(loadedProcessedDatabaseWithParameters, processedDatabaseWithParameters)
#expect(loadedProcessedDatabaseWithParameters == processedDatabaseWithParameters)
}

func testPirAlgorithm() throws {
for algorithm in PirAlgorithm.allCases {
XCTAssertEqual(try algorithm.proto().native(), algorithm)
}
@Test(arguments: PirAlgorithm.allCases)
func pirAlgorithm(_ algorithm: PirAlgorithm) throws {
#expect(try algorithm.proto().native() == algorithm)
}

func testPirKeyCompressionStrategy() throws {
for strategy in PirKeyCompressionStrategy.allCases {
XCTAssertEqual(try strategy.proto().native(), strategy)
}
@Test(arguments: PirKeyCompressionStrategy.allCases)
func pirKeyCompressionStrategy(_ strategy: PirKeyCompressionStrategy) throws {
#expect(try strategy.proto().native() == strategy)
}

func testOprfQuery() throws {
@Test
func oprfQuery() throws {
let element =
"02a36bc90e6db34096346eaf8b7bc40ee1113582155ad3797003ce614c835a874343701d3f2debbd80d97cbe45de6e5f1f"
let query = try OprfQuery(oprfRepresentation: Data(XCTUnwrap(Array(hexEncoded: element))))
let query = try OprfQuery(oprfRepresentation: Data(#require(Array(hexEncoded: element))))
let roundTrip = try query.proto().native()
XCTAssertEqual(roundTrip.oprfRepresentation, query.oprfRepresentation)
#expect(roundTrip.oprfRepresentation == query.oprfRepresentation)
}

func testOprfResponse() throws {
@Test
func oprfResponse() throws {
let evaluatedElement =
try Data(
XCTUnwrap(Array(
#require(Array(
hexEncoded: """
02a7bba589b3e8672aa19e8fd258de2e6aae20101c8d761246de97a6b5ee9cf105febce4327a326\
255a3c604f63f600ef6
""")))

let proof =
try Data(
XCTUnwrap(Array(
#require(Array(
hexEncoded: """
bfc6cf3859127f5fe25548859856d6b7fa1c7459f0ba5712a806fc091a3000c42d8ba34ff45f32a52\
e40533efd2a03bc87f3bf4f9f58028297ccb9ccb18ae7182bcd1ef239df77e3be65ef147f3acf8bc9\
cbfc5524b702263414f043e3b7ca2e
""")))
let rawRepresentation = evaluatedElement + proof
let blindEvaluation = try OprfResponse(rawRepresentation: rawRepresentation)
XCTAssertEqual(try blindEvaluation.proto().native().rawRepresentation, blindEvaluation.rawRepresentation)
#expect(try blindEvaluation.proto().native().rawRepresentation == blindEvaluation.rawRepresentation)
}

func testKeywordPirParameter() throws {
@Test
func keywordPirParameter() throws {
var keywordPirParameter = KeywordPirParameter(hashFunctionCount: 2)
var roundTrip = keywordPirParameter.proto().native()
XCTAssertEqual(roundTrip, keywordPirParameter)
#expect(roundTrip == keywordPirParameter)
let symmetricPirClientConfig = SymmetricPirClientConfig(serverPublicKey: [],
configType: .OPRF_P384_AES_GCM_192_NONCE_96_TAG_128)
keywordPirParameter = KeywordPirParameter(hashFunctionCount: 2,
symmetricPirClientConfig: symmetricPirClientConfig)
roundTrip = try keywordPirParameter.proto().nativeWithSymmetricPirClientConfig()
XCTAssertEqual(roundTrip, keywordPirParameter)
#expect(roundTrip == keywordPirParameter)
}
}