1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ public import Algorithms
1516public import ModularArithmetic
17+ import Foundation
1618
1719/// Brakerski-Fan-Vercauteren cryptosystem.
1820public enum Bfv < T: ScalarType > : HeScheme {
@@ -218,6 +220,97 @@ public enum Bfv<T: ScalarType>: HeScheme {
218220
219221 // MARK: Inner product
220222
223+ @inlinable
224+ static func validateInnerProductInput( _ lhs: CanonicalCiphertext ,
225+ _ rhs: CanonicalCiphertext ) throws
226+ {
227+ try validateEquality ( of: lhs. context, and: rhs. context)
228+ guard lhs. polys. count == freshCiphertextPolyCount, lhs. correctionFactor == 1 else {
229+ throw HeError . invalidCiphertext ( lhs)
230+ }
231+ guard rhs. polys. count == freshCiphertextPolyCount, rhs. correctionFactor == 1 else {
232+ throw HeError . invalidCiphertext ( rhs)
233+ }
234+ }
235+
236+ @inlinable
237+ public static func innerProductAsync( _ lhs: some Collection < CanonicalCiphertext > ,
238+ _ rhs: some Collection < CanonicalCiphertext > ) async throws
239+ -> CanonicalCiphertext
240+ {
241+ // swiftlint:disable:next local_doc_comment
242+ /// Computes accumulator += lhs * rhs
243+ @Sendable
244+ func lazyMultiply(
245+ _ lhs: CanonicalCiphertext ,
246+ _ rhs: CanonicalCiphertext ,
247+ to accumulator: inout [ Array2d < T . DoubleWidth > ] ) throws
248+ {
249+ try validateInnerProductInput ( lhs, rhs)
250+
251+ let lhsPolys = try computeBehzPolys ( ciphertext: lhs)
252+ let rhsPolys = try computeBehzPolys ( ciphertext: rhs)
253+ PolyRq . addingLazyProduct ( lhsPolys [ 0 ] , rhsPolys [ 0 ] , to: & accumulator[ 0 ] )
254+ PolyRq . addingLazyProduct ( lhsPolys [ 0 ] , rhsPolys [ 1 ] , to: & accumulator[ 1 ] )
255+ PolyRq . addingLazyProduct ( lhsPolys [ 1 ] , rhsPolys [ 0 ] , to: & accumulator[ 1 ] )
256+ PolyRq . addingLazyProduct ( lhsPolys [ 1 ] , rhsPolys [ 1 ] , to: & accumulator[ 2 ] )
257+ }
258+
259+ let firstCiphertext = lhs [ lhs. startIndex]
260+ let rnsTool = firstCiphertext. context. getRnsTool ( moduliCount: firstCiphertext. moduli. count)
261+ let moduliCount = rnsTool. qBskContext. moduli. count
262+ let poly = firstCiphertext. polys [ 0 ]
263+ let maxProductCount = rnsTool. qBskContext. maxLazyProductAccumulationCount ( ) / 2
264+ let polyContext = rnsTool. qBskContext
265+
266+ let pairs = Array ( zip ( lhs, rhs) )
267+ let chunks = pairs. evenlyChunked ( in: Util . activeProcessorCount)
268+
269+ // Each task accumulates its chunk independently then reduces to [0, modulus).
270+ let partials = try await chunks. concurrentMap ( ordered: false ) { chunk in
271+ var accumulator = Array (
272+ repeating: Array2d (
273+ data: Array ( repeating: T . DoubleWidth ( 0 ) , count: moduliCount * poly. degree) ,
274+ rowCount: moduliCount, columnCount: poly. degree) ,
275+ count: 3 )
276+ var reduceCount = 0
277+ for (lhsCipher, rhsCipher) in chunk {
278+ try lazyMultiply ( lhsCipher, rhsCipher, to: & accumulator)
279+ reduceCount += 1
280+ if reduceCount >= maxProductCount {
281+ reduceCount = 0
282+ reduceInPlace ( accumulator: & accumulator, polyContext: polyContext)
283+ }
284+ }
285+ reduceInPlace ( accumulator: & accumulator, polyContext: polyContext)
286+ return accumulator
287+ }
288+
289+ // Merge: values are in [0, modulus) after reduceInPlace, so summing fits in DoubleWidth.
290+ var merged = partials [ 0 ]
291+ for partial in partials. dropFirst ( ) {
292+ for polyIndex in merged. indices {
293+ for i in merged [ polyIndex] . data. indices {
294+ #if DEBUG
295+ let ( result, didOverflow) = merged [ polyIndex] . data [ i]
296+ . addingReportingOverflow ( partial [ polyIndex] . data [ i] )
297+ assert ( !didOverflow, " Overflow in polynomial addition at polyIndex \( polyIndex) , index \( i) " )
298+ merged [ polyIndex] . data [ i] = result
299+ #else
300+ merged [ polyIndex] . data [ i] &+= partial [ polyIndex] . data [ i]
301+ #endif
302+ }
303+ }
304+ }
305+
306+ var sum = try EvalCiphertext (
307+ context: firstCiphertext. context,
308+ polys: Array ( repeating: . zero( context: rnsTool. qBskContext) , count: 3 ) ,
309+ correctionFactor: 1 )
310+ reduceToCiphertext ( accumulator: merged, result: & sum)
311+ return try dropExtendedBase ( from: sum)
312+ }
313+
221314 @inlinable
222315 public static func innerProduct( _ lhs: some Collection < CanonicalCiphertext > ,
223316 _ rhs: some Collection < CanonicalCiphertext > ) throws -> CanonicalCiphertext
@@ -229,13 +322,7 @@ public enum Bfv<T: ScalarType>: HeScheme {
229322 _ rhs: CanonicalCiphertext ,
230323 to accumulator: inout [ Array2d < T . DoubleWidth > ] ) throws
231324 {
232- try validateEquality ( of: lhs. context, and: rhs. context)
233- guard lhs. polys. count == freshCiphertextPolyCount, lhs. correctionFactor == 1 else {
234- throw HeError . invalidCiphertext ( lhs)
235- }
236- guard rhs. polys. count == freshCiphertextPolyCount, rhs. correctionFactor == 1 else {
237- throw HeError . invalidCiphertext ( rhs)
238- }
325+ try validateInnerProductInput ( lhs, rhs)
239326
240327 let lhsPolys = try computeBehzPolys ( ciphertext: lhs)
241328 let rhsPolys = try computeBehzPolys ( ciphertext: rhs)
@@ -306,26 +393,89 @@ public enum Bfv<T: ScalarType>: HeScheme {
306393 }
307394 }
308395
396+ /// Computes accumulator += ciphertext * plaintext
309397 @inlinable
310- public static func innerProduct( ciphertexts: some Collection < EvalCiphertext > ,
311- plaintexts: some Collection < EvalPlaintext ? > ) throws -> EvalCiphertext
398+ static func lazyMultiply(
399+ ciphertext: EvalCiphertext ,
400+ plaintext: EvalPlaintext ,
401+ to accumulator: inout [ Array2d < T . DoubleWidth > ] ) throws
312402 {
313- // swiftlint:disable:next local_doc_comment
314- /// Computes accumulator += ciphertext * plaintext
315- func lazyMultiply(
316- ciphertext: EvalCiphertext ,
317- plaintext: EvalPlaintext ,
318- to accumulator: inout [ Array2d < T . DoubleWidth > ] ) throws
319- {
320- try validateEquality ( of: ciphertext. context, and: plaintext. context)
321- guard ciphertext. moduli. count == plaintext. moduli. count else {
322- throw HeError . incompatibleCiphertextAndPlaintext ( ciphertext: ciphertext, plaintext: plaintext)
403+ try validateEquality ( of: ciphertext. context, and: plaintext. context)
404+ guard ciphertext. moduli. count == plaintext. moduli. count else {
405+ throw HeError . incompatibleCiphertextAndPlaintext ( ciphertext: ciphertext, plaintext: plaintext)
406+ }
407+ for (polyIndex, ciphertextPoly) in ciphertext. polys. enumerated ( ) {
408+ PolyRq . addingLazyProduct ( ciphertextPoly, plaintext. poly, to: & accumulator[ polyIndex] )
409+ }
410+ }
411+
412+ @inlinable
413+ public static func innerProductAsync( ciphertexts: some Collection < EvalCiphertext > ,
414+ plaintexts: some Collection < EvalPlaintext ? > ) async throws -> EvalCiphertext
415+ {
416+ precondition ( plaintexts. count == ciphertexts. count)
417+ guard var result = ciphertexts. first else {
418+ preconditionFailure ( " Empty ciphertexts " )
419+ }
420+ precondition ( ciphertexts. allSatisfy { $0. polys. count == result. polys. count } ,
421+ " All ciphertexts must have the same polynomial count " )
422+ let poly = result. polys [ 0 ]
423+ let maxProductCount = poly. context. maxLazyProductAccumulationCount ( )
424+ let polyContext = result. polyContext ( )
425+ let polyCount = result. polys. count
426+ let dataCount = poly. data. count
427+ let moduliCount = poly. moduli. count
428+ let degree = poly. degree
429+
430+ let pairs = Array ( zip ( ciphertexts, plaintexts) )
431+ let chunks = pairs. evenlyChunked ( in: Util . activeProcessorCount)
432+
433+ // Each task accumulates its chunk independently then reduces to [0, modulus).
434+ let partials = try await chunks. concurrentMap ( ordered: false ) { chunk in
435+ var accumulator = Array (
436+ repeating: Array2d (
437+ data: Array ( repeating: T . DoubleWidth ( 0 ) , count: dataCount) ,
438+ rowCount: moduliCount, columnCount: degree) ,
439+ count: polyCount)
440+ var reduceCount = 0
441+ for (ciphertext, plaintext) in chunk {
442+ guard let plaintext else { continue }
443+ try lazyMultiply ( ciphertext: ciphertext, plaintext: plaintext, to: & accumulator)
444+ reduceCount += 1
445+ if reduceCount >= maxProductCount {
446+ reduceCount = 0
447+ reduceInPlace ( accumulator: & accumulator, polyContext: polyContext)
448+ }
323449 }
324- for (polyIndex, ciphertextPoly) in ciphertext. polys. enumerated ( ) {
325- PolyRq . addingLazyProduct ( ciphertextPoly, plaintext. poly, to: & accumulator[ polyIndex] )
450+ reduceInPlace ( accumulator: & accumulator, polyContext: polyContext)
451+ return accumulator
452+ }
453+
454+ // Merge: values are in [0, modulus) after reduceInPlace, so summing fits in DoubleWidth.
455+ var merged = partials [ 0 ]
456+ for partial in partials. dropFirst ( ) {
457+ for polyIndex in merged. indices {
458+ for i in merged [ polyIndex] . data. indices {
459+ #if DEBUG
460+ let ( result, didOverflow) = merged [ polyIndex] . data [ i]
461+ . addingReportingOverflow ( partial [ polyIndex] . data [ i] )
462+ assert ( !didOverflow, " Overflow in polynomial addition at polyIndex \( polyIndex) , index \( i) " )
463+ merged [ polyIndex] . data [ i] = result
464+ #else
465+ merged [ polyIndex] . data [ i] &+= partial [ polyIndex] . data [ i]
466+ #endif
467+ }
326468 }
327469 }
328470
471+ reduceToCiphertext ( accumulator: merged, result: & result)
472+ return result
473+ }
474+
475+ @inlinable
476+ public static func innerProduct( ciphertexts: some Collection < EvalCiphertext > ,
477+ plaintexts: some Collection < EvalPlaintext ? > ) throws -> EvalCiphertext
478+ {
329479 precondition ( plaintexts. count == ciphertexts. count)
330480 guard var result = ciphertexts. first else {
331481 preconditionFailure ( " Empty ciphertexts " )
0 commit comments