Skip to content

Commit dfda59e

Browse files
authored
multi-threading (#300)
1 parent 2945adc commit dfda59e

File tree

17 files changed

+867
-149
lines changed

17 files changed

+867
-149
lines changed

Benchmarks/PrivateInformationRetrievalBenchmark/PirBenchmark.swift

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024-2025 Apple Inc. and the Swift Homomorphic Encryption project authors
1+
// Copyright 2024-2026 Apple Inc. and the Swift Homomorphic Encryption project authors
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -24,9 +24,13 @@ nonisolated(unsafe) let benchmarks: () -> Void = {
2424
pirProcessBenchmark(PirUtil<Bfv<UInt32>>.self)()
2525
pirProcessBenchmark(PirUtil<Bfv<UInt64>>.self)()
2626

27-
indexPirBenchmark(PirUtil<Bfv<UInt32>>.self)()
28-
indexPirBenchmark(PirUtil<Bfv<UInt64>>.self)()
27+
indexPirBenchmark(PirUtil<Bfv<UInt32>>.self, callOptions: .multiThreaded)()
28+
indexPirBenchmark(PirUtil<Bfv<UInt32>>.self, callOptions: .singleThreaded)()
29+
indexPirBenchmark(PirUtil<Bfv<UInt64>>.self, callOptions: .multiThreaded)()
30+
indexPirBenchmark(PirUtil<Bfv<UInt64>>.self, callOptions: .singleThreaded)()
2931

30-
keywordPirBenchmark(PirUtil<Bfv<UInt32>>.self)()
31-
keywordPirBenchmark(PirUtil<Bfv<UInt64>>.self)()
32+
keywordPirBenchmark(PirUtil<Bfv<UInt32>>.self, callOptions: .multiThreaded)()
33+
keywordPirBenchmark(PirUtil<Bfv<UInt32>>.self, callOptions: .singleThreaded)()
34+
keywordPirBenchmark(PirUtil<Bfv<UInt64>>.self, callOptions: .multiThreaded)()
35+
keywordPirBenchmark(PirUtil<Bfv<UInt64>>.self, callOptions: .singleThreaded)()
3236
}

Package.resolved

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ let package = Package(
8181
.package(url: "https://github.com/apple/swift-algorithms", from: "1.2.0"),
8282
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.0"),
8383
.package(url: "https://github.com/apple/swift-async-algorithms.git", from: "1.0.2"),
84+
.package(url: "https://github.com/apple/swift-collections.git", from: "1.3.0"),
8485
.package(url: "https://github.com/apple/swift-crypto.git", from: "3.10.0"),
8586
.package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"),
8687
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"),
@@ -102,6 +103,7 @@ let package = Package(
102103
.target(
103104
name: "HomomorphicEncryption",
104105
dependencies: [
106+
.product(name: "Algorithms", package: "swift-algorithms"),
105107
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
106108
.product(name: "Crypto", package: "swift-crypto"),
107109
.product(name: "_CryptoExtras", package: "swift-crypto"),
@@ -123,6 +125,7 @@ let package = Package(
123125
name: "PrivateInformationRetrieval",
124126
dependencies: ["HomomorphicEncryption",
125127
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
128+
.product(name: "Collections", package: "swift-collections"),
126129
.product(name: "Numerics", package: "swift-numerics")],
127130
swiftSettings: librarySettings),
128131
.target(

Sources/HomomorphicEncryption/Bfv/Bfv.swift

Lines changed: 171 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
public import Algorithms
1516
public import ModularArithmetic
17+
import Foundation
1618

1719
/// Brakerski-Fan-Vercauteren cryptosystem.
1820
public enum Bfv<T: ScalarType>: HeScheme {
@@ -218,6 +220,97 @@ public enum Bfv<T: ScalarType>: HeScheme {
218220

219221
// MARK: Inner product
220222

223+
@inlinable
224+
static func validateInnerProductInput(_ lhs: CanonicalCiphertext,
225+
_ rhs: CanonicalCiphertext) throws
226+
{
227+
try validateEquality(of: lhs.context, and: rhs.context)
228+
guard lhs.polys.count == freshCiphertextPolyCount, lhs.correctionFactor == 1 else {
229+
throw HeError.invalidCiphertext(lhs)
230+
}
231+
guard rhs.polys.count == freshCiphertextPolyCount, rhs.correctionFactor == 1 else {
232+
throw HeError.invalidCiphertext(rhs)
233+
}
234+
}
235+
236+
@inlinable
237+
public static func innerProductAsync(_ lhs: some Collection<CanonicalCiphertext>,
238+
_ rhs: some Collection<CanonicalCiphertext>) async throws
239+
-> CanonicalCiphertext
240+
{
241+
// swiftlint:disable:next local_doc_comment
242+
/// Computes accumulator += lhs * rhs
243+
@Sendable
244+
func lazyMultiply(
245+
_ lhs: CanonicalCiphertext,
246+
_ rhs: CanonicalCiphertext,
247+
to accumulator: inout [Array2d<T.DoubleWidth>]) throws
248+
{
249+
try validateInnerProductInput(lhs, rhs)
250+
251+
let lhsPolys = try computeBehzPolys(ciphertext: lhs)
252+
let rhsPolys = try computeBehzPolys(ciphertext: rhs)
253+
PolyRq.addingLazyProduct(lhsPolys[0], rhsPolys[0], to: &accumulator[0])
254+
PolyRq.addingLazyProduct(lhsPolys[0], rhsPolys[1], to: &accumulator[1])
255+
PolyRq.addingLazyProduct(lhsPolys[1], rhsPolys[0], to: &accumulator[1])
256+
PolyRq.addingLazyProduct(lhsPolys[1], rhsPolys[1], to: &accumulator[2])
257+
}
258+
259+
let firstCiphertext = lhs[lhs.startIndex]
260+
let rnsTool = firstCiphertext.context.getRnsTool(moduliCount: firstCiphertext.moduli.count)
261+
let moduliCount = rnsTool.qBskContext.moduli.count
262+
let poly = firstCiphertext.polys[0]
263+
let maxProductCount = rnsTool.qBskContext.maxLazyProductAccumulationCount() / 2
264+
let polyContext = rnsTool.qBskContext
265+
266+
let pairs = Array(zip(lhs, rhs))
267+
let chunks = pairs.evenlyChunked(in: Util.activeProcessorCount)
268+
269+
// Each task accumulates its chunk independently then reduces to [0, modulus).
270+
let partials = try await chunks.concurrentMap(ordered: false) { chunk in
271+
var accumulator = Array(
272+
repeating: Array2d(
273+
data: Array(repeating: T.DoubleWidth(0), count: moduliCount * poly.degree),
274+
rowCount: moduliCount, columnCount: poly.degree),
275+
count: 3)
276+
var reduceCount = 0
277+
for (lhsCipher, rhsCipher) in chunk {
278+
try lazyMultiply(lhsCipher, rhsCipher, to: &accumulator)
279+
reduceCount += 1
280+
if reduceCount >= maxProductCount {
281+
reduceCount = 0
282+
reduceInPlace(accumulator: &accumulator, polyContext: polyContext)
283+
}
284+
}
285+
reduceInPlace(accumulator: &accumulator, polyContext: polyContext)
286+
return accumulator
287+
}
288+
289+
// Merge: values are in [0, modulus) after reduceInPlace, so summing fits in DoubleWidth.
290+
var merged = partials[0]
291+
for partial in partials.dropFirst() {
292+
for polyIndex in merged.indices {
293+
for i in merged[polyIndex].data.indices {
294+
#if DEBUG
295+
let (result, didOverflow) = merged[polyIndex].data[i]
296+
.addingReportingOverflow(partial[polyIndex].data[i])
297+
assert(!didOverflow, "Overflow in polynomial addition at polyIndex \(polyIndex), index \(i)")
298+
merged[polyIndex].data[i] = result
299+
#else
300+
merged[polyIndex].data[i] &+= partial[polyIndex].data[i]
301+
#endif
302+
}
303+
}
304+
}
305+
306+
var sum = try EvalCiphertext(
307+
context: firstCiphertext.context,
308+
polys: Array(repeating: .zero(context: rnsTool.qBskContext), count: 3),
309+
correctionFactor: 1)
310+
reduceToCiphertext(accumulator: merged, result: &sum)
311+
return try dropExtendedBase(from: sum)
312+
}
313+
221314
@inlinable
222315
public static func innerProduct(_ lhs: some Collection<CanonicalCiphertext>,
223316
_ rhs: some Collection<CanonicalCiphertext>) throws -> CanonicalCiphertext
@@ -229,13 +322,7 @@ public enum Bfv<T: ScalarType>: HeScheme {
229322
_ rhs: CanonicalCiphertext,
230323
to accumulator: inout [Array2d<T.DoubleWidth>]) throws
231324
{
232-
try validateEquality(of: lhs.context, and: rhs.context)
233-
guard lhs.polys.count == freshCiphertextPolyCount, lhs.correctionFactor == 1 else {
234-
throw HeError.invalidCiphertext(lhs)
235-
}
236-
guard rhs.polys.count == freshCiphertextPolyCount, rhs.correctionFactor == 1 else {
237-
throw HeError.invalidCiphertext(rhs)
238-
}
325+
try validateInnerProductInput(lhs, rhs)
239326

240327
let lhsPolys = try computeBehzPolys(ciphertext: lhs)
241328
let rhsPolys = try computeBehzPolys(ciphertext: rhs)
@@ -306,26 +393,89 @@ public enum Bfv<T: ScalarType>: HeScheme {
306393
}
307394
}
308395

396+
/// Computes accumulator += ciphertext * plaintext
309397
@inlinable
310-
public static func innerProduct(ciphertexts: some Collection<EvalCiphertext>,
311-
plaintexts: some Collection<EvalPlaintext?>) throws -> EvalCiphertext
398+
static func lazyMultiply(
399+
ciphertext: EvalCiphertext,
400+
plaintext: EvalPlaintext,
401+
to accumulator: inout [Array2d<T.DoubleWidth>]) throws
312402
{
313-
// swiftlint:disable:next local_doc_comment
314-
/// Computes accumulator += ciphertext * plaintext
315-
func lazyMultiply(
316-
ciphertext: EvalCiphertext,
317-
plaintext: EvalPlaintext,
318-
to accumulator: inout [Array2d<T.DoubleWidth>]) throws
319-
{
320-
try validateEquality(of: ciphertext.context, and: plaintext.context)
321-
guard ciphertext.moduli.count == plaintext.moduli.count else {
322-
throw HeError.incompatibleCiphertextAndPlaintext(ciphertext: ciphertext, plaintext: plaintext)
403+
try validateEquality(of: ciphertext.context, and: plaintext.context)
404+
guard ciphertext.moduli.count == plaintext.moduli.count else {
405+
throw HeError.incompatibleCiphertextAndPlaintext(ciphertext: ciphertext, plaintext: plaintext)
406+
}
407+
for (polyIndex, ciphertextPoly) in ciphertext.polys.enumerated() {
408+
PolyRq.addingLazyProduct(ciphertextPoly, plaintext.poly, to: &accumulator[polyIndex])
409+
}
410+
}
411+
412+
@inlinable
413+
public static func innerProductAsync(ciphertexts: some Collection<EvalCiphertext>,
414+
plaintexts: some Collection<EvalPlaintext?>) async throws -> EvalCiphertext
415+
{
416+
precondition(plaintexts.count == ciphertexts.count)
417+
guard var result = ciphertexts.first else {
418+
preconditionFailure("Empty ciphertexts")
419+
}
420+
precondition(ciphertexts.allSatisfy { $0.polys.count == result.polys.count },
421+
"All ciphertexts must have the same polynomial count")
422+
let poly = result.polys[0]
423+
let maxProductCount = poly.context.maxLazyProductAccumulationCount()
424+
let polyContext = result.polyContext()
425+
let polyCount = result.polys.count
426+
let dataCount = poly.data.count
427+
let moduliCount = poly.moduli.count
428+
let degree = poly.degree
429+
430+
let pairs = Array(zip(ciphertexts, plaintexts))
431+
let chunks = pairs.evenlyChunked(in: Util.activeProcessorCount)
432+
433+
// Each task accumulates its chunk independently then reduces to [0, modulus).
434+
let partials = try await chunks.concurrentMap(ordered: false) { chunk in
435+
var accumulator = Array(
436+
repeating: Array2d(
437+
data: Array(repeating: T.DoubleWidth(0), count: dataCount),
438+
rowCount: moduliCount, columnCount: degree),
439+
count: polyCount)
440+
var reduceCount = 0
441+
for (ciphertext, plaintext) in chunk {
442+
guard let plaintext else { continue }
443+
try lazyMultiply(ciphertext: ciphertext, plaintext: plaintext, to: &accumulator)
444+
reduceCount += 1
445+
if reduceCount >= maxProductCount {
446+
reduceCount = 0
447+
reduceInPlace(accumulator: &accumulator, polyContext: polyContext)
448+
}
323449
}
324-
for (polyIndex, ciphertextPoly) in ciphertext.polys.enumerated() {
325-
PolyRq.addingLazyProduct(ciphertextPoly, plaintext.poly, to: &accumulator[polyIndex])
450+
reduceInPlace(accumulator: &accumulator, polyContext: polyContext)
451+
return accumulator
452+
}
453+
454+
// Merge: values are in [0, modulus) after reduceInPlace, so summing fits in DoubleWidth.
455+
var merged = partials[0]
456+
for partial in partials.dropFirst() {
457+
for polyIndex in merged.indices {
458+
for i in merged[polyIndex].data.indices {
459+
#if DEBUG
460+
let (result, didOverflow) = merged[polyIndex].data[i]
461+
.addingReportingOverflow(partial[polyIndex].data[i])
462+
assert(!didOverflow, "Overflow in polynomial addition at polyIndex \(polyIndex), index \(i)")
463+
merged[polyIndex].data[i] = result
464+
#else
465+
merged[polyIndex].data[i] &+= partial[polyIndex].data[i]
466+
#endif
467+
}
326468
}
327469
}
328470

471+
reduceToCiphertext(accumulator: merged, result: &result)
472+
return result
473+
}
474+
475+
@inlinable
476+
public static func innerProduct(ciphertexts: some Collection<EvalCiphertext>,
477+
plaintexts: some Collection<EvalPlaintext?>) throws -> EvalCiphertext
478+
{
329479
precondition(plaintexts.count == ciphertexts.count)
330480
guard var result = ciphertexts.first else {
331481
preconditionFailure("Empty ciphertexts")

Sources/HomomorphicEncryption/HeScheme.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ public protocol HeScheme: Sendable {
922922
/// - Parameters:
923923
/// - ciphertext: Ciphertext to transform.
924924
/// - element: Galois element of the transformation. Must be odd in `[1, 2 * N - 1]` where `N` is the RLWE ring
925-
/// dimension, given by ``EncryptionParameters/polyDegree``.
925+
/// dimension, given by ``EncryptionParameters/polyDegree``.
926926
/// - key: Evaluation key. Must contain Galois element `element`.
927927
/// - Throws: Error upon failure to apply the Galois transformation.
928928
/// - seealso: ``applyGaloisAsync(ciphertext:element:using:)`` for an async version of this API

Sources/HomomorphicEncryption/Util.swift

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,82 @@ extension Array where Element: ScalarType {
135135
return width32.reduce(Width32<Self.Element>(1)) { $0 * $1 }
136136
}
137137
}
138+
139+
extension Collection where Self.Element: Sendable {
140+
@inlinable
141+
package func concurrentMap<T: Sendable>(
142+
ordered: Bool = true,
143+
_ transform: @Sendable @escaping (Element) async throws -> T) async throws -> [T]
144+
{
145+
// Fast path for empty collection
146+
if isEmpty {
147+
return []
148+
}
149+
150+
// We use enumerated indices (Int) for stable ordering.
151+
return try await withThrowingTaskGroup(of: (Int, T).self) { group in
152+
for (index, element) in self.enumerated() {
153+
group.addTask {
154+
try await (index, transform(element))
155+
}
156+
}
157+
158+
if ordered {
159+
var orderedBuffer = [T?](repeating: nil, count: self.count)
160+
for try await (index, value) in group {
161+
orderedBuffer[index] = value
162+
}
163+
return orderedBuffer.compactMap(\.self)
164+
}
165+
166+
var unorderedResults: [T] = []
167+
unorderedResults.reserveCapacity(self.count)
168+
for try await (_, value) in group {
169+
unorderedResults.append(value)
170+
}
171+
return unorderedResults
172+
}
173+
}
174+
175+
@inlinable
176+
package mutating func concurrentConsumingMap<T: Sendable>(
177+
ordered: Bool = true,
178+
_ transform: @Sendable @escaping (consuming Element) async throws -> T) async throws -> [T]
179+
{
180+
// Fast path for empty collection
181+
if isEmpty {
182+
return []
183+
}
184+
185+
// We use enumerated indices (Int) for stable ordering.
186+
return try await withThrowingTaskGroup(of: (Int, T).self) { group in
187+
for (index, element) in self.enumerated() {
188+
group.addTask {
189+
try await (index, transform(element))
190+
}
191+
}
192+
193+
if ordered {
194+
var orderedBuffer = [T?](repeating: nil, count: self.count)
195+
for try await (index, value) in group {
196+
orderedBuffer[index] = value
197+
}
198+
return orderedBuffer.compactMap(\.self)
199+
}
200+
201+
var unorderedResults: [T] = []
202+
unorderedResults.reserveCapacity(self.count)
203+
for try await (_, value) in group {
204+
unorderedResults.append(value)
205+
}
206+
return unorderedResults
207+
}
208+
}
209+
}
210+
211+
@usableFromInline
212+
enum Util {
213+
@usableFromInline static var activeProcessorCount: Int {
214+
ProcessInfo.processInfo.activeProcessorCount
215+
}
216+
}

0 commit comments

Comments
 (0)