Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 100 additions & 89 deletions lib/std/crypto/kangarootwelve.zig
Original file line number Diff line number Diff line change
Expand Up @@ -230,58 +230,61 @@ fn keccakP1600timesN(comptime N: usize, states: *[5][5]@Vector(N, u64)) void {
break :blk offsets;
};

inline for (RC) |rc| {
// θ (theta)
var C: [5]@Vector(N, u64) = undefined;
inline for (0..5) |x| {
C[x] = states[x][0] ^ states[x][1] ^ states[x][2] ^ states[x][3] ^ states[x][4];
}
var round: usize = 0;
while (round < 12) : (round += 2) {
inline for (0..2) |i| {
// θ (theta)
var C: [5]@Vector(N, u64) = undefined;
inline for (0..5) |x| {
C[x] = states[x][0] ^ states[x][1] ^ states[x][2] ^ states[x][3] ^ states[x][4];
}

var D: [5]@Vector(N, u64) = undefined;
inline for (0..5) |x| {
D[x] = C[(x + 4) % 5] ^ rol64Vec(N, C[(x + 1) % 5], 1);
}
var D: [5]@Vector(N, u64) = undefined;
inline for (0..5) |x| {
D[x] = C[(x + 4) % 5] ^ rol64Vec(N, C[(x + 1) % 5], 1);
}

// Apply D to all lanes
inline for (0..5) |x| {
states[x][0] ^= D[x];
states[x][1] ^= D[x];
states[x][2] ^= D[x];
states[x][3] ^= D[x];
states[x][4] ^= D[x];
}
// Apply D to all lanes
inline for (0..5) |x| {
states[x][0] ^= D[x];
states[x][1] ^= D[x];
states[x][2] ^= D[x];
states[x][3] ^= D[x];
states[x][4] ^= D[x];
}

// ρ (rho) and π (pi) - optimized with pre-computed offsets
var current = states[1][0];
var px: usize = 1;
var py: usize = 0;
inline for (rho_offsets) |rot| {
const next_y = (2 * px + 3 * py) % 5;
const next = states[py][next_y];
states[py][next_y] = rol64Vec(N, current, rot);
current = next;
px = py;
py = next_y;
}
// ρ (rho) and π (pi) - optimized with pre-computed offsets
var current = states[1][0];
var px: usize = 1;
var py: usize = 0;
inline for (rho_offsets) |rot| {
const next_y = (2 * px + 3 * py) % 5;
const next = states[py][next_y];
states[py][next_y] = rol64Vec(N, current, rot);
current = next;
px = py;
py = next_y;
}

// χ (chi) - optimized with better register usage
inline for (0..5) |y| {
const t0 = states[0][y];
const t1 = states[1][y];
const t2 = states[2][y];
const t3 = states[3][y];
const t4 = states[4][y];

states[0][y] = t0 ^ (~t1 & t2);
states[1][y] = t1 ^ (~t2 & t3);
states[2][y] = t2 ^ (~t3 & t4);
states[3][y] = t3 ^ (~t4 & t0);
states[4][y] = t4 ^ (~t0 & t1);
}
// χ (chi) - optimized with better register usage
inline for (0..5) |y| {
const t0 = states[0][y];
const t1 = states[1][y];
const t2 = states[2][y];
const t3 = states[3][y];
const t4 = states[4][y];

states[0][y] = t0 ^ (~t1 & t2);
states[1][y] = t1 ^ (~t2 & t3);
states[2][y] = t2 ^ (~t3 & t4);
states[3][y] = t3 ^ (~t4 & t0);
states[4][y] = t4 ^ (~t0 & t1);
}

// ι (iota)
const rc_splat: @Vector(N, u64) = @splat(rc);
states[0][0] ^= rc_splat;
// ι (iota)
const rc_splat: @Vector(N, u64) = @splat(RC[round + i]);
states[0][0] ^= rc_splat;
}
}
}

Expand Down Expand Up @@ -323,46 +326,49 @@ fn keccakP(state: *[200]u8) void {
}

