Skip to content

Commit 8a3e297

Browse files
committed
feat(shortint): add kreyvium transcipher
1 parent b9f1d68 commit 8a3e297

10 files changed

Lines changed: 1619 additions & 148 deletions

File tree

tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kreyvium.rs

Lines changed: 22 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use crate::integer::keycache::KEY_CACHE;
22
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
33
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey};
44
use crate::shortint::parameters::TestParameters;
5+
use crate::transciphering::ciphers::kreyvium::KreyviumPlainStream;
6+
use crate::transciphering::StreamCipher;
57
use rand::Rng;
68
use std::fmt::Write;
79
use std::sync::Arc;
@@ -21,120 +23,6 @@ fn decrypt_bits(cks: &RadixClientKey, ct: &RadixCiphertext) -> Vec<u8> {
2123
.collect()
2224
}
2325

24-
struct KreyviumRef {
25-
a: Vec<u64>,
26-
b: Vec<u64>,
27-
c: Vec<u64>,
28-
k: Vec<u64>,
29-
iv: Vec<u64>,
30-
cursor_a: usize,
31-
cursor_b: usize,
32-
cursor_c: usize,
33-
cursor_k: usize,
34-
cursor_iv: usize,
35-
}
36-
37-
impl KreyviumRef {
38-
fn new(key_bits: &[u64], iv_bits: &[u64]) -> Self {
39-
let mut a = vec![0u64; 93];
40-
let mut b = vec![0u64; 84];
41-
let mut c = vec![0u64; 111];
42-
let mut k = key_bits.to_vec();
43-
let mut iv = iv_bits.to_vec();
44-
45-
assert_eq!(k.len(), 128);
46-
assert_eq!(iv.len(), 128);
47-
48-
for i in 0..93 {
49-
a[i] = key_bits[128 - 93 + i];
50-
}
51-
for i in 0..84 {
52-
b[i] = iv_bits[128 - 84 + i];
53-
}
54-
for i in 0..44 {
55-
c[111 - 44 + i] = iv_bits[i];
56-
}
57-
for i in 0..66 {
58-
c[i + 1] = 1;
59-
}
60-
61-
k.reverse();
62-
iv.reverse();
63-
64-
let mut kreyvium = Self {
65-
a,
66-
b,
67-
c,
68-
k,
69-
iv,
70-
cursor_a: 0,
71-
cursor_b: 0,
72-
cursor_c: 0,
73-
cursor_k: 0,
74-
cursor_iv: 0,
75-
};
76-
77-
for _ in 0..1152 {
78-
kreyvium.next_bit();
79-
}
80-
81-
kreyvium
82-
}
83-
84-
fn next_bit(&mut self) -> u8 {
85-
let idx_a = |cursor: usize, i: usize| -> usize { (93 + cursor - i - 1) % 93 };
86-
let idx_b = |cursor: usize, i: usize| -> usize { (84 + cursor - i - 1) % 84 };
87-
let idx_c = |cursor: usize, i: usize| -> usize { (111 + cursor - i - 1) % 111 };
88-
let idx_k = |cursor: usize, i: usize| -> usize { (128 + cursor - i - 1) % 128 };
89-
let idx_iv = |cursor: usize, i: usize| -> usize { (128 + cursor - i - 1) % 128 };
90-
91-
let k_val = self.k[idx_k(self.cursor_k, 127)];
92-
let iv_val = self.iv[idx_iv(self.cursor_iv, 127)];
93-
94-
let a1 = self.a[idx_a(self.cursor_a, 65)];
95-
let a2 = self.a[idx_a(self.cursor_a, 92)];
96-
let a3 = self.a[idx_a(self.cursor_a, 91)];
97-
let a4 = self.a[idx_a(self.cursor_a, 90)];
98-
let a5 = self.a[idx_a(self.cursor_a, 68)];
99-
100-
let b1 = self.b[idx_b(self.cursor_b, 68)];
101-
let b2 = self.b[idx_b(self.cursor_b, 83)];
102-
let b3 = self.b[idx_b(self.cursor_b, 82)];
103-
let b4 = self.b[idx_b(self.cursor_b, 81)];
104-
let b5 = self.b[idx_b(self.cursor_b, 77)];
105-
106-
let c1 = self.c[idx_c(self.cursor_c, 65)];
107-
let c2 = self.c[idx_c(self.cursor_c, 110)];
108-
let c3 = self.c[idx_c(self.cursor_c, 109)];
109-
let c4 = self.c[idx_c(self.cursor_c, 108)];
110-
let c5 = self.c[idx_c(self.cursor_c, 86)];
111-
112-
let temp_a = a1 ^ a2;
113-
let temp_b = b1 ^ b2;
114-
let temp_c = c1 ^ c2 ^ k_val;
115-
116-
let new_a = (c3 & c4) ^ a5 ^ temp_c;
117-
let new_b = (a3 & a4) ^ b5 ^ temp_a ^ iv_val;
118-
let new_c = (b3 & b4) ^ c5 ^ temp_b;
119-
120-
let out = temp_a ^ temp_b ^ temp_c;
121-
122-
self.a[self.cursor_a] = new_a;
123-
self.cursor_a = (self.cursor_a + 1) % 93;
124-
125-
self.b[self.cursor_b] = new_b;
126-
self.cursor_b = (self.cursor_b + 1) % 84;
127-
128-
self.c[self.cursor_c] = new_c;
129-
self.cursor_c = (self.cursor_c + 1) % 111;
130-
131-
self.cursor_k = (self.cursor_k + 1) % 128;
132-
self.cursor_iv = (self.cursor_iv + 1) % 128;
133-
134-
out as u8
135-
}
136-
}
137-
13826
fn bits_to_hex(bits: &[u8]) -> String {
13927
let mut result = String::new();
14028
for chunk in bits.chunks(8) {
@@ -160,27 +48,6 @@ fn parse_hex_to_bits(s: &str) -> Vec<u64> {
16048
bits
16149
}
16250

163-
/// Tests the Rust reference implementation of Kreyvium against a known test vector.
164-
/// This ensures the logic in `KreyviumRef` is correct before comparing it to FHE.
165-
#[test]
166-
fn test_kreyvium_ref_consistency() {
167-
let key_hex = "0053A6F94C9FF24598EB000000000000";
168-
let iv_hex = "0D74DB42A91077DE45AC000000000000";
169-
let expected_out_hex = "D1F0303482061111";
170-
171-
let key_bits = parse_hex_to_bits(key_hex);
172-
let iv_bits = parse_hex_to_bits(iv_hex);
173-
174-
let mut kreyvium = KreyviumRef::new(&key_bits, &iv_bits);
175-
let mut output_bits = Vec::new();
176-
for _ in 0..64 {
177-
output_bits.push(kreyvium.next_bit());
178-
}
179-
180-
let hex_string = bits_to_hex(&output_bits);
181-
assert_eq!(hex_string, expected_out_hex);
182-
}
183-
18451
/// Tests the full FHE Kreyvium implementation against a known standard test vector.
18552
/// This verifies that the homomorphic circuit produces the exact same hex output as standard
18653
/// Kreyvium.
@@ -250,11 +117,15 @@ where
250117
let ct_key = encrypt_bits(&cks, &key_bits);
251118
let ct_iv = encrypt_bits(&cks, &iv_bits);
252119

253-
let mut ref_kreyvium = KreyviumRef::new(&key_bits, &iv_bits);
254-
let mut cpu_output = Vec::with_capacity(num_steps);
255-
for _ in 0..num_steps {
256-
cpu_output.push(ref_kreyvium.next_bit());
257-
}
120+
let key_bool: [bool; 128] = std::array::from_fn(|i| key_bits[i] == 1);
121+
let iv_bool: [bool; 128] = std::array::from_fn(|i| iv_bits[i] == 1);
122+
let mut ref_kreyvium = KreyviumPlainStream::new(key_bool, iv_bool);
123+
let ref_bytes = ref_kreyvium.next_keystream_bits(num_steps);
124+
let cpu_output = ref_bytes
125+
.iter()
126+
.flat_map(|&b| (0..8).map(move |i| (b >> i) & 1))
127+
.take(num_steps)
128+
.collect::<Vec<_>>();
258129

259130
let output_radix = executor.execute((&ct_key, &ct_iv, num_steps)).unwrap();
260131
let fhe_output = decrypt_bits(&cks, &output_radix);
@@ -264,9 +135,8 @@ where
264135
}
265136
}
266137

