Skip to content

Commit

Permalink
fix an unintended variable reuse in UdpEncryption object (#502)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #502

`indexOffset_` is meant to record how much data has been processed. And my data and peer's data should be recorded separately. However the original implementation mistakenly used the same variable and the unit test failed to catch that. This diff fixes the bug and adds necessary tests.

Reviewed By: haochenuw

Differential Revision: D43746711

fbshipit-source-id: 94abdfdf4bd6890c32c7ce5019889a14e03fe3ab
  • Loading branch information
Ruiyu Zhu authored and facebook-github-bot committed Mar 9, 2023
1 parent 353a000 commit f2bae56
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void UdpEncryption::prepareToProcessMyData(size_t myDataWidth) {
statusOfProcessingMyData_ = Status::inProgress;
myDataWidth_ = myDataWidth;
prgKey_ = fbpcf::engine::util::getRandomM128iFromSystemNoise();
indexOffset_ = 0;
myDataIndexOffset_ = 0;
}

void UdpEncryption::processMyData(
Expand All @@ -45,12 +45,12 @@ void UdpEncryption::processMyData(
" but get " + std::to_string(plaintextData.at(0).size()));
}
auto [ciphertext, nonce] =
UdpUtil::localEncryption(plaintextData, prgKey_, indexOffset_);
UdpUtil::localEncryption(plaintextData, prgKey_, myDataIndexOffset_);
agent_->send(nonce);
for (size_t i = 0; i < ciphertext.size(); i++) {
agent_->send(ciphertext.at(i));
}
indexOffset_ += plaintextData.size();
myDataIndexOffset_ += plaintextData.size();
}

void UdpEncryption::prepareToProcessPeerData(
Expand All @@ -67,7 +67,7 @@ void UdpEncryption::prepareToProcessPeerData(
}

peerDataWidth_ = peerDataWidth;
indexOffset_ = 0;
peerDataIndexOffset_ = 0;

cherryPickedEncryption_ =
std::vector<std::vector<unsigned char>>(indexes.size());
Expand All @@ -87,17 +87,17 @@ void UdpEncryption::processPeerData(size_t dataSize) {

for (size_t i = 0; i < dataSize; i++) {
auto ciphertext = agent_->receive(peerDataWidth_);
auto pos = indexToOrderMap_.find(i + indexOffset_);
auto pos = indexToOrderMap_.find(i + peerDataIndexOffset_);
if (pos != indexToOrderMap_.end()) {
// this ciphertext should be picked up
cherryPickedEncryption_.at(pos->second) = std::move(ciphertext);
cherryPickedNonce_.at(pos->second) = nonce;
cherryPickedIndex_.at(pos->second) = i + indexOffset_;
cherryPickedIndex_.at(pos->second) = i + peerDataIndexOffset_;
indexToOrderMap_.erase(pos);
// TODO: this can be further optimized by not copying duplicated nonce.
}
}
indexOffset_ += dataSize;
peerDataIndexOffset_ += dataSize;
}

} // namespace fbpcf::mpc_std_lib::unified_data_process::data_processor
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class UdpEncryption {
std::unique_ptr<fbpcf::engine::communication::IPartyCommunicationAgent>
agent_;

uint64_t indexOffset_;
uint64_t myDataIndexOffset_;
uint64_t peerDataIndexOffset_;

size_t myDataWidth_;
__m128i prgKey_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,23 @@ std::tuple<std::vector<T>, std::vector<T>> split(
std::vector<T>(src.begin() + cutPosition, src.end())};
}

std::vector<std::vector<uint8_t>> convertToVecs(
size_t rowCount,
size_t dataWidth,
const std::vector<std::vector<bool>>& src) {
std::vector<std::vector<uint8_t>> rst(
rowCount, std::vector<uint8_t>(dataWidth));

for (size_t i = 0; i < dataWidth; i++) {
for (uint8_t j = 0; j < 8; j++) {
for (size_t k = 0; k < rowCount; k++) {
rst.at(k).at(i) += (src.at(i * 8 + j).at(k) << j);
}
}
}
return rst;
}

