Skip to content

Commit 8b7cafb

Browse files
committed
made code cleaner
1 parent 3fada53 commit 8b7cafb

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,7 @@ a [numpy-like](https://numpy.org/) interface for use in encrypted machine learni
5050
- vstack
5151
- re: expensive transpose, we do a hstack on the transpose to get the resulting transpose without actually doing the
5252
entire thing
53+
54+
# Trivia
55+
56+
This library is pronounced Tensor as the "p" is silent.

linear_regression_ames.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ int main() {
7171
// If numFolds ==1, we shuffle the single dataset
7272
// If numFolds == 0, we keep the order
7373
int numFolds = 0;
74-
float _alpha = 0.075;
75-
float _l2_regularization_factor = 0.25;
74+
float _alpha = 0.5;
75+
float _l2_regularization_factor = -1;
7676

7777
uint8_t multDepth = 8;
7878
uint8_t scalingFactorBits = 45;

src/p_tensor.cpp

+6-10
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ cipherTensor pTensor::binaryOpAbstraction(const char *flag,
7373

7474
// Now, we know that it is broadcast-able. Therefore, they either have the same shape or one is shape 1 in rows
7575
cipherVector op_res;
76-
int ringDim = (*m_cc)->GetRingDimension();
7776
for (unsigned int i = 0; i < std::max(m_rows, other.m_rows); i++) {
7877
unsigned int lhsInd;
7978
unsigned int rhsInd;
@@ -95,11 +94,10 @@ cipherTensor pTensor::binaryOpAbstraction(const char *flag,
9594
if (other.isScalar() && !(other.m_isRepeated)) {
9695

9796
// Get how much to sum over and rotate.
98-
int rot = int(-ringDim / 4) + 1;
9997
// We've now summed it up and it should be projected into the back
100-
otherVec = (*m_cc)->EvalSum(other.m_ciphertexts[0], -rot);
98+
otherVec = (*m_cc)->EvalSum(other.m_ciphertexts[0], -getRepeatBatchSize());
10199
// The last rot entries are now populated with the value. We then rotate them back and we are done.
102-
otherVec = (*m_cc)->EvalAtIndex(otherVec, rot);
100+
otherVec = (*m_cc)->EvalAtIndex(otherVec, getRepeatBatchSize());
103101

104102
} else {
105103
otherVec = other.m_ciphertexts[rhsInd];
@@ -268,7 +266,7 @@ pTensor pTensor::dot(pTensor &other, bool asRowVector) {
268266
auto innerProd = (*m_cc)->EvalInnerProduct(
269267
m_ciphertexts[i],
270268
rhs.m_ciphertexts[HARDCODED_INDEX_FOR_OTHER_VECTOR],
271-
((*m_cc)->GetRingDimension() / 4));
269+
getBatchSize());
272270

273271
innerProd = (*m_cc)->EvalMult(innerProd, mask);
274272

@@ -350,7 +348,7 @@ pTensor pTensor::sum(int axis) {
350348
// Sum across the rows
351349
cipherTensor accumulator;
352350
for (const auto &item : m_ciphertexts) {
353-
auto resp = (*m_cc)->EvalSum(item, (*m_cc)->GetRingDimension() / 4);
351+
auto resp = (*m_cc)->EvalSum(item, getBatchSize());
354352
accumulator.emplace_back(resp);
355353
}
356354

@@ -562,16 +560,14 @@ pTensor pTensor::applyGradient(pTensor matrixOfWeights, pTensor vectorGradients)
562560

563561
lbcrypto::Plaintext pt;
564562
int index = 0;
565-
int ringDim = (*m_cc)->GetRingDimension();
566-
int rot = int(-ringDim / 4) + 1;
567563

568564
for (auto &row: maskedGradients.m_ciphertexts) {
569565
auto maskedVal = row;
570566
for (int i = 0; i < (index + 1); ++i) {
571567
maskedVal = (*m_cc)->EvalAtIndex(maskedVal, 1);
572568
}
573-
maskedVal = (*m_cc)->EvalSum(maskedVal, -rot);
574-
maskedVal = (*m_cc)->EvalAtIndex(maskedVal, rot);
569+
maskedVal = (*m_cc)->EvalSum(maskedVal, -getRepeatBatchSize());
570+
maskedVal = (*m_cc)->EvalAtIndex(maskedVal, getRepeatBatchSize());
575571
tensorCipherContainer.emplace_back(maskedVal);
576572
index += 1;
577573
}

src/p_tensor.h

+22
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,28 @@ class pTensor {
512512
return newTensor;
513513
}
514514

515+
/**
516+
* Get the batch size for repetition of elements. The value of interest must be in the -1-th index slot (they wrap around). We
517+
* effectively sum which is a cumulative sum. Now, only the last batchSize values are the values of interest so we then
518+
* rotate the entire cipher around
519+
*
520+
* For the summation, we use the negative of this value
521+
* For the rotation, we use the actual value
522+
* @return
523+
*/
524+
static int getRepeatBatchSize() {
525+
int ringDim = (*m_cc)->GetRingDimension();
526+
return int(-ringDim / 4) + 1;
527+
};
528+
529+
/**
530+
* Used for non-repeat sums and or inner products
531+
* @return
532+
*/
533+
static int getBatchSize() {
534+
return int((*m_cc)->GetRingDimension() / 4);
535+
}
536+
515537
private:
516538

517539
/**

0 commit comments

Comments
 (0)