@@ -73,7 +73,6 @@ cipherTensor pTensor::binaryOpAbstraction(const char *flag,
73
73
74
74
// Now, we know that it is broadcast-able. Therefore, they either have the same shape or one is shape 1 in rows
75
75
cipherVector op_res;
76
- int ringDim = (*m_cc)->GetRingDimension ();
77
76
for (unsigned int i = 0 ; i < std::max (m_rows, other.m_rows ); i++) {
78
77
unsigned int lhsInd;
79
78
unsigned int rhsInd;
@@ -95,11 +94,10 @@ cipherTensor pTensor::binaryOpAbstraction(const char *flag,
95
94
if (other.isScalar () && !(other.m_isRepeated )) {
96
95
97
96
// Get how much to sum over and rotate.
98
- int rot = int (-ringDim / 4 ) + 1 ;
99
97
// We've now summed it up and it should be projected into the back
100
- otherVec = (*m_cc)->EvalSum (other.m_ciphertexts [0 ], -rot );
98
+ otherVec = (*m_cc)->EvalSum (other.m_ciphertexts [0 ], -getRepeatBatchSize () );
101
99
// The last rot entries are now populated with the value. We then rotate them back and we are done.
102
- otherVec = (*m_cc)->EvalAtIndex (otherVec, rot );
100
+ otherVec = (*m_cc)->EvalAtIndex (otherVec, getRepeatBatchSize () );
103
101
104
102
} else {
105
103
otherVec = other.m_ciphertexts [rhsInd];
@@ -268,7 +266,7 @@ pTensor pTensor::dot(pTensor &other, bool asRowVector) {
268
266
auto innerProd = (*m_cc)->EvalInnerProduct (
269
267
m_ciphertexts[i],
270
268
rhs.m_ciphertexts [HARDCODED_INDEX_FOR_OTHER_VECTOR],
271
- ((*m_cc)-> GetRingDimension () / 4 ));
269
+ getBatchSize ( ));
272
270
273
271
innerProd = (*m_cc)->EvalMult (innerProd, mask);
274
272
@@ -350,7 +348,7 @@ pTensor pTensor::sum(int axis) {
350
348
// Sum across the rows
351
349
cipherTensor accumulator;
352
350
for (const auto &item : m_ciphertexts) {
353
- auto resp = (*m_cc)->EvalSum (item, (*m_cc)-> GetRingDimension () / 4 );
351
+ auto resp = (*m_cc)->EvalSum (item, getBatchSize () );
354
352
accumulator.emplace_back (resp);
355
353
}
356
354
@@ -562,16 +560,14 @@ pTensor pTensor::applyGradient(pTensor matrixOfWeights, pTensor vectorGradients)
562
560
563
561
lbcrypto::Plaintext pt;
564
562
int index = 0 ;
565
- int ringDim = (*m_cc)->GetRingDimension ();
566
- int rot = int (-ringDim / 4 ) + 1 ;
567
563
568
564
for (auto &row: maskedGradients.m_ciphertexts ) {
569
565
auto maskedVal = row;
570
566
for (int i = 0 ; i < (index + 1 ); ++i) {
571
567
maskedVal = (*m_cc)->EvalAtIndex (maskedVal, 1 );
572
568
}
573
- maskedVal = (*m_cc)->EvalSum (maskedVal, -rot );
574
- maskedVal = (*m_cc)->EvalAtIndex (maskedVal, rot );
569
+ maskedVal = (*m_cc)->EvalSum (maskedVal, -getRepeatBatchSize () );
570
+ maskedVal = (*m_cc)->EvalAtIndex (maskedVal, getRepeatBatchSize () );
575
571
tensorCipherContainer.emplace_back (maskedVal);
576
572
index += 1 ;
577
573
}
0 commit comments