void testUdpEncryptionAndDecryptionObjects(
std::unique_ptr<engine::communication::IPartyCommunicationAgent> agent0,
std::unique_ptr<engine::communication::IPartyCommunicationAgent> agent1) {
Expand All @@ -162,69 +179,94 @@ void testUdpEncryptionAndDecryptionObjects(
auto udpDec10 = std::make_unique<UdpDecryption<2>>(0, 1);
auto udpDec11 = std::make_unique<UdpDecryption<3>>(1, 0);

auto task0 = [numberOfInputShards](
std::unique_ptr<UdpEncryption> udpEnc,
std::unique_ptr<UdpDecryption<0>> udpDec0,
std::unique_ptr<UdpDecryption<2>> udpDec1,
const std::vector<std::vector<std::vector<unsigned char>>>&
plaintextDataInShards,
size_t dataWidth,
size_t outputSize) {
auto task0 = [](std::unique_ptr<UdpEncryption> udpEnc,
std::unique_ptr<UdpDecryption<0>> udpDec0,
std::unique_ptr<UdpDecryption<2>> udpDec1,
const std::vector<std::vector<std::vector<unsigned char>>>&
plaintextDataInShards,
size_t dataWidth,
size_t outputSize,
const std::vector<int32_t>& indexes,
const std::vector<size_t>& sizes) {
udpEnc->prepareToProcessMyData(dataWidth);
for (size_t i = 0; i < numberOfInputShards; i++) {
for (size_t i = 0; i < plaintextDataInShards.size(); i++) {
udpEnc->processMyData(plaintextDataInShards.at(i));
};
udpEnc->prepareToProcessPeerData(dataWidth, indexes);
for (size_t i = 0; i < sizes.size(); i++) {
udpEnc->processPeerData(sizes.at(i));
}

size_t outputShard0Size = outputSize / 2;
size_t outputShard1Size = outputSize - outputShard0Size;
auto key = udpEnc->getExpandedKey();

auto result0 = udpDec0->decryptMyData(key, dataWidth, outputShard0Size)
.openToParty(0)
.getValue();
auto result1 = udpDec1->decryptMyData(key, dataWidth, outputShard1Size)
.openToParty(0)
.getValue();

std::vector<std::vector<uint8_t>> rst0(
outputShard0Size, std::vector<uint8_t>(dataWidth));
for (size_t i = 0; i < dataWidth; i++) {
for (uint8_t j = 0; j < 8; j++) {
for (size_t k = 0; k < outputShard0Size; k++) {
rst0.at(k).at(i) += (result0.at(i * 8 + j).at(k) << j);
}
}
}

std::vector<std::vector<uint8_t>> rst1(
outputShard1Size, std::vector<uint8_t>(dataWidth));
for (size_t i = 0; i < dataWidth; i++) {
for (uint8_t j = 0; j < 8; j++) {
for (size_t k = 0; k < outputShard1Size; k++) {
rst1.at(k).at(i) += (result1.at(i * 8 + j).at(k) << j);
}
}
}
auto rst0 = convertToVecs(outputShard0Size, dataWidth, result0);
auto rst1 = convertToVecs(outputShard1Size, dataWidth, result1);
rst0.insert(rst0.end(), rst1.begin(), rst1.end());
return rst0;

auto [intersection, nonces, pickedIndexes] = udpEnc->getProcessedData();

auto [intersection0, intersection1] = split(intersection, outputShard0Size);
auto [nonces0, nonces1] = split(nonces, outputShard0Size);
auto [indexes0, indexes1] = split(pickedIndexes, outputShard0Size);

auto rst2 = convertToVecs(
outputShard0Size,
dataWidth,
udpDec0->decryptPeerData(intersection0, nonces0, indexes0)
.openToParty(0)
.getValue());
auto rst3 = convertToVecs(
outputShard1Size,
dataWidth,
udpDec1->decryptPeerData(intersection1, nonces1, indexes1)
.openToParty(0)
.getValue());
rst2.insert(rst2.end(), rst3.begin(), rst3.end());
return std::make_tuple(rst0, rst2);
};

auto task1 = [numberOfInputShards, &dataWidth](
std::unique_ptr<UdpEncryption> udpEnc,
std::unique_ptr<UdpDecryption<1>> udpDec0,
std::unique_ptr<UdpDecryption<3>> udpDec1,
const std::vector<int32_t>& indexes,
const std::vector<size_t>& sizes) {
auto task1 = [](std::unique_ptr<UdpEncryption> udpEnc,
std::unique_ptr<UdpDecryption<1>> udpDec0,
std::unique_ptr<UdpDecryption<3>> udpDec1,
const std::vector<int32_t>& indexes,
const std::vector<size_t>& sizes,
const std::vector<std::vector<std::vector<unsigned char>>>&
plaintextDataInShards,
size_t dataWidth) {
udpEnc->prepareToProcessPeerData(dataWidth, indexes);
for (size_t i = 0; i < numberOfInputShards; i++) {
for (size_t i = 0; i < sizes.size(); i++) {
udpEnc->processPeerData(sizes.at(i));
}
udpEnc->prepareToProcessMyData(dataWidth);
for (size_t i = 0; i < plaintextDataInShards.size(); i++) {
udpEnc->processMyData(plaintextDataInShards.at(i));
};

auto [intersection, nonces, pickedIndexes] = udpEnc->getProcessedData();

size_t outputShard0Size = intersection.size() / 2;
size_t outputShard1Size = intersection.size() - outputShard0Size;

auto [intersection0, intersection1] = split(intersection, outputShard0Size);
auto [nonces0, nonces1] = split(nonces, outputShard0Size);
auto [indexes0, indexes1] = split(pickedIndexes, outputShard0Size);

udpDec0->decryptPeerData(intersection0, nonces0, indexes0).openToParty(0);
udpDec1->decryptPeerData(intersection1, nonces1, indexes1).openToParty(0);

auto key = udpEnc->getExpandedKey();

udpDec0->decryptMyData(key, dataWidth, outputShard0Size).openToParty(0);
udpDec1->decryptMyData(key, dataWidth, outputShard1Size).openToParty(0);
};

auto future1 = std::async(
Expand All @@ -233,18 +275,25 @@ void testUdpEncryptionAndDecryptionObjects(
std::move(udpDec01),
std::move(udpDec11),
indexes,
sizes);
sizes,
shard,
dataWidth);
auto rst = task0(
std::move(udpEnc0),
std::move(udpDec00),
std::move(udpDec10),
shard,
dataWidth,
outputSize);
outputSize,
indexes,
sizes);
future1.get();

for (size_t i = 0; i < outputSize; i++) {
fbpcf::testVectorEq(rst.at(i), expectedOutput.at(i));
fbpcf::testVectorEq(std::get<0>(rst).at(i), expectedOutput.at(i));
}
for (size_t i = 0; i < outputSize; i++) {
fbpcf::testVectorEq(std::get<1>(rst).at(i), expectedOutput.at(i));
}
}

Expand Down

0 comments on commit f2bae56

Please sign in to comment.