binius_circuits/
keccakf.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::array;
4
5use anyhow::anyhow;
6use binius_core::{
7	oracle::{OracleId, ProjectionVariant, ShiftVariant},
8	transparent::multilinear_extension::MultilinearExtensionTransparent,
9};
10use binius_field::{
11	as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, BinaryField64b, Field,
12	PackedField, TowerField,
13};
14use binius_macros::arith_expr;
15use bytemuck::{pod_collect_to_vec, Pod};
16
17use crate::{
18	builder::{
19		types::{F, U},
20		ConstraintSystemBuilder,
21	},
22	transparent::step_down,
23};
24
25#[derive(Default, Clone, Copy)]
26pub struct KeccakfState(pub [u64; STATE_SIZE]);
27
28pub struct KeccakfOracles {
29	pub input: [OracleId; STATE_SIZE],
30	pub output: [OracleId; STATE_SIZE],
31}
32
33pub fn keccakf(
34	builder: &mut ConstraintSystemBuilder,
35	input_witness: &Option<impl AsRef<[KeccakfState]>>,
36	log_size: usize,
37) -> Result<KeccakfOracles, anyhow::Error> {
38	let internal_log_size = log_size + LOG_BIT_ROWS_PER_PERMUTATION;
39	let round_consts_single: [OracleId; ROUNDS_PER_STATE_ROW] =
40		array::try_from_fn(|round_within_row| {
41			let round_within_row_rc: [_; STATE_ROWS_PER_PERMUTATION] =
42				array::from_fn(|row_within_perm| {
43					KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_perm + round_within_row]
44				});
45
46			let packed_vec = into_packed_vec::<PackedType<U, BinaryField1b>>(&round_within_row_rc);
47			let rc_single_mle =
48				MultilinearExtensionTransparent::<_, PackedType<U, F>, _>::from_values(packed_vec)?;
49			builder.add_transparent("round_consts_single", rc_single_mle)
50		})?;
51
52	let round_consts: [OracleId; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
53		builder.add_repeating(
54			"round_consts",
55			round_consts_single[round_within_row],
56			internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
57		)
58	})?;
59
60	if let Some(witness) = builder.witness() {
61		let mut round_consts_single =
62			round_consts_single.map(|id| witness.new_column::<BinaryField1b>(id));
63		let mut round_consts = round_consts.map(|id| witness.new_column::<BinaryField1b>(id));
64
65		let round_consts_single_u64 = round_consts_single
66			.each_mut()
67			.map(|col| col.as_mut_slice::<u64>());
68		let round_consts_u64 = round_consts.each_mut().map(|col| col.as_mut_slice::<u64>());
69
70		for row_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
71			for round_within_row in 0..ROUNDS_PER_STATE_ROW {
72				round_consts_single_u64[round_within_row][row_within_permutation] =
73					KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
74			}
75		}
76
77		for state_row_idx in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_STATE_ROW) {
78			let row_within_permutation = state_row_idx % STATE_ROWS_PER_PERMUTATION;
79			for round_within_row in 0..ROUNDS_PER_STATE_ROW {
80				round_consts_u64[round_within_row][state_row_idx] =
81					KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
82			}
83		}
84	}
85
86	let selector_single = step_down(
87		builder,
88		"selector_single",
89		LOG_BIT_ROWS_PER_PERMUTATION,
90		BIT_ROWS_PER_PERMUTATION - BIT_ROWS_PER_STATE_ROW,
91	)?;
92
93	let selector = builder.add_repeating(
94		"selector",
95		selector_single,
96		internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
97	)?;
98
99	let state: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW + 1] = array::from_fn(|_| {
100		builder.add_committed_multiple("state_in", internal_log_size, BinaryField1b::TOWER_LEVEL)
101	});
102
103	let state_in = state[0];
104
105	let state_out = state[ROUNDS_PER_STATE_ROW];
106
107	let packed_state_in: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
108		builder.add_packed("packed state input", state_in[xy], LOG_BIT_ROWS_PER_STATE_ROW)
109	})?;
110
111	let input: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
112		builder.add_projected(
113			"packed projected state input",
114			packed_state_in[xy],
115			vec![F::ZERO; LOG_STATE_ROWS_PER_PERMUTATION],
116			ProjectionVariant::FirstVars,
117		)
118	})?;
119
120	let packed_state_out: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
121		builder.add_packed("packed state output", state_out[xy], LOG_BIT_ROWS_PER_STATE_ROW)
122	})?;
123
124	let output: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
125		builder.add_projected(
126			"output",
127			packed_state_out[xy],
128			vec![Field::ONE; LOG_STATE_ROWS_PER_PERMUTATION],
129			ProjectionVariant::FirstVars,
130		)
131	})?;
132
133	let c: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
134		array::try_from_fn(|x| {
135			builder.add_linear_combination(
136				"c",
137				internal_log_size,
138				array::from_fn::<_, 5, _>(|offset| {
139					(state[round_within_row][x + 5 * offset], Field::ONE)
140				}),
141			)
142		})
143	})?;
144
145	let c_shift: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
146		array::try_from_fn(|x| {
147			builder.add_shifted(
148				format!("c[{x}]"),
149				c[round_within_row][x],
150				1,
151				6,
152				ShiftVariant::CircularLeft,
153			)
154		})
155	})?;
156
157	let d: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array::try_from_fn(|round_within_row| {
158		array::try_from_fn(|x| {
159			builder.add_linear_combination(
160				"d",
161				internal_log_size,
162				[
163					(c[round_within_row][(x + 4) % 5], Field::ONE),
164					(c_shift[round_within_row][(x + 1) % 5], Field::ONE),
165				],
166			)
167		})
168	})?;
169
170	let a_theta: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
171		array::try_from_fn(|round_within_row| {
172			array::try_from_fn(|xy| {
173				let x = xy % 5;
174				builder.add_linear_combination(
175					format!("a_theta[{xy}]"),
176					internal_log_size,
177					[
178						(state[round_within_row][xy], Field::ONE),
179						(d[round_within_row][x], Field::ONE),
180					],
181				)
182			})
183		})?;
184
185	let b: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
186		array::try_from_fn(|round_within_row| {
187			array::try_from_fn(|xy| {
188				if xy == 0 {
189					Ok(a_theta[round_within_row][0])
190				} else {
191					builder.add_shifted(
192						format!("b[{xy}]"),
193						a_theta[round_within_row][PI[xy]],
194						RHO[xy] as usize,
195						6,
196						ShiftVariant::CircularLeft,
197					)
198				}
199			})
200		})?;
201
202	let next_state_in: [OracleId; STATE_SIZE] = array::try_from_fn(|xy| {
203		builder.add_shifted(
204			format!("next_state_in[{xy}]"),
205			state_in[xy],
206			64,
207			LOG_BIT_ROWS_PER_PERMUTATION,
208			ShiftVariant::LogicalRight,
209		)
210	})?;
211
212	if let Some(witness) = builder.witness() {
213		let input_witness = input_witness
214			.as_ref()
215			.ok_or_else(|| anyhow!("builder witness available and input witness is not"))?
216			.as_ref();
217
218		let mut input = input.map(|id| witness.new_column::<BinaryField64b>(id));
219
220		let mut packed_state_in =
221			packed_state_in.map(|id| witness.new_column::<BinaryField64b>(id));
222
223		let mut packed_state_out =
224			packed_state_out.map(|id| witness.new_column::<BinaryField64b>(id));
225
226		let mut output = output.map(|id| witness.new_column::<BinaryField64b>(id));
227
228		let mut state = state
229			.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
230
231		let mut c =
232			c.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
233		let mut d =
234			d.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
235		let mut c_shift = c_shift
236			.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
237		let mut a_theta = a_theta
238			.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
239		let mut b =
240			b.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
241		let mut next_state_in = next_state_in.map(|id| witness.new_column::<BinaryField1b>(id));
242
243		let mut selector_single = witness.new_column::<BinaryField1b>(selector_single);
244
245		let mut selector = witness.new_column::<BinaryField1b>(selector);
246
247		let input_u64 = input.each_mut().map(|col| col.as_mut_slice::<u64>());
248
249		let packed_state_in_u64 = packed_state_in
250			.each_mut()
251			.map(|col| col.as_mut_slice::<u64>());
252
253		let mut packed_state_out_u64 = packed_state_out
254			.each_mut()
255			.map(|col| col.as_mut_slice::<u64>());
256
257		let output_u64 = output.each_mut().map(|col| col.as_mut_slice::<u64>());
258
259		let state_u64 = state
260			.each_mut()
261			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
262		let c_u64 = c
263			.each_mut()
264			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
265		let d_u64 = d
266			.each_mut()
267			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
268		let c_shift_u64 = c_shift
269			.each_mut()
270			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
271		let a_theta_u64 = a_theta
272			.each_mut()
273			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
274		let b_u64 = b
275			.each_mut()
276			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
277		let next_state_in_u64 = next_state_in.each_mut().map(|col| col.as_mut_slice());
278		let selector_single_u64 = selector_single.as_mut_slice::<u64>();
279		let selector_u64 = selector.as_mut_slice();
280
281		// Fill in the non-repeating selector witness
282		for selector_single_u64_row in selector_single_u64
283			.iter_mut()
284			.take(STATE_ROWS_PER_PERMUTATION - 1)
285		{
286			*selector_single_u64_row = u64::MAX;
287		}
288
289		// Each round state is 64 rows
290		// Each permutation is 24 round states
291		for perm_i in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION) {
292			let first_state_row_idx_in_perm = perm_i << LOG_STATE_ROWS_PER_PERMUTATION;
293
294			let input_this_perm = input_witness.get(perm_i).copied().unwrap_or_default().0;
295
296			for xy in 0..STATE_SIZE {
297				input_u64[xy][perm_i] = input_this_perm[xy];
298			}
299
300			let expected_output_this_perm = {
301				let mut output = input_this_perm;
302				tiny_keccak::keccakf(&mut output);
303				output
304			};
305
306			for xy in 0..STATE_SIZE {
307				output_u64[xy][perm_i] = expected_output_this_perm[xy];
308			}
309
310			// Assign the permutation inputs for the long table
311			for xy in 0..STATE_SIZE {
312				state_u64[0][xy][first_state_row_idx_in_perm] = input_this_perm[xy];
313				packed_state_in_u64[xy][first_state_row_idx_in_perm] = input_this_perm[xy];
314			}
315
316			for row_idx_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
317				let state_row_idx = first_state_row_idx_in_perm | row_idx_within_permutation;
318				// Expand trace columns for each round on the row
319				for round_within_row in 0..ROUNDS_PER_STATE_ROW {
320					let keccakf_rc = KECCAKF_RC
321						[ROUNDS_PER_STATE_ROW * row_idx_within_permutation + round_within_row];
322
323					for x in 0..5 {
324						c_u64[round_within_row][x][state_row_idx] = (0..5).fold(0, |acc, y| {
325							acc ^ state_u64[round_within_row][x + 5 * y][state_row_idx]
326						});
327						c_shift_u64[round_within_row][x][state_row_idx] =
328							c_u64[round_within_row][x][state_row_idx].rotate_left(1);
329					}
330
331					for x in 0..5 {
332						d_u64[round_within_row][x][state_row_idx] = c_u64[round_within_row]
333							[(x + 4) % 5][state_row_idx]
334							^ c_shift_u64[round_within_row][(x + 1) % 5][state_row_idx];
335					}
336
337					for x in 0..5 {
338						for y in 0..5 {
339							a_theta_u64[round_within_row][x + 5 * y][state_row_idx] = state_u64
340								[round_within_row][x + 5 * y][state_row_idx]
341								^ d_u64[round_within_row][x][state_row_idx];
342						}
343					}
344
345					for xy in 0..STATE_SIZE {
346						b_u64[round_within_row][xy][state_row_idx] = a_theta_u64[round_within_row]
347							[PI[xy]][state_row_idx]
348							.rotate_left(RHO[xy]);
349					}
350
351					for x in 0..5 {
352						for y in 0..5 {
353							let b0 = b_u64[round_within_row][x + 5 * y][state_row_idx];
354							let b1 = b_u64[round_within_row][(x + 1) % 5 + 5 * y][state_row_idx];
355							let b2 = b_u64[round_within_row][(x + 2) % 5 + 5 * y][state_row_idx];
356
357							state_u64[round_within_row + 1][x + 5 * y][state_row_idx] =
358								b0 ^ (!b1 & b2);
359						}
360					}
361
362					state_u64[round_within_row + 1][0][state_row_idx] ^= keccakf_rc;
363				}
364
365				for (xy, packed_state_out_u64_row) in packed_state_out_u64.iter_mut().enumerate() {
366					packed_state_out_u64_row[state_row_idx] =
367						state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
368				}
369
370				if row_idx_within_permutation < (STATE_ROWS_PER_PERMUTATION - 1) {
371					for xy in 0..STATE_SIZE {
372						let this_row_output = state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
373
374						state_u64[0][xy][state_row_idx + 1] = this_row_output;
375						packed_state_in_u64[xy][state_row_idx + 1] = this_row_output;
376						next_state_in_u64[xy][state_row_idx] = this_row_output;
377					}
378					selector_u64[state_row_idx] = u64::MAX;
379				}
380			}
381
382			let last_state_row_idx_in_perm =
383				first_state_row_idx_in_perm + (STATE_ROWS_PER_PERMUTATION - 1);
384
385			let actual_output_this_perm =
386				array::from_fn(|i| state_u64[ROUNDS_PER_STATE_ROW][i][last_state_row_idx_in_perm]);
387
388			assert_eq!(expected_output_this_perm, actual_output_this_perm);
389		}
390	}
391
392	let chi_iota = arith_expr!([s, b0, b1, b2, rc] = s - (rc + b0 + (1 - b1) * b2));
393	let chi = arith_expr!([s, b0, b1, b2] = s - (b0 + (1 - b1) * b2));
394	for x in 0..5 {
395		for y in 0..5 {
396			for round_within_row in 0..ROUNDS_PER_STATE_ROW {
397				let this_round_output = state[round_within_row + 1][x + 5 * y];
398
399				if x == 0 && y == 0 {
400					builder.assert_zero(
401						format!("chi_iota(round_within_row={round_within_row}, x={x}, y={y})"),
402						[
403							this_round_output,
404							b[round_within_row][x + 5 * y],
405							b[round_within_row][(x + 1) % 5 + 5 * y],
406							b[round_within_row][(x + 2) % 5 + 5 * y],
407							round_consts[round_within_row],
408						],
409						chi_iota.clone().convert_field(),
410					);
411				} else {
412					builder.assert_zero(
413						format!("chi(round_within_row={round_within_row}, x={x}, y={y})"),
414						[
415							this_round_output,
416							b[round_within_row][x + 5 * y],
417							b[round_within_row][(x + 1) % 5 + 5 * y],
418							b[round_within_row][(x + 2) % 5 + 5 * y],
419						],
420						chi.clone().convert_field(),
421					)
422				}
423			}
424		}
425	}
426
427	let selector_consistency =
428		arith_expr!([state_out, next_state_in, select] = (state_out - next_state_in) * select);
429
430	for xy in 0..STATE_SIZE {
431		builder.assert_zero(
432			format!("next_state_in_is_state_out_{xy}"),
433			[state_out[xy], next_state_in[xy], selector],
434			selector_consistency.clone().convert_field(),
435		)
436	}
437
438	Ok(KeccakfOracles { input, output })
439}
440
441#[inline]
442fn into_packed_vec<P>(src: &[impl Pod]) -> Vec<P>
443where
444	P: PackedField + WithUnderlier,
445	P::Underlier: Pod,
446{
447	pod_collect_to_vec::<_, P::Underlier>(src)
448		.into_iter()
449		.map(P::from_underlier)
450		.collect()
451}
452
453const STATE_SIZE: usize = 25;
454const LOG_STATE_ROWS_PER_PERMUTATION: usize = 3;
455const STATE_ROWS_PER_PERMUTATION: usize = 1 << LOG_STATE_ROWS_PER_PERMUTATION;
456const ROUNDS_PER_STATE_ROW: usize = 3;
457const LOG_BIT_ROWS_PER_STATE_ROW: usize = 6;
458const BIT_ROWS_PER_STATE_ROW: usize = 1 << LOG_BIT_ROWS_PER_STATE_ROW;
459const LOG_BIT_ROWS_PER_PERMUTATION: usize =
460	LOG_BIT_ROWS_PER_STATE_ROW + LOG_STATE_ROWS_PER_PERMUTATION;
461const BIT_ROWS_PER_PERMUTATION: usize = 1 << LOG_BIT_ROWS_PER_PERMUTATION;
462const ROUNDS_PER_PERMUTATION: usize = ROUNDS_PER_STATE_ROW * STATE_ROWS_PER_PERMUTATION;
463
464#[rustfmt::skip]
465const RHO: [u32; STATE_SIZE] = [
466	 0, 44, 43, 21, 14,
467	28, 20,  3, 45, 61,
468	 1,  6, 25,  8, 18,
469	27, 36, 10, 15, 56,
470	62, 55, 39, 41,  2,
471];
472
473#[rustfmt::skip]
474const PI: [usize; STATE_SIZE] = [
475	0, 6, 12, 18, 24,
476	3, 9, 10, 16, 22,
477	1, 7, 13, 19, 20,
478	4, 5, 11, 17, 23,
479	2, 8, 14, 15, 21,
480];
481
482const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [
483	0x0000000000000001,
484	0x0000000000008082,
485	0x800000000000808A,
486	0x8000000080008000,
487	0x000000000000808B,
488	0x0000000080000001,
489	0x8000000080008081,
490	0x8000000000008009,
491	0x000000000000008A,
492	0x0000000000000088,
493	0x0000000080008009,
494	0x000000008000000A,
495	0x000000008000808B,
496	0x800000000000008B,
497	0x8000000000008089,
498	0x8000000000008003,
499	0x8000000000008002,
500	0x8000000000000080,
501	0x000000000000800A,
502	0x800000008000000A,
503	0x8000000080008081,
504	0x8000000000008080,
505	0x0000000080000001,
506	0x8000000080008008,
507];
508
509#[cfg(test)]
510mod tests {
511	use rand::{rngs::StdRng, Rng, SeedableRng};
512
513	use super::{keccakf, KeccakfState};
514	use crate::builder::test_utils::test_circuit;
515
516	#[test]
517	fn test_keccakf() {
518		test_circuit(|builder| {
519			let log_size = 5;
520			let mut rng = StdRng::seed_from_u64(0);
521			let input_states = vec![KeccakfState(rng.gen())];
522			let _state_out = keccakf(builder, &Some(input_states), log_size)?;
523			Ok(vec![])
524		})
525		.unwrap();
526	}
527}