1use std::array;
4
5use anyhow::anyhow;
6use binius_core::{
7 oracle::{OracleId, ShiftVariant},
8 transparent::multilinear_extension::MultilinearExtensionTransparent,
9};
10use binius_field::{
11 as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, BinaryField64b, Field,
12 PackedField, TowerField,
13};
14use binius_macros::arith_expr;
15use bytemuck::{pod_collect_to_vec, Pod};
16
17use crate::{
18 builder::{
19 types::{F, U},
20 ConstraintSystemBuilder,
21 },
22 transparent::step_down,
23};
24
25#[derive(Default, Clone, Copy)]
26pub struct KeccakfState(pub [u64; STATE_SIZE]);
27
28pub struct KeccakfOracles {
29 pub input: [OracleId; STATE_SIZE],
30 pub output: [OracleId; STATE_SIZE],
31}
32
33pub fn keccakf(
34 builder: &mut ConstraintSystemBuilder,
35 input_witness: &Option<impl AsRef<[KeccakfState]>>,
36 log_size: usize,
37) -> Result<KeccakfOracles, anyhow::Error> {
38 let internal_log_size = log_size + LOG_BIT_ROWS_PER_PERMUTATION;
39 let round_consts_single: [OracleId; ROUNDS_PER_STATE_ROW] =
40 array_util::try_from_fn(|round_within_row| {
41 let round_within_row_rc: [_; STATE_ROWS_PER_PERMUTATION] =
42 array::from_fn(|row_within_perm| {
43 KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_perm + round_within_row]
44 });
45
46 let packed_vec = into_packed_vec::<PackedType<U, BinaryField1b>>(&round_within_row_rc);
47 let rc_single_mle =
48 MultilinearExtensionTransparent::<_, PackedType<U, F>, _>::from_values(packed_vec)?;
49 builder.add_transparent("round_consts_single", rc_single_mle)
50 })?;
51
52 let round_consts: [OracleId; ROUNDS_PER_STATE_ROW] =
53 array_util::try_from_fn(|round_within_row| {
54 builder.add_repeating(
55 "round_consts",
56 round_consts_single[round_within_row],
57 internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
58 )
59 })?;
60
61 if let Some(witness) = builder.witness() {
62 let mut round_consts_single =
63 round_consts_single.map(|id| witness.new_column::<BinaryField1b>(id));
64 let mut round_consts = round_consts.map(|id| witness.new_column::<BinaryField1b>(id));
65
66 let round_consts_single_u64 = round_consts_single
67 .each_mut()
68 .map(|col| col.as_mut_slice::<u64>());
69 let round_consts_u64 = round_consts.each_mut().map(|col| col.as_mut_slice::<u64>());
70
71 for row_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
72 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
73 round_consts_single_u64[round_within_row][row_within_permutation] =
74 KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
75 }
76 }
77
78 for state_row_idx in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_STATE_ROW) {
79 let row_within_permutation = state_row_idx % STATE_ROWS_PER_PERMUTATION;
80 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
81 round_consts_u64[round_within_row][state_row_idx] =
82 KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
83 }
84 }
85 }
86
87 let selector_single = step_down(
88 builder,
89 "selector_single",
90 LOG_BIT_ROWS_PER_PERMUTATION,
91 BIT_ROWS_PER_PERMUTATION - BIT_ROWS_PER_STATE_ROW,
92 )?;
93
94 let selector = builder.add_repeating(
95 "selector",
96 selector_single,
97 internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
98 )?;
99
100 let state: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW + 1] = array::from_fn(|_| {
101 builder.add_committed_multiple("state_in", internal_log_size, BinaryField1b::TOWER_LEVEL)
102 });
103
104 let state_in = state[0];
105
106 let state_out = state[ROUNDS_PER_STATE_ROW];
107
108 let packed_state_in: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
109 builder.add_packed("packed state input", state_in[xy], LOG_BIT_ROWS_PER_STATE_ROW)
110 })?;
111
112 let input: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
113 builder.add_projected(
114 "packed projected state input",
115 packed_state_in[xy],
116 vec![F::ZERO; LOG_STATE_ROWS_PER_PERMUTATION],
117 0,
118 )
119 })?;
120
121 let packed_state_out: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
122 builder.add_packed("packed state output", state_out[xy], LOG_BIT_ROWS_PER_STATE_ROW)
123 })?;
124
125 let output: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
126 builder.add_projected(
127 "output",
128 packed_state_out[xy],
129 vec![Field::ONE; LOG_STATE_ROWS_PER_PERMUTATION],
130 0,
131 )
132 })?;
133
134 let c: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array_util::try_from_fn(|round_within_row| {
135 array_util::try_from_fn(|x| {
136 builder.add_linear_combination(
137 "c",
138 internal_log_size,
139 array::from_fn::<_, 5, _>(|offset| {
140 (state[round_within_row][x + 5 * offset], Field::ONE)
141 }),
142 )
143 })
144 })?;
145
146 let c_shift: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] =
147 array_util::try_from_fn(|round_within_row| {
148 array_util::try_from_fn(|x| {
149 builder.add_shifted(
150 format!("c[{x}]"),
151 c[round_within_row][x],
152 1,
153 6,
154 ShiftVariant::CircularLeft,
155 )
156 })
157 })?;
158
159 let d: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array_util::try_from_fn(|round_within_row| {
160 array_util::try_from_fn(|x| {
161 builder.add_linear_combination(
162 "d",
163 internal_log_size,
164 [
165 (c[round_within_row][(x + 4) % 5], Field::ONE),
166 (c_shift[round_within_row][(x + 1) % 5], Field::ONE),
167 ],
168 )
169 })
170 })?;
171
172 let a_theta: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
173 array_util::try_from_fn(|round_within_row| {
174 array_util::try_from_fn(|xy| {
175 let x = xy % 5;
176 builder.add_linear_combination(
177 format!("a_theta[{xy}]"),
178 internal_log_size,
179 [
180 (state[round_within_row][xy], Field::ONE),
181 (d[round_within_row][x], Field::ONE),
182 ],
183 )
184 })
185 })?;
186
187 let b: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
188 array_util::try_from_fn(|round_within_row| {
189 array_util::try_from_fn(|xy| {
190 if xy == 0 {
191 Ok(a_theta[round_within_row][0])
192 } else {
193 builder.add_shifted(
194 format!("b[{xy}]"),
195 a_theta[round_within_row][PI[xy]],
196 RHO[xy] as usize,
197 6,
198 ShiftVariant::CircularLeft,
199 )
200 }
201 })
202 })?;
203
204 let next_state_in: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
205 builder.add_shifted(
206 format!("next_state_in[{xy}]"),
207 state_in[xy],
208 64,
209 LOG_BIT_ROWS_PER_PERMUTATION,
210 ShiftVariant::LogicalRight,
211 )
212 })?;
213
214 if let Some(witness) = builder.witness() {
215 let input_witness = input_witness
216 .as_ref()
217 .ok_or_else(|| anyhow!("builder witness available and input witness is not"))?
218 .as_ref();
219
220 let mut input = input.map(|id| witness.new_column::<BinaryField64b>(id));
221
222 let mut packed_state_in =
223 packed_state_in.map(|id| witness.new_column::<BinaryField64b>(id));
224
225 let mut packed_state_out =
226 packed_state_out.map(|id| witness.new_column::<BinaryField64b>(id));
227
228 let mut output = output.map(|id| witness.new_column::<BinaryField64b>(id));
229
230 let mut state = state
231 .map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
232
233 let mut c =
234 c.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
235 let mut d =
236 d.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
237 let mut c_shift = c_shift
238 .map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
239 let mut a_theta = a_theta
240 .map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
241 let mut b =
242 b.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
243 let mut next_state_in = next_state_in.map(|id| witness.new_column::<BinaryField1b>(id));
244
245 let mut selector_single = witness.new_column::<BinaryField1b>(selector_single);
246
247 let mut selector = witness.new_column::<BinaryField1b>(selector);
248
249 let input_u64 = input.each_mut().map(|col| col.as_mut_slice::<u64>());
250
251 let packed_state_in_u64 = packed_state_in
252 .each_mut()
253 .map(|col| col.as_mut_slice::<u64>());
254
255 let mut packed_state_out_u64 = packed_state_out
256 .each_mut()
257 .map(|col| col.as_mut_slice::<u64>());
258
259 let output_u64 = output.each_mut().map(|col| col.as_mut_slice::<u64>());
260
261 let state_u64 = state
262 .each_mut()
263 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
264 let c_u64 = c
265 .each_mut()
266 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
267 let d_u64 = d
268 .each_mut()
269 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
270 let c_shift_u64 = c_shift
271 .each_mut()
272 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
273 let a_theta_u64 = a_theta
274 .each_mut()
275 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
276 let b_u64 = b
277 .each_mut()
278 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
279 let next_state_in_u64 = next_state_in.each_mut().map(|col| col.as_mut_slice());
280 let selector_single_u64 = selector_single.as_mut_slice::<u64>();
281 let selector_u64 = selector.as_mut_slice();
282
283 for selector_single_u64_row in selector_single_u64
285 .iter_mut()
286 .take(STATE_ROWS_PER_PERMUTATION - 1)
287 {
288 *selector_single_u64_row = u64::MAX;
289 }
290
291 for perm_i in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION) {
294 let first_state_row_idx_in_perm = perm_i << LOG_STATE_ROWS_PER_PERMUTATION;
295
296 let input_this_perm = input_witness.get(perm_i).copied().unwrap_or_default().0;
297
298 for xy in 0..STATE_SIZE {
299 input_u64[xy][perm_i] = input_this_perm[xy];
300 }
301
302 let expected_output_this_perm = {
303 let mut output = input_this_perm;
304 tiny_keccak::keccakf(&mut output);
305 output
306 };
307
308 for xy in 0..STATE_SIZE {
309 output_u64[xy][perm_i] = expected_output_this_perm[xy];
310 }
311
312 for xy in 0..STATE_SIZE {
314 state_u64[0][xy][first_state_row_idx_in_perm] = input_this_perm[xy];
315 packed_state_in_u64[xy][first_state_row_idx_in_perm] = input_this_perm[xy];
316 }
317
318 for row_idx_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
319 let state_row_idx = first_state_row_idx_in_perm | row_idx_within_permutation;
320 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
322 let keccakf_rc = KECCAKF_RC
323 [ROUNDS_PER_STATE_ROW * row_idx_within_permutation + round_within_row];
324
325 for x in 0..5 {
326 c_u64[round_within_row][x][state_row_idx] = (0..5).fold(0, |acc, y| {
327 acc ^ state_u64[round_within_row][x + 5 * y][state_row_idx]
328 });
329 c_shift_u64[round_within_row][x][state_row_idx] =
330 c_u64[round_within_row][x][state_row_idx].rotate_left(1);
331 }
332
333 for x in 0..5 {
334 d_u64[round_within_row][x][state_row_idx] = c_u64[round_within_row]
335 [(x + 4) % 5][state_row_idx]
336 ^ c_shift_u64[round_within_row][(x + 1) % 5][state_row_idx];
337 }
338
339 for x in 0..5 {
340 for y in 0..5 {
341 a_theta_u64[round_within_row][x + 5 * y][state_row_idx] = state_u64
342 [round_within_row][x + 5 * y][state_row_idx]
343 ^ d_u64[round_within_row][x][state_row_idx];
344 }
345 }
346
347 for xy in 0..STATE_SIZE {
348 b_u64[round_within_row][xy][state_row_idx] = a_theta_u64[round_within_row]
349 [PI[xy]][state_row_idx]
350 .rotate_left(RHO[xy]);
351 }
352
353 for x in 0..5 {
354 for y in 0..5 {
355 let b0 = b_u64[round_within_row][x + 5 * y][state_row_idx];
356 let b1 = b_u64[round_within_row][(x + 1) % 5 + 5 * y][state_row_idx];
357 let b2 = b_u64[round_within_row][(x + 2) % 5 + 5 * y][state_row_idx];
358
359 state_u64[round_within_row + 1][x + 5 * y][state_row_idx] =
360 b0 ^ (!b1 & b2);
361 }
362 }
363
364 state_u64[round_within_row + 1][0][state_row_idx] ^= keccakf_rc;
365 }
366
367 for (xy, packed_state_out_u64_row) in packed_state_out_u64.iter_mut().enumerate() {
368 packed_state_out_u64_row[state_row_idx] =
369 state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
370 }
371
372 if row_idx_within_permutation < (STATE_ROWS_PER_PERMUTATION - 1) {
373 for xy in 0..STATE_SIZE {
374 let this_row_output = state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
375
376 state_u64[0][xy][state_row_idx + 1] = this_row_output;
377 packed_state_in_u64[xy][state_row_idx + 1] = this_row_output;
378 next_state_in_u64[xy][state_row_idx] = this_row_output;
379 }
380 selector_u64[state_row_idx] = u64::MAX;
381 }
382 }
383
384 let last_state_row_idx_in_perm =
385 first_state_row_idx_in_perm + (STATE_ROWS_PER_PERMUTATION - 1);
386
387 let actual_output_this_perm =
388 array::from_fn(|i| state_u64[ROUNDS_PER_STATE_ROW][i][last_state_row_idx_in_perm]);
389
390 assert_eq!(expected_output_this_perm, actual_output_this_perm);
391 }
392 }
393
394 let chi_iota = arith_expr!([s, b0, b1, b2, rc] = s - (rc + b0 + (1 - b1) * b2));
395 let chi = arith_expr!([s, b0, b1, b2] = s - (b0 + (1 - b1) * b2));
396 for x in 0..5 {
397 for y in 0..5 {
398 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
399 let this_round_output = state[round_within_row + 1][x + 5 * y];
400
401 if x == 0 && y == 0 {
402 builder.assert_zero(
403 format!("chi_iota(round_within_row={round_within_row}, x={x}, y={y})"),
404 [
405 this_round_output,
406 b[round_within_row][x + 5 * y],
407 b[round_within_row][(x + 1) % 5 + 5 * y],
408 b[round_within_row][(x + 2) % 5 + 5 * y],
409 round_consts[round_within_row],
410 ],
411 chi_iota.clone().convert_field(),
412 );
413 } else {
414 builder.assert_zero(
415 format!("chi(round_within_row={round_within_row}, x={x}, y={y})"),
416 [
417 this_round_output,
418 b[round_within_row][x + 5 * y],
419 b[round_within_row][(x + 1) % 5 + 5 * y],
420 b[round_within_row][(x + 2) % 5 + 5 * y],
421 ],
422 chi.clone().convert_field(),
423 )
424 }
425 }
426 }
427 }
428
429 let selector_consistency =
430 arith_expr!([state_out, next_state_in, select] = (state_out - next_state_in) * select);
431
432 for xy in 0..STATE_SIZE {
433 builder.assert_zero(
434 format!("next_state_in_is_state_out_{xy}"),
435 [state_out[xy], next_state_in[xy], selector],
436 selector_consistency.clone().convert_field(),
437 )
438 }
439
440 Ok(KeccakfOracles { input, output })
441}
442
443#[inline]
444fn into_packed_vec<P>(src: &[impl Pod]) -> Vec<P>
445where
446 P: PackedField + WithUnderlier,
447 P::Underlier: Pod,
448{
449 pod_collect_to_vec::<_, P::Underlier>(src)
450 .into_iter()
451 .map(P::from_underlier)
452 .collect()
453}
454
455const STATE_SIZE: usize = 25;
456const LOG_STATE_ROWS_PER_PERMUTATION: usize = 3;
457const STATE_ROWS_PER_PERMUTATION: usize = 1 << LOG_STATE_ROWS_PER_PERMUTATION;
458const ROUNDS_PER_STATE_ROW: usize = 3;
459const LOG_BIT_ROWS_PER_STATE_ROW: usize = 6;
460const BIT_ROWS_PER_STATE_ROW: usize = 1 << LOG_BIT_ROWS_PER_STATE_ROW;
461const LOG_BIT_ROWS_PER_PERMUTATION: usize =
462 LOG_BIT_ROWS_PER_STATE_ROW + LOG_STATE_ROWS_PER_PERMUTATION;
463const BIT_ROWS_PER_PERMUTATION: usize = 1 << LOG_BIT_ROWS_PER_PERMUTATION;
464const ROUNDS_PER_PERMUTATION: usize = ROUNDS_PER_STATE_ROW * STATE_ROWS_PER_PERMUTATION;
465
466#[rustfmt::skip]
467const RHO: [u32; STATE_SIZE] = [
468 0, 44, 43, 21, 14,
469 28, 20, 3, 45, 61,
470 1, 6, 25, 8, 18,
471 27, 36, 10, 15, 56,
472 62, 55, 39, 41, 2,
473];
474
475#[rustfmt::skip]
476const PI: [usize; STATE_SIZE] = [
477 0, 6, 12, 18, 24,
478 3, 9, 10, 16, 22,
479 1, 7, 13, 19, 20,
480 4, 5, 11, 17, 23,
481 2, 8, 14, 15, 21,
482];
483
484const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [
485 0x0000000000000001,
486 0x0000000000008082,
487 0x800000000000808A,
488 0x8000000080008000,
489 0x000000000000808B,
490 0x0000000080000001,
491 0x8000000080008081,
492 0x8000000000008009,
493 0x000000000000008A,
494 0x0000000000000088,
495 0x0000000080008009,
496 0x000000008000000A,
497 0x000000008000808B,
498 0x800000000000008B,
499 0x8000000000008089,
500 0x8000000000008003,
501 0x8000000000008002,
502 0x8000000000000080,
503 0x000000000000800A,
504 0x800000008000000A,
505 0x8000000080008081,
506 0x8000000000008080,
507 0x0000000080000001,
508 0x8000000080008008,
509];
510
511#[cfg(test)]
512mod tests {
513 use rand::{rngs::StdRng, Rng, SeedableRng};
514
515 use super::{keccakf, KeccakfState};
516 use crate::builder::test_utils::test_circuit;
517
518 #[test]
519 fn test_keccakf() {
520 test_circuit(|builder| {
521 let log_size = 5;
522 let mut rng = StdRng::seed_from_u64(0);
523 let input_states = vec![KeccakfState(rng.gen())];
524 let _state_out = keccakf(builder, &Some(input_states), log_size)?;
525 Ok(vec![])
526 })
527 .unwrap();
528 }
529}