Skip to content

Commit 2cede38

Browse files
committed
Implement NTT based on 2n-th root of unity
1 parent 2267517 commit 2cede38

File tree

2 files changed

+319
-0
lines changed

2 files changed

+319
-0
lines changed

src/ntt.zig

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
const std = @import("std");
2+
const testing = std.testing;
3+
const Allocator = std.mem.Allocator;
4+
5+
const utils = @import("utils.zig");
6+
7+
/// Instance of a Number Theoretic Transform based on 2n-th root of unity.
8+
/// This implementation is using the polynomial quotient ring Z_q[x]/(x^n + 1)
9+
/// where q is the coefficient modulus and n is the degree of the cyclotomic
10+
/// polynomial.
11+
/// The NTT algorithm is based on the paper "Low-Cost and Area-Efficient FPGA
12+
// Implementations of Lattice-Based Cryptography" by Aysu et al.
13+
// See: https://schaumont.dyn.wpi.edu/schaum/pdf/papers/2013hostb.pdf
14+
pub const NTT = struct {
15+
const Self = @This();
16+
17+
/// Coefficient modulus.
18+
q: i64,
19+
/// Degree of cyclotomic polynomial.
20+
n: i64,
21+
/// Inverse of degree of cyclotomic polynomial.
22+
n_inverse: i64,
23+
/// Powers of psi from psi^0 to psi^(n - 1).
24+
psi_powers: []i64,
25+
/// Powers of psi^-1 = ipsi from ipsi^0 to ipsi^(n - 1).
26+
psi_inverse_powers: []i64,
27+
/// Allocator used for internal memory allocations.
28+
allocator: Allocator,
29+
30+
/// Initializes a new NTT instance with the given coefficient modulus and
31+
/// cyclotomic polynomial degree (which MUST be a power of 2).
32+
pub fn init(allocator: Allocator, q: i64, n: i64) !Self {
33+
// Check if degree of cyclotomic polynomial is a power of 2.
34+
if (!utils.isPowerOfTwo(n)) {
35+
return error.InvalidDegree;
36+
}
37+
38+
// Compute psi, the 2n-th root of unity and its inverse.
39+
const psi = try utils.findRootOfUnity(allocator, 2 * n, q);
40+
const psi_inverse = try utils.modInv(psi, q);
41+
42+
// Compute the inverse of the degree of the cyclotomic polynomial.
43+
const n_inverse = try utils.modInv(n, q);
44+
45+
// Compute powers of psi.
46+
var powers = try std.ArrayList(i64).initCapacity(allocator, @intCast(n));
47+
defer powers.deinit();
48+
try powers.append(1);
49+
for (1..@intCast(n)) |_| {
50+
const power = @mod(powers.getLast() * psi, q);
51+
try powers.append(power);
52+
}
53+
const psi_powers = try powers.toOwnedSlice();
54+
55+
// Compute powers of psi^-1.
56+
var inverse_powers = try std.ArrayList(i64).initCapacity(allocator, @intCast(n));
57+
defer inverse_powers.deinit();
58+
try inverse_powers.append(1);
59+
for (1..@intCast(n)) |_| {
60+
const power = @mod(inverse_powers.getLast() * psi_inverse, q);
61+
try inverse_powers.append(power);
62+
}
63+
const psi_inverse_powers = try inverse_powers.toOwnedSlice();
64+
65+
return NTT{
66+
.q = q,
67+
.n = n,
68+
.n_inverse = n_inverse,
69+
.psi_powers = psi_powers,
70+
.psi_inverse_powers = psi_inverse_powers,
71+
.allocator = allocator,
72+
};
73+
}
74+
75+
/// Release all allocated memory.
76+
pub fn deinit(self: Self) void {
77+
self.allocator.free(self.psi_powers);
78+
self.allocator.free(self.psi_inverse_powers);
79+
}
80+
81+
/// Runs a forward pass of NTT with the given coefficients.
82+
/// The caller owns the returned memory.
83+
pub fn fwd(self: Self, coefficients: []const i64) ![]const i64 {
84+
// Length of coefficients must equal the degree of the cyclotomic polynomial.
85+
if (coefficients.len != self.n) {
86+
return error.InvalidLength;
87+
}
88+
89+
var input = try self.allocator.alloc(i64, coefficients.len);
90+
defer self.allocator.free(input);
91+
92+
for (0..coefficients.len) |i| {
93+
input[i] = @mod(coefficients[i] * self.psi_powers[i], self.q);
94+
}
95+
96+
return self.ntt(input, self.psi_powers);
97+
}
98+
99+
/// Runs INTT (Inverse NTT) with the given coefficients.
100+
/// The caller owns the returned memory.
101+
pub fn inv(self: Self, coefficients: []const i64) ![]const i64 {
102+
// Length of coefficients must equal the degree of the cyclotomic polynomial.
103+
if (coefficients.len != self.n) {
104+
return error.InvalidLength;
105+
}
106+
107+
const unscaled_result = try self.ntt(coefficients, self.psi_inverse_powers);
108+
defer self.allocator.free(unscaled_result);
109+
110+
var result = try self.allocator.alloc(i64, coefficients.len);
111+
112+
for (0..coefficients.len) |i| {
113+
result[i] = @mod(unscaled_result[i] * self.psi_inverse_powers[i] * self.n_inverse, self.q);
114+
}
115+
116+
return result;
117+
}
118+
119+
/// Runs an iterative version of NTT with the given coefficients and twiddles.
120+
/// The caller owns the returned memory.
121+
fn ntt(self: Self, coefficients: []const i64, twiddles: []const i64) ![]const i64 {
122+
// Length of coefficients and twiddles must be the same.
123+
if (coefficients.len != twiddles.len) {
124+
return error.InvalidLength;
125+
}
126+
127+
const log2_n = std.math.log2_int(usize, @intCast(self.n));
128+
const reversed = try utils.bitReverseSlice(self.allocator, coefficients);
129+
defer self.allocator.free(reversed);
130+
131+
var result = try self.allocator.dupe(i64, reversed);
132+
133+
for (0..log2_n) |i| {
134+
var temp_twiddle: i64 = 1;
135+
var final_twiddle: i64 = 1;
136+
const twiddle = twiddles[i + 1];
137+
138+
const in_1 = @as(usize, 1) << @intCast(i); // 2^i
139+
const in_2 = @as(usize, 1) << @intCast(i + 1); // 2^(i + 1)
140+
const in_3 = @as(usize, @intCast(self.n)) / in_2; // n / (2^(i + 1))
141+
142+
for (0..in_1) |j| {
143+
for (0..in_3) |t| {
144+
const index_1 = (t * in_2) + j; // (t * 2^(i + 1)) + j
145+
const index_2 = index_1 + in_1; // (t * 2^(i + 1)) + j + 2^i
146+
147+
const c = result[index_1];
148+
const d = result[index_2];
149+
150+
const butterfly_plus = @mod(c + (final_twiddle * d), self.q);
151+
const butterfly_minus = @mod(c - (final_twiddle * d), self.q);
152+
153+
result[index_1] = butterfly_plus;
154+
result[index_2] = butterfly_minus;
155+
156+
temp_twiddle = @mod(temp_twiddle * twiddle, self.q);
157+
}
158+
final_twiddle = temp_twiddle;
159+
}
160+
}
161+
162+
return result;
163+
}
164+
};
165+
166+
test "ntt - init" {
167+
const allocator = testing.allocator;
168+
169+
{
170+
const q = 7681;
171+
const n = 5;
172+
173+
const expected = error.InvalidDegree;
174+
const result = NTT.init(allocator, q, n);
175+
176+
try testing.expectError(expected, result);
177+
}
178+
}
179+
180+
test "ntt - fwd" {
181+
const allocator = testing.allocator;
182+
183+
{
184+
const q = 7681;
185+
const n = 4;
186+
187+
const ntt = try NTT.init(allocator, q, n);
188+
defer ntt.deinit();
189+
190+
const coefficients = [_]i64{ 1, 2, 3, 4 };
191+
const expected = [_]i64{ 1467, 2807, 3471, 7621 };
192+
193+
const result = try ntt.fwd(&coefficients);
194+
defer allocator.free(result);
195+
196+
try testing.expectEqualSlices(i64, &expected, result);
197+
}
198+
199+
{
200+
const q = 7681;
201+
const n = 4;
202+
203+
const ntt = try NTT.init(allocator, q, n);
204+
defer ntt.deinit();
205+
206+
const coefficients = [_]i64{ 5, 6, 7, 8 };
207+
const expected = [_]i64{ 2489, 7489, 6478, 6607 };
208+
209+
const result = try ntt.fwd(&coefficients);
210+
defer allocator.free(result);
211+
212+
try testing.expectEqualSlices(i64, &expected, result);
213+
}
214+
215+
{
216+
const q = 7681;
217+
const n = 4;
218+
219+
const ntt = try NTT.init(allocator, q, n);
220+
defer ntt.deinit();
221+
222+
const coefficients = [_]i64{ 1, 2, 3 };
223+
224+
const expected = error.InvalidLength;
225+
const result = ntt.fwd(&coefficients);
226+
227+
try testing.expectError(expected, result);
228+
}
229+
}
230+
231+
test "ntt - inv" {
232+
const allocator = testing.allocator;
233+
234+
{
235+
const q = 7681;
236+
const n = 4;
237+
238+
const ntt = try NTT.init(allocator, q, n);
239+
defer ntt.deinit();
240+
241+
const coefficients = [_]i64{ 1467, 2807, 3471, 7621 };
242+
const expected = [_]i64{ 1, 2, 3, 4 };
243+
244+
const result = try ntt.inv(&coefficients);
245+
defer allocator.free(result);
246+
247+
try testing.expectEqualSlices(i64, &expected, result);
248+
}
249+
250+
{
251+
const q = 7681;
252+
const n = 4;
253+
254+
const ntt = try NTT.init(allocator, q, n);
255+
defer ntt.deinit();
256+
257+
const coefficients = [_]i64{ 2489, 7489, 6478, 6607 };
258+
const expected = [_]i64{ 5, 6, 7, 8 };
259+
260+
const result = try ntt.inv(&coefficients);
261+
defer allocator.free(result);
262+
263+
try testing.expectEqualSlices(i64, &expected, result);
264+
}
265+
266+
{
267+
const q = 7681;
268+
const n = 4;
269+
270+
const ntt = try NTT.init(allocator, q, n);
271+
defer ntt.deinit();
272+
273+
const coefficients = [_]i64{ 1, 2, 3 };
274+
275+
const expected = error.InvalidLength;
276+
const result = ntt.inv(&coefficients);
277+
278+
try testing.expectError(expected, result);
279+
}
280+
}
281+
282+
test "ntt - ntt" {
283+
const allocator = testing.allocator;
284+
285+
{
286+
const q = 7681;
287+
const n = 4;
288+
289+
const ntt = try NTT.init(allocator, q, n);
290+
defer ntt.deinit();
291+
292+
const twiddles = ntt.psi_powers;
293+
const coefficients = [_]i64{ 1, 2, 3, 4 };
294+
const expected = [_]i64{ 10, 913, 7679, 6764 };
295+
296+
const result = try ntt.ntt(&coefficients, twiddles);
297+
defer allocator.free(result);
298+
299+
try testing.expectEqualSlices(i64, &expected, result);
300+
}
301+
302+
{
303+
const q = 7681;
304+
const n = 4;
305+
306+
const ntt = try NTT.init(allocator, q, n);
307+
defer ntt.deinit();
308+
309+
const twiddles = ntt.psi_powers;
310+
const coefficients = [_]i64{ 1, 2 };
311+
312+
const expected = error.InvalidLength;
313+
const result = ntt.ntt(&coefficients, twiddles);
314+
315+
try testing.expectError(expected, result);
316+
}
317+
}

src/root.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
//! start with main.zig instead.
44
const std = @import("std");
55
const testing = std.testing;
6+
7+
pub const ntt = @import("ntt.zig");
68
pub const utils = @import("utils.zig");
79

810
pub export fn add(a: i32, b: i32) i32 {

0 commit comments

Comments
 (0)