1use std::array;
4
5use anyhow::anyhow;
6use binius_core::{
7 oracle::{OracleId, ProjectionVariant, ShiftVariant},
8 transparent::multilinear_extension::MultilinearExtensionTransparent,
9};
10use binius_field::{
11 as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, BinaryField64b, Field,
12 PackedField, TowerField,
13};
14use binius_macros::arith_expr;
15use bytemuck::{pod_collect_to_vec, Pod};
16
17use crate::{
18 builder::{
19 types::{F, U},
20 ConstraintSystemBuilder,
21 },
22 transparent::step_down,
23};
24
25#[derive(Default, Clone, Copy)]
26pub struct KeccakfState(pub [u64; STATE_SIZE]);
27
28pub struct KeccakfOracles {
29 pub input: [OracleId; STATE_SIZE],
30 pub output: [OracleId; STATE_SIZE],
31}
32
33pub fn keccakf(
34 builder: &mut ConstraintSystemBuilder,
35 input_witness: &Option<impl AsRef<[KeccakfState]>>,
36 log_size: usize,
37) -> Result<KeccakfOracles, anyhow::Error> {
38 let internal_log_size = log_size + LOG_BIT_ROWS_PER_PERMUTATION;
39 let round_consts_single: [OracleId; ROUNDS_PER_STATE_ROW] =
40 array::try_from_fn(|round_within_row| {
41 let round_within_row_rc: [_; STATE_ROWS_PER_PERMUTATION] =
42 array::from_fn(|row_within_perm| {
43 KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_perm + round_within_row]
44 });
45
46 let packed_vec = into_packed_vec::<PackedType<U, BinaryField1b>>(&round_within_row_rc);
47 let rc_single_mle =
48 MultilinearExtensionTransparent::<_, PackedType<U, F>, _>::from_values(packed_vec)?;
49 builder.add_transparent("round_consts_single", rc_single_mle)
50 })?;
51
52 let round_consts: [OracleId; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
53 builder.add_repeating(
54 "round_consts",
55 round_consts_single[round_within_row],
56 internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
57 )
58 })?;
59
60 if let Some(witness) = builder.witness() {
61 let mut round_consts_single =
62 round_consts_single.map(|id| witness.new_column::<BinaryField1b>(id));
63 let mut round_consts = round_consts.map(|id| witness.new_column::<BinaryField1b>(id));
64
65 let round_consts_single_u64 = round_consts_single
66 .each_mut()
67 .map(|col| col.as_mut_slice::<u64>());
68 let round_consts_u64 = round_consts.each_mut().map(|col| col.as_mut_slice::<u64>());
69
70 for row_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
71 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
72 round_consts_single_u64[round_within_row][row_within_permutation] =
73 KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
74 }
75 }
76
77 for state_row_idx in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_STATE_ROW) {
78 let row_within_permutation = state_row_idx % STATE_ROWS_PER_PERMUTATION;
79 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
80 round_consts_u64[round_within_row][state_row_idx] =
81 KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
82 }
83 }
84 }
85
86 let selector_single = step_down(
87 builder,
88 "selector_single",
89 LOG_BIT_ROWS_PER_PERMUTATION,
90 BIT_ROWS_PER_PERMUTATION - BIT_ROWS_PER_STATE_ROW,
91 )?;
92
93 let selector = builder.add_repeating(
94 "selector",
95 selector_single,
96 internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
97 )?;
98
99 let state: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW + 1] = array::from_fn(|_| {
100 builder.add_committed_multiple("state_in", internal_log_size, BinaryField1b::TOWER_LEVEL)
101 });
102
103 let state_in = state[0];
104
105 let state_out = state[ROUNDS_PER_STATE_ROW];
106
107 let packed_state_in: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
108 builder.add_packed("packed state input", state_in[xy], LOG_BIT_ROWS_PER_STATE_ROW)
109 })?;
110
111 let input: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
112 builder.add_projected(
113 "packed projected state input",
114 packed_state_in[xy],
115 vec![F::ZERO; LOG_STATE_ROWS_PER_PERMUTATION],
116 ProjectionVariant::FirstVars,
117 )
118 })?;
119
120 let packed_state_out: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
121 builder.add_packed("packed state output", state_out[xy], LOG_BIT_ROWS_PER_STATE_ROW)
122 })?;
123
124 let output: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
125 builder.add_projected(
126 "output",
127 packed_state_out[xy],
128 vec![Field::ONE; LOG_STATE_ROWS_PER_PERMUTATION],
129 ProjectionVariant::FirstVars,
130 )
131 })?;
132
133 let c: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
134 array::try_from_fn(|x| {
135 builder.add_linear_combination(
136 "c",
137 internal_log_size,
138 array::from_fn::<_, 5, _>(|offset| {
139 (state[round_within_row][x + 5 * offset], Field::ONE)
140 }),
141 )
142 })
143 })?;
144
145 let c_shift: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
146 array::try_from_fn(|x| {
147 builder.add_shifted(
148 format!("c[{x}]"),
149 c[round_within_row][x],
150 1,
151 6,
152 ShiftVariant::CircularLeft,
153 )
154 })
155 })?;
156
157 let d: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
158 array::try_from_fn(|x| {
159 builder.add_linear_combination(
160 "d",
161 internal_log_size,
162 [
163 (c[round_within_row][(x + 4) % 5], Field::ONE),
164 (c_shift[round_within_row][(x + 1) % 5], Field::ONE),
165 ],
166 )
167 })
168 })?;
169
170 let a_theta: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
171 array::try_from_fn(|round_within_row| {
172 array::try_from_fn(|xy| {
173 let x = xy % 5;
174 builder.add_linear_combination(
175 format!("a_theta[{xy}]"),
176 internal_log_size,
177 [
178 (state[round_within_row][xy], Field::ONE),
179 (d[round_within_row][x], Field::ONE),
180 ],
181 )
182 })
183 })?;
184
185 let b: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
186 array::try_from_fn(|round_within_row| {
187 array::try_from_fn(|xy| {
188 if xy == 0 {
189 Ok(a_theta[round_within_row][0])
190 } else {
191 builder.add_shifted(
192 format!("b[{xy}]"),
193 a_theta[round_within_row][PI[xy]],
194 RHO[xy] as usize,
195 6,
196 ShiftVariant::CircularLeft,
197 )
198 }
199 })
200 })?;
201
202 let next_state_in: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
203 builder.add_shifted(
204 format!("next_state_in[{xy}]"),
205 state_in[xy],
206 64,
207 LOG_BIT_ROWS_PER_PERMUTATION,
208 ShiftVariant::LogicalRight,
209 )
210 })?;
211
212 if let Some(witness) = builder.witness() {
213 let input_witness = input_witness
214 .as_ref()
215 .ok_or_else(|| anyhow!("builder witness available and input witness is not"))?
216 .as_ref();
217
218 let mut input = input.map(|id| witness.new_column::<BinaryField64b>(id));
219
220 let mut packed_state_in =
221 packed_state_in.map(|id| witness.new_column::<BinaryField64b>(id));
222
223 let mut packed_state_out =
224 packed_state_out.map(|id| witness.new_column::<BinaryField64b>(id));
225
226 let mut output = output.map(|id| witness.new_column::<BinaryField64b>(id));
227
228 let mut state = state
229 .map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
230
231 let mut c =
232 c.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
233 let mut d =
234 d.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
235 let mut c_shift = c_shift
236 .map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
237 let mut a_theta = a_theta
238 .map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
239 let mut b =
240 b.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
241 let mut next_state_in = next_state_in.map(|id| witness.new_column::<BinaryField1b>(id));
242
243 let mut selector_single = witness.new_column::<BinaryField1b>(selector_single);
244
245 let mut selector = witness.new_column::<BinaryField1b>(selector);
246
247 let input_u64 = input.each_mut().map(|col| col.as_mut_slice::<u64>());
248
249 let packed_state_in_u64 = packed_state_in
250 .each_mut()
251 .map(|col| col.as_mut_slice::<u64>());
252
253 let mut packed_state_out_u64 = packed_state_out
254 .each_mut()
255 .map(|col| col.as_mut_slice::<u64>());
256
257 let output_u64 = output.each_mut().map(|col| col.as_mut_slice::<u64>());
258
259 let state_u64 = state
260 .each_mut()
261 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
262 let c_u64 = c
263 .each_mut()
264 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
265 let d_u64 = d
266 .each_mut()
267 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
268 let c_shift_u64 = c_shift
269 .each_mut()
270 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
271 let a_theta_u64 = a_theta
272 .each_mut()
273 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
274 let b_u64 = b
275 .each_mut()
276 .map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
277 let next_state_in_u64 = next_state_in.each_mut().map(|col| col.as_mut_slice());
278 let selector_single_u64 = selector_single.as_mut_slice::<u64>();
279 let selector_u64 = selector.as_mut_slice();
280
281 for selector_single_u64_row in selector_single_u64
283 .iter_mut()
284 .take(STATE_ROWS_PER_PERMUTATION - 1)
285 {
286 *selector_single_u64_row = u64::MAX;
287 }
288
289 for perm_i in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION) {
292 let first_state_row_idx_in_perm = perm_i << LOG_STATE_ROWS_PER_PERMUTATION;
293
294 let input_this_perm = input_witness.get(perm_i).copied().unwrap_or_default().0;
295
296 for xy in 0..STATE_SIZE {
297 input_u64[xy][perm_i] = input_this_perm[xy];
298 }
299
300 let expected_output_this_perm = {
301 let mut output = input_this_perm;
302 tiny_keccak::keccakf(&mut output);
303 output
304 };
305
306 for xy in 0..STATE_SIZE {
307 output_u64[xy][perm_i] = expected_output_this_perm[xy];
308 }
309
310 for xy in 0..STATE_SIZE {
312 state_u64[0][xy][first_state_row_idx_in_perm] = input_this_perm[xy];
313 packed_state_in_u64[xy][first_state_row_idx_in_perm] = input_this_perm[xy];
314 }
315
316 for row_idx_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
317 let state_row_idx = first_state_row_idx_in_perm | row_idx_within_permutation;
318 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
320 let keccakf_rc = KECCAKF_RC
321 [ROUNDS_PER_STATE_ROW * row_idx_within_permutation + round_within_row];
322
323 for x in 0..5 {
324 c_u64[round_within_row][x][state_row_idx] = (0..5).fold(0, |acc, y| {
325 acc ^ state_u64[round_within_row][x + 5 * y][state_row_idx]
326 });
327 c_shift_u64[round_within_row][x][state_row_idx] =
328 c_u64[round_within_row][x][state_row_idx].rotate_left(1);
329 }
330
331 for x in 0..5 {
332 d_u64[round_within_row][x][state_row_idx] = c_u64[round_within_row]
333 [(x + 4) % 5][state_row_idx]
334 ^ c_shift_u64[round_within_row][(x + 1) % 5][state_row_idx];
335 }
336
337 for x in 0..5 {
338 for y in 0..5 {
339 a_theta_u64[round_within_row][x + 5 * y][state_row_idx] = state_u64
340 [round_within_row][x + 5 * y][state_row_idx]
341 ^ d_u64[round_within_row][x][state_row_idx];
342 }
343 }
344
345 for xy in 0..STATE_SIZE {
346 b_u64[round_within_row][xy][state_row_idx] = a_theta_u64[round_within_row]
347 [PI[xy]][state_row_idx]
348 .rotate_left(RHO[xy]);
349 }
350
351 for x in 0..5 {
352 for y in 0..5 {
353 let b0 = b_u64[round_within_row][x + 5 * y][state_row_idx];
354 let b1 = b_u64[round_within_row][(x + 1) % 5 + 5 * y][state_row_idx];
355 let b2 = b_u64[round_within_row][(x + 2) % 5 + 5 * y][state_row_idx];
356
357 state_u64[round_within_row + 1][x + 5 * y][state_row_idx] =
358 b0 ^ (!b1 & b2);
359 }
360 }
361
362 state_u64[round_within_row + 1][0][state_row_idx] ^= keccakf_rc;
363 }
364
365 for (xy, packed_state_out_u64_row) in packed_state_out_u64.iter_mut().enumerate() {
366 packed_state_out_u64_row[state_row_idx] =
367 state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
368 }
369
370 if row_idx_within_permutation < (STATE_ROWS_PER_PERMUTATION - 1) {
371 for xy in 0..STATE_SIZE {
372 let this_row_output = state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
373
374 state_u64[0][xy][state_row_idx + 1] = this_row_output;
375 packed_state_in_u64[xy][state_row_idx + 1] = this_row_output;
376 next_state_in_u64[xy][state_row_idx] = this_row_output;
377 }
378 selector_u64[state_row_idx] = u64::MAX;
379 }
380 }
381
382 let last_state_row_idx_in_perm =
383 first_state_row_idx_in_perm + (STATE_ROWS_PER_PERMUTATION - 1);
384
385 let actual_output_this_perm =
386 array::from_fn(|i| state_u64[ROUNDS_PER_STATE_ROW][i][last_state_row_idx_in_perm]);
387
388 assert_eq!(expected_output_this_perm, actual_output_this_perm);
389 }
390 }
391
392 let chi_iota = arith_expr!([s, b0, b1, b2, rc] = s - (rc + b0 + (1 - b1) * b2));
393 let chi = arith_expr!([s, b0, b1, b2] = s - (b0 + (1 - b1) * b2));
394 for x in 0..5 {
395 for y in 0..5 {
396 for round_within_row in 0..ROUNDS_PER_STATE_ROW {
397 let this_round_output = state[round_within_row + 1][x + 5 * y];
398
399 if x == 0 && y == 0 {
400 builder.assert_zero(
401 format!("chi_iota(round_within_row={round_within_row}, x={x}, y={y})"),
402 [
403 this_round_output,
404 b[round_within_row][x + 5 * y],
405 b[round_within_row][(x + 1) % 5 + 5 * y],
406 b[round_within_row][(x + 2) % 5 + 5 * y],
407 round_consts[round_within_row],
408 ],
409 chi_iota.clone().convert_field(),
410 );
411 } else {
412 builder.assert_zero(
413 format!("chi(round_within_row={round_within_row}, x={x}, y={y})"),
414 [
415 this_round_output,
416 b[round_within_row][x + 5 * y],
417 b[round_within_row][(x + 1) % 5 + 5 * y],
418 b[round_within_row][(x + 2) % 5 + 5 * y],
419 ],
420 chi.clone().convert_field(),
421 )
422 }
423 }
424 }
425 }
426
427 let selector_consistency =
428 arith_expr!([state_out, next_state_in, select] = (state_out - next_state_in) * select);
429
430 for xy in 0..STATE_SIZE {
431 builder.assert_zero(
432 format!("next_state_in_is_state_out_{xy}"),
433 [state_out[xy], next_state_in[xy], selector],
434 selector_consistency.clone().convert_field(),
435 )
436 }
437
438 Ok(KeccakfOracles { input, output })
439}
440
441#[inline]
442fn into_packed_vec<P>(src: &[impl Pod]) -> Vec<P>
443where
444 P: PackedField + WithUnderlier,
445 P::Underlier: Pod,
446{
447 pod_collect_to_vec::<_, P::Underlier>(src)
448 .into_iter()
449 .map(P::from_underlier)
450 .collect()
451}
452
453const STATE_SIZE: usize = 25;
454const LOG_STATE_ROWS_PER_PERMUTATION: usize = 3;
455const STATE_ROWS_PER_PERMUTATION: usize = 1 << LOG_STATE_ROWS_PER_PERMUTATION;
456const ROUNDS_PER_STATE_ROW: usize = 3;
457const LOG_BIT_ROWS_PER_STATE_ROW: usize = 6;
458const BIT_ROWS_PER_STATE_ROW: usize = 1 << LOG_BIT_ROWS_PER_STATE_ROW;
459const LOG_BIT_ROWS_PER_PERMUTATION: usize =
460 LOG_BIT_ROWS_PER_STATE_ROW + LOG_STATE_ROWS_PER_PERMUTATION;
461const BIT_ROWS_PER_PERMUTATION: usize = 1 << LOG_BIT_ROWS_PER_PERMUTATION;
462const ROUNDS_PER_PERMUTATION: usize = ROUNDS_PER_STATE_ROW * STATE_ROWS_PER_PERMUTATION;
463
464#[rustfmt::skip]
465const RHO: [u32; STATE_SIZE] = [
466 0, 44, 43, 21, 14,
467 28, 20, 3, 45, 61,
468 1, 6, 25, 8, 18,
469 27, 36, 10, 15, 56,
470 62, 55, 39, 41, 2,
471];
472
473#[rustfmt::skip]
474const PI: [usize; STATE_SIZE] = [
475 0, 6, 12, 18, 24,
476 3, 9, 10, 16, 22,
477 1, 7, 13, 19, 20,
478 4, 5, 11, 17, 23,
479 2, 8, 14, 15, 21,
480];
481
482const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [
483 0x0000000000000001,
484 0x0000000000008082,
485 0x800000000000808A,
486 0x8000000080008000,
487 0x000000000000808B,
488 0x0000000080000001,
489 0x8000000080008081,
490 0x8000000000008009,
491 0x000000000000008A,
492 0x0000000000000088,
493 0x0000000080008009,
494 0x000000008000000A,
495 0x000000008000808B,
496 0x800000000000008B,
497 0x8000000000008089,
498 0x8000000000008003,
499 0x8000000000008002,
500 0x8000000000000080,
501 0x000000000000800A,
502 0x800000008000000A,
503 0x8000000080008081,
504 0x8000000000008080,
505 0x0000000080000001,
506 0x8000000080008008,
507];
508
509#[cfg(test)]
510mod tests {
511 use rand::{rngs::StdRng, Rng, SeedableRng};
512
513 use super::{keccakf, KeccakfState};
514 use crate::builder::test_utils::test_circuit;
515
516 #[test]
517 fn test_keccakf() {
518 test_circuit(|builder| {
519 let log_size = 5;
520 let mut rng = StdRng::seed_from_u64(0);
521 let input_states = vec![KeccakfState(rng.gen())];
522 let _state_out = keccakf(builder, &Some(input_states), log_size)?;
523 Ok(vec![])
524 })
525 .unwrap();
526 }
527}