binius_circuits/
keccakf.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::array;
4
5use anyhow::anyhow;
6use binius_core::{
7	oracle::{OracleId, ShiftVariant},
8	transparent::multilinear_extension::MultilinearExtensionTransparent,
9};
10use binius_field::{
11	as_packed_field::PackedType, underlier::WithUnderlier, BinaryField1b, BinaryField64b, Field,
12	PackedField, TowerField,
13};
14use binius_macros::arith_expr;
15use bytemuck::{pod_collect_to_vec, Pod};
16
17use crate::{
18	builder::{
19		types::{F, U},
20		ConstraintSystemBuilder,
21	},
22	transparent::step_down,
23};
24
25#[derive(Default, Clone, Copy)]
26pub struct KeccakfState(pub [u64; STATE_SIZE]);
27
28pub struct KeccakfOracles {
29	pub input: [OracleId; STATE_SIZE],
30	pub output: [OracleId; STATE_SIZE],
31}
32
33pub fn keccakf(
34	builder: &mut ConstraintSystemBuilder,
35	input_witness: &Option<impl AsRef<[KeccakfState]>>,
36	log_size: usize,
37) -> Result<KeccakfOracles, anyhow::Error> {
38	let internal_log_size = log_size + LOG_BIT_ROWS_PER_PERMUTATION;
39	let round_consts_single: [OracleId; ROUNDS_PER_STATE_ROW] =
40		array_util::try_from_fn(|round_within_row| {
41			let round_within_row_rc: [_; STATE_ROWS_PER_PERMUTATION] =
42				array::from_fn(|row_within_perm| {
43					KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_perm + round_within_row]
44				});
45
46			let packed_vec = into_packed_vec::<PackedType<U, BinaryField1b>>(&round_within_row_rc);
47			let rc_single_mle =
48				MultilinearExtensionTransparent::<_, PackedType<U, F>, _>::from_values(packed_vec)?;
49			builder.add_transparent("round_consts_single", rc_single_mle)
50		})?;
51
52	let round_consts: [OracleId; ROUNDS_PER_STATE_ROW] =
53		array_util::try_from_fn(|round_within_row| {
54			builder.add_repeating(
55				"round_consts",
56				round_consts_single[round_within_row],
57				internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
58			)
59		})?;
60
61	if let Some(witness) = builder.witness() {
62		let mut round_consts_single =
63			round_consts_single.map(|id| witness.new_column::<BinaryField1b>(id));
64		let mut round_consts = round_consts.map(|id| witness.new_column::<BinaryField1b>(id));
65
66		let round_consts_single_u64 = round_consts_single
67			.each_mut()
68			.map(|col| col.as_mut_slice::<u64>());
69		let round_consts_u64 = round_consts.each_mut().map(|col| col.as_mut_slice::<u64>());
70
71		for row_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
72			for round_within_row in 0..ROUNDS_PER_STATE_ROW {
73				round_consts_single_u64[round_within_row][row_within_permutation] =
74					KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
75			}
76		}
77
78		for state_row_idx in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_STATE_ROW) {
79			let row_within_permutation = state_row_idx % STATE_ROWS_PER_PERMUTATION;
80			for round_within_row in 0..ROUNDS_PER_STATE_ROW {
81				round_consts_u64[round_within_row][state_row_idx] =
82					KECCAKF_RC[ROUNDS_PER_STATE_ROW * row_within_permutation + round_within_row];
83			}
84		}
85	}
86
87	let selector_single = step_down(
88		builder,
89		"selector_single",
90		LOG_BIT_ROWS_PER_PERMUTATION,
91		BIT_ROWS_PER_PERMUTATION - BIT_ROWS_PER_STATE_ROW,
92	)?;
93
94	let selector = builder.add_repeating(
95		"selector",
96		selector_single,
97		internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION,
98	)?;
99
100	let state: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW + 1] = array::from_fn(|_| {
101		builder.add_committed_multiple("state_in", internal_log_size, BinaryField1b::TOWER_LEVEL)
102	});
103
104	let state_in = state[0];
105
106	let state_out = state[ROUNDS_PER_STATE_ROW];
107
108	let packed_state_in: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
109		builder.add_packed("packed state input", state_in[xy], LOG_BIT_ROWS_PER_STATE_ROW)
110	})?;
111
112	let input: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
113		builder.add_projected(
114			"packed projected state input",
115			packed_state_in[xy],
116			vec![F::ZERO; LOG_STATE_ROWS_PER_PERMUTATION],
117			0,
118		)
119	})?;
120
121	let packed_state_out: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
122		builder.add_packed("packed state output", state_out[xy], LOG_BIT_ROWS_PER_STATE_ROW)
123	})?;
124
125	let output: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
126		builder.add_projected(
127			"output",
128			packed_state_out[xy],
129			vec![Field::ONE; LOG_STATE_ROWS_PER_PERMUTATION],
130			0,
131		)
132	})?;
133
134	let c: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array_util::try_from_fn(|round_within_row| {
135		array_util::try_from_fn(|x| {
136			builder.add_linear_combination(
137				"c",
138				internal_log_size,
139				array::from_fn::<_, 5, _>(|offset| {
140					(state[round_within_row][x + 5 * offset], Field::ONE)
141				}),
142			)
143		})
144	})?;
145
146	let c_shift: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] =
147		array_util::try_from_fn(|round_within_row| {
148			array_util::try_from_fn(|x| {
149				builder.add_shifted(
150					format!("c[{x}]"),
151					c[round_within_row][x],
152					1,
153					6,
154					ShiftVariant::CircularLeft,
155				)
156			})
157		})?;
158
159	let d: [[OracleId; 5]; ROUNDS_PER_STATE_ROW] = array_util::try_from_fn(|round_within_row| {
160		array_util::try_from_fn(|x| {
161			builder.add_linear_combination(
162				"d",
163				internal_log_size,
164				[
165					(c[round_within_row][(x + 4) % 5], Field::ONE),
166					(c_shift[round_within_row][(x + 1) % 5], Field::ONE),
167				],
168			)
169		})
170	})?;
171
172	let a_theta: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
173		array_util::try_from_fn(|round_within_row| {
174			array_util::try_from_fn(|xy| {
175				let x = xy % 5;
176				builder.add_linear_combination(
177					format!("a_theta[{xy}]"),
178					internal_log_size,
179					[
180						(state[round_within_row][xy], Field::ONE),
181						(d[round_within_row][x], Field::ONE),
182					],
183				)
184			})
185		})?;
186
187	let b: [[OracleId; STATE_SIZE]; ROUNDS_PER_STATE_ROW] =
188		array_util::try_from_fn(|round_within_row| {
189			array_util::try_from_fn(|xy| {
190				if xy == 0 {
191					Ok(a_theta[round_within_row][0])
192				} else {
193					builder.add_shifted(
194						format!("b[{xy}]"),
195						a_theta[round_within_row][PI[xy]],
196						RHO[xy] as usize,
197						6,
198						ShiftVariant::CircularLeft,
199					)
200				}
201			})
202		})?;
203
204	let next_state_in: [OracleId; STATE_SIZE] = array_util::try_from_fn(|xy| {
205		builder.add_shifted(
206			format!("next_state_in[{xy}]"),
207			state_in[xy],
208			64,
209			LOG_BIT_ROWS_PER_PERMUTATION,
210			ShiftVariant::LogicalRight,
211		)
212	})?;
213
214	if let Some(witness) = builder.witness() {
215		let input_witness = input_witness
216			.as_ref()
217			.ok_or_else(|| anyhow!("builder witness available and input witness is not"))?
218			.as_ref();
219
220		let mut input = input.map(|id| witness.new_column::<BinaryField64b>(id));
221
222		let mut packed_state_in =
223			packed_state_in.map(|id| witness.new_column::<BinaryField64b>(id));
224
225		let mut packed_state_out =
226			packed_state_out.map(|id| witness.new_column::<BinaryField64b>(id));
227
228		let mut output = output.map(|id| witness.new_column::<BinaryField64b>(id));
229
230		let mut state = state
231			.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
232
233		let mut c =
234			c.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
235		let mut d =
236			d.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
237		let mut c_shift = c_shift
238			.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
239		let mut a_theta = a_theta
240			.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
241		let mut b =
242			b.map(|round_oracles| round_oracles.map(|id| witness.new_column::<BinaryField1b>(id)));
243		let mut next_state_in = next_state_in.map(|id| witness.new_column::<BinaryField1b>(id));
244
245		let mut selector_single = witness.new_column::<BinaryField1b>(selector_single);
246
247		let mut selector = witness.new_column::<BinaryField1b>(selector);
248
249		let input_u64 = input.each_mut().map(|col| col.as_mut_slice::<u64>());
250
251		let packed_state_in_u64 = packed_state_in
252			.each_mut()
253			.map(|col| col.as_mut_slice::<u64>());
254
255		let mut packed_state_out_u64 = packed_state_out
256			.each_mut()
257			.map(|col| col.as_mut_slice::<u64>());
258
259		let output_u64 = output.each_mut().map(|col| col.as_mut_slice::<u64>());
260
261		let state_u64 = state
262			.each_mut()
263			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
264		let c_u64 = c
265			.each_mut()
266			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
267		let d_u64 = d
268			.each_mut()
269			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
270		let c_shift_u64 = c_shift
271			.each_mut()
272			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
273		let a_theta_u64 = a_theta
274			.each_mut()
275			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
276		let b_u64 = b
277			.each_mut()
278			.map(|round_cols| round_cols.each_mut().map(|col| col.as_mut_slice::<u64>()));
279		let next_state_in_u64 = next_state_in.each_mut().map(|col| col.as_mut_slice());
280		let selector_single_u64 = selector_single.as_mut_slice::<u64>();
281		let selector_u64 = selector.as_mut_slice();
282
283		// Fill in the non-repeating selector witness
284		for selector_single_u64_row in selector_single_u64
285			.iter_mut()
286			.take(STATE_ROWS_PER_PERMUTATION - 1)
287		{
288			*selector_single_u64_row = u64::MAX;
289		}
290
291		// Each round state is 64 rows
292		// Each permutation is 24 round states
293		for perm_i in 0..1 << (internal_log_size - LOG_BIT_ROWS_PER_PERMUTATION) {
294			let first_state_row_idx_in_perm = perm_i << LOG_STATE_ROWS_PER_PERMUTATION;
295
296			let input_this_perm = input_witness.get(perm_i).copied().unwrap_or_default().0;
297
298			for xy in 0..STATE_SIZE {
299				input_u64[xy][perm_i] = input_this_perm[xy];
300			}
301
302			let expected_output_this_perm = {
303				let mut output = input_this_perm;
304				tiny_keccak::keccakf(&mut output);
305				output
306			};
307
308			for xy in 0..STATE_SIZE {
309				output_u64[xy][perm_i] = expected_output_this_perm[xy];
310			}
311
312			// Assign the permutation inputs for the long table
313			for xy in 0..STATE_SIZE {
314				state_u64[0][xy][first_state_row_idx_in_perm] = input_this_perm[xy];
315				packed_state_in_u64[xy][first_state_row_idx_in_perm] = input_this_perm[xy];
316			}
317
318			for row_idx_within_permutation in 0..STATE_ROWS_PER_PERMUTATION {
319				let state_row_idx = first_state_row_idx_in_perm | row_idx_within_permutation;
320				// Expand trace columns for each round on the row
321				for round_within_row in 0..ROUNDS_PER_STATE_ROW {
322					let keccakf_rc = KECCAKF_RC
323						[ROUNDS_PER_STATE_ROW * row_idx_within_permutation + round_within_row];
324
325					for x in 0..5 {
326						c_u64[round_within_row][x][state_row_idx] = (0..5).fold(0, |acc, y| {
327							acc ^ state_u64[round_within_row][x + 5 * y][state_row_idx]
328						});
329						c_shift_u64[round_within_row][x][state_row_idx] =
330							c_u64[round_within_row][x][state_row_idx].rotate_left(1);
331					}
332
333					for x in 0..5 {
334						d_u64[round_within_row][x][state_row_idx] = c_u64[round_within_row]
335							[(x + 4) % 5][state_row_idx]
336							^ c_shift_u64[round_within_row][(x + 1) % 5][state_row_idx];
337					}
338
339					for x in 0..5 {
340						for y in 0..5 {
341							a_theta_u64[round_within_row][x + 5 * y][state_row_idx] = state_u64
342								[round_within_row][x + 5 * y][state_row_idx]
343								^ d_u64[round_within_row][x][state_row_idx];
344						}
345					}
346
347					for xy in 0..STATE_SIZE {
348						b_u64[round_within_row][xy][state_row_idx] = a_theta_u64[round_within_row]
349							[PI[xy]][state_row_idx]
350							.rotate_left(RHO[xy]);
351					}
352
353					for x in 0..5 {
354						for y in 0..5 {
355							let b0 = b_u64[round_within_row][x + 5 * y][state_row_idx];
356							let b1 = b_u64[round_within_row][(x + 1) % 5 + 5 * y][state_row_idx];
357							let b2 = b_u64[round_within_row][(x + 2) % 5 + 5 * y][state_row_idx];
358
359							state_u64[round_within_row + 1][x + 5 * y][state_row_idx] =
360								b0 ^ (!b1 & b2);
361						}
362					}
363
364					state_u64[round_within_row + 1][0][state_row_idx] ^= keccakf_rc;
365				}
366
367				for (xy, packed_state_out_u64_row) in packed_state_out_u64.iter_mut().enumerate() {
368					packed_state_out_u64_row[state_row_idx] =
369						state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
370				}
371
372				if row_idx_within_permutation < (STATE_ROWS_PER_PERMUTATION - 1) {
373					for xy in 0..STATE_SIZE {
374						let this_row_output = state_u64[ROUNDS_PER_STATE_ROW][xy][state_row_idx];
375
376						state_u64[0][xy][state_row_idx + 1] = this_row_output;
377						packed_state_in_u64[xy][state_row_idx + 1] = this_row_output;
378						next_state_in_u64[xy][state_row_idx] = this_row_output;
379					}
380					selector_u64[state_row_idx] = u64::MAX;
381				}
382			}
383
384			let last_state_row_idx_in_perm =
385				first_state_row_idx_in_perm + (STATE_ROWS_PER_PERMUTATION - 1);
386
387			let actual_output_this_perm =
388				array::from_fn(|i| state_u64[ROUNDS_PER_STATE_ROW][i][last_state_row_idx_in_perm]);
389
390			assert_eq!(expected_output_this_perm, actual_output_this_perm);
391		}
392	}
393
394	let chi_iota = arith_expr!([s, b0, b1, b2, rc] = s - (rc + b0 + (1 - b1) * b2));
395	let chi = arith_expr!([s, b0, b1, b2] = s - (b0 + (1 - b1) * b2));
396	for x in 0..5 {
397		for y in 0..5 {
398			for round_within_row in 0..ROUNDS_PER_STATE_ROW {
399				let this_round_output = state[round_within_row + 1][x + 5 * y];
400
401				if x == 0 && y == 0 {
402					builder.assert_zero(
403						format!("chi_iota(round_within_row={round_within_row}, x={x}, y={y})"),
404						[
405							this_round_output,
406							b[round_within_row][x + 5 * y],
407							b[round_within_row][(x + 1) % 5 + 5 * y],
408							b[round_within_row][(x + 2) % 5 + 5 * y],
409							round_consts[round_within_row],
410						],
411						chi_iota.clone().convert_field(),
412					);
413				} else {
414					builder.assert_zero(
415						format!("chi(round_within_row={round_within_row}, x={x}, y={y})"),
416						[
417							this_round_output,
418							b[round_within_row][x + 5 * y],
419							b[round_within_row][(x + 1) % 5 + 5 * y],
420							b[round_within_row][(x + 2) % 5 + 5 * y],
421						],
422						chi.clone().convert_field(),
423					)
424				}
425			}
426		}
427	}
428
429	let selector_consistency =
430		arith_expr!([state_out, next_state_in, select] = (state_out - next_state_in) * select);
431
432	for xy in 0..STATE_SIZE {
433		builder.assert_zero(
434			format!("next_state_in_is_state_out_{xy}"),
435			[state_out[xy], next_state_in[xy], selector],
436			selector_consistency.clone().convert_field(),
437		)
438	}
439
440	Ok(KeccakfOracles { input, output })
441}
442
443#[inline]
444fn into_packed_vec<P>(src: &[impl Pod]) -> Vec<P>
445where
446	P: PackedField + WithUnderlier,
447	P::Underlier: Pod,
448{
449	pod_collect_to_vec::<_, P::Underlier>(src)
450		.into_iter()
451		.map(P::from_underlier)
452		.collect()
453}
454
455const STATE_SIZE: usize = 25;
456const LOG_STATE_ROWS_PER_PERMUTATION: usize = 3;
457const STATE_ROWS_PER_PERMUTATION: usize = 1 << LOG_STATE_ROWS_PER_PERMUTATION;
458const ROUNDS_PER_STATE_ROW: usize = 3;
459const LOG_BIT_ROWS_PER_STATE_ROW: usize = 6;
460const BIT_ROWS_PER_STATE_ROW: usize = 1 << LOG_BIT_ROWS_PER_STATE_ROW;
461const LOG_BIT_ROWS_PER_PERMUTATION: usize =
462	LOG_BIT_ROWS_PER_STATE_ROW + LOG_STATE_ROWS_PER_PERMUTATION;
463const BIT_ROWS_PER_PERMUTATION: usize = 1 << LOG_BIT_ROWS_PER_PERMUTATION;
464const ROUNDS_PER_PERMUTATION: usize = ROUNDS_PER_STATE_ROW * STATE_ROWS_PER_PERMUTATION;
465
466#[rustfmt::skip]
467const RHO: [u32; STATE_SIZE] = [
468	 0, 44, 43, 21, 14,
469	28, 20,  3, 45, 61,
470	 1,  6, 25,  8, 18,
471	27, 36, 10, 15, 56,
472	62, 55, 39, 41,  2,
473];
474
475#[rustfmt::skip]
476const PI: [usize; STATE_SIZE] = [
477	0, 6, 12, 18, 24,
478	3, 9, 10, 16, 22,
479	1, 7, 13, 19, 20,
480	4, 5, 11, 17, 23,
481	2, 8, 14, 15, 21,
482];
483
484const KECCAKF_RC: [u64; ROUNDS_PER_PERMUTATION] = [
485	0x0000000000000001,
486	0x0000000000008082,
487	0x800000000000808A,
488	0x8000000080008000,
489	0x000000000000808B,
490	0x0000000080000001,
491	0x8000000080008081,
492	0x8000000000008009,
493	0x000000000000008A,
494	0x0000000000000088,
495	0x0000000080008009,
496	0x000000008000000A,
497	0x000000008000808B,
498	0x800000000000008B,
499	0x8000000000008089,
500	0x8000000000008003,
501	0x8000000000008002,
502	0x8000000000000080,
503	0x000000000000800A,
504	0x800000008000000A,
505	0x8000000080008081,
506	0x8000000000008080,
507	0x0000000080000001,
508	0x8000000080008008,
509];
510
511#[cfg(test)]
512mod tests {
513	use rand::{rngs::StdRng, Rng, SeedableRng};
514
515	use super::{keccakf, KeccakfState};
516	use crate::builder::test_utils::test_circuit;
517
518	#[test]
519	fn test_keccakf() {
520		test_circuit(|builder| {
521			let log_size = 5;
522			let mut rng = StdRng::seed_from_u64(0);
523			let input_states = vec![KeccakfState(rng.gen())];
524			let _state_out = keccakf(builder, &Some(input_states), log_size)?;
525			Ok(vec![])
526		})
527		.unwrap();
528	}
529}