267-
// Integration test verifying the correctness of the stateful FHE Kreyvium implementation by
268-
// comparing consecutive keystream chunks against a cleartext CPU reference.
269-
//
138+
/// Integration test verifying the correctness of the stateful FHE Kreyvium implementation by
139+
/// comparing consecutive keystream chunks against a cleartext CPU reference.
270140
pub fn kreyvium_stateful_comparison_test<P, E>(param: P, mut executor: E)
271141
where
272142
P: Into<TestParameters>,
@@ -297,11 +167,15 @@ where
297167
let ct_key = encrypt_bits(&cks, &key_bits);
298168
let ct_iv = encrypt_bits(&cks, &iv_bits);
299169

300-
let mut ref_kreyvium = KreyviumRef::new(&key_bits, &iv_bits);
301-
let mut cpu_output = Vec::with_capacity(total_steps);
302-
for _ in 0..total_steps {
303-
cpu_output.push(ref_kreyvium.next_bit());
304-
}
170+
let key_bool: [bool; 128] = std::array::from_fn(|i| key_bits[i] == 1);
171+
let iv_bool: [bool; 128] = std::array::from_fn(|i| iv_bits[i] == 1);
172+
let mut ref_kreyvium = KreyviumPlainStream::new(key_bool, iv_bool);
173+
let ref_bytes = ref_kreyvium.next_keystream_bits(total_steps);
174+
let cpu_output = ref_bytes
175+
.iter()
176+
.flat_map(|&b| (0..8).map(move |i| (b >> i) & 1))
177+
.take(total_steps)
178+
.collect::<Vec<_>>();
305179

306180
let output_radixes = executor.execute((&ct_key, &ct_iv, &step_chunks)).unwrap();
307181

tfhe/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ pub mod integer;
100100
/// cbindgen:ignore
101101
pub mod shortint;
102102

103+
#[cfg(feature = "shortint")]
104+
pub mod transciphering;
105+
103106
#[cfg(feature = "pbs-stats")]
104107
pub use shortint::server_key::pbs_stats::*;
105108

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use tfhe_versionable::VersionsDispatch;
2+
3+
use crate::transciphering::{StreamCipherKind, StreamCiphertext};
4+
5+
#[derive(VersionsDispatch)]
6+
pub enum StreamCipherKindVersions {
7+
V0(StreamCipherKind),
8+
}
9+
10+
#[derive(VersionsDispatch)]
11+
pub enum StreamCiphertextVersions {
12+
V0(StreamCiphertext),
13+
}

0 commit comments

Comments
 (0)