// Apply 12 rounds
inline for (RC) |rc| {
// θ
var C: [5]u64 = undefined;
inline for (0..5) |x| {
C[x] = lanes[x][0] ^ lanes[x][1] ^ lanes[x][2] ^ lanes[x][3] ^ lanes[x][4];
}
var D: [5]u64 = undefined;
inline for (0..5) |x| {
D[x] = C[(x + 4) % 5] ^ std.math.rotl(u64, C[(x + 1) % 5], 1);
}
inline for (0..5) |x| {
inline for (0..5) |y| {
lanes[x][y] ^= D[x];
var round: usize = 0;
while (round < 12) : (round += 2) {
inline for (0..2) |i| {
// θ
var C: [5]u64 = undefined;
inline for (0..5) |x| {
C[x] = lanes[x][0] ^ lanes[x][1] ^ lanes[x][2] ^ lanes[x][3] ^ lanes[x][4];
}
var D: [5]u64 = undefined;
inline for (0..5) |x| {
D[x] = C[(x + 4) % 5] ^ std.math.rotl(u64, C[(x + 1) % 5], 1);
}
inline for (0..5) |x| {
inline for (0..5) |y| {
lanes[x][y] ^= D[x];
}
}
}

// ρ and π
var current = lanes[1][0];
var px: usize = 1;
var py: usize = 0;
inline for (0..24) |t| {
const temp = lanes[py][(2 * px + 3 * py) % 5];
const rot_amount = ((t + 1) * (t + 2) / 2) % 64;
lanes[py][(2 * px + 3 * py) % 5] = std.math.rotl(u64, current, @as(u6, @intCast(rot_amount)));
current = temp;
const temp_x = py;
py = (2 * px + 3 * py) % 5;
px = temp_x;
}
// ρ and π
var current = lanes[1][0];
var px: usize = 1;
var py: usize = 0;
inline for (0..24) |t| {
const temp = lanes[py][(2 * px + 3 * py) % 5];
const rot_amount = ((t + 1) * (t + 2) / 2) % 64;
lanes[py][(2 * px + 3 * py) % 5] = std.math.rotl(u64, current, @as(u6, @intCast(rot_amount)));
current = temp;
const temp_x = py;
py = (2 * px + 3 * py) % 5;
px = temp_x;
}

// χ
inline for (0..5) |y| {
const T = [5]u64{ lanes[0][y], lanes[1][y], lanes[2][y], lanes[3][y], lanes[4][y] };
inline for (0..5) |x| {
lanes[x][y] = T[x] ^ (~T[(x + 1) % 5] & T[(x + 2) % 5]);
// χ
inline for (0..5) |y| {
const T = [5]u64{ lanes[0][y], lanes[1][y], lanes[2][y], lanes[3][y], lanes[4][y] };
inline for (0..5) |x| {
lanes[x][y] = T[x] ^ (~T[(x + 1) % 5] & T[(x + 2) % 5]);
}
}
}

// ι
lanes[0][0] ^= rc;
// ι
lanes[0][0] ^= RC[round + i];
}
}

// Store lanes back to state
Expand Down Expand Up @@ -759,32 +765,37 @@ fn ktMultiThreaded(
const all_scratch = try allocator.alloc(u8, thread_count * scratch_size);
defer allocator.free(all_scratch);

var group: Io.Group = .init;
const contexts = try allocator.alloc(LeafBatchContext, thread_count);
defer allocator.free(contexts);

var leaves_assigned: usize = 0;
var thread_idx: usize = 0;
var context_count: usize = 0;

while (leaves_assigned < total_leaves) {
const batch_count = @min(leaves_per_thread, total_leaves - leaves_assigned);
const batch_start = chunk_size + leaves_assigned * chunk_size;
const cvs_offset = leaves_assigned * cv_size;

const ctx = LeafBatchContext{
contexts[context_count] = LeafBatchContext{
.output_cvs = cvs[cvs_offset .. cvs_offset + batch_count * cv_size],
.batch_start = batch_start,
.batch_count = batch_count,
.view = view,
.scratch_buffer = all_scratch[thread_idx * scratch_size .. (thread_idx + 1) * scratch_size],
.scratch_buffer = all_scratch[context_count * scratch_size .. (context_count + 1) * scratch_size],
.total_len = total_len,
};

leaves_assigned += batch_count;
context_count += 1;
}

var group: Io.Group = .init;
for (contexts[0..context_count]) |ctx| {
group.async(io, struct {
fn process(c: LeafBatchContext) void {
processLeafBatch(Variant, c);
}
}.process, .{ctx});

leaves_assigned += batch_count;
thread_idx += 1;
}

// Wait for all threads to complete
Expand Down