binius_m3/gadgets/hash/keccak/
lookedup.rs

1// Copyright 2025 Irreducible Inc.
2
3//! The version of the Keccakf permutation taking the arithmetization approach that is based on
4//! stacked columns and lookups.
5
6// This implementation tries to be as close to the
7// [Keccak Specification Summary][keccak_spec_summary] and as such it is highly recommended to
8// get familiar with it. A lot of terminology is carried over from that spec.
9//
10// [keccak_spec_summary]: https://keccak.team/keccak_specs_summary.html
11
12use std::array;
13
14use anyhow::Result;
15use array_util::ArrayExt as _;
16use binius_core::{constraint_system::channel::ChannelId, oracle::ShiftVariant};
17use binius_field::{
18	Field, PackedExtension, PackedFieldIndexable, PackedSubfield, TowerField,
19	linear_transformation::PackedTransformationFactory,
20};
21use trace::PermutationTrace;
22
23use super::{
24	RC, RHO,
25	state::{StateMatrix, StateRow},
26	trace,
27};
28use crate::{
29	builder::{B1, B8, B32, B64, B128, Col, Expr, TableBuilder, TableWitnessSegment, upcast_expr},
30	gadgets::indexed_lookup::and::{merge_and_columns, merge_bitand_vals},
31};
32
33/// 8x 64-bit lanes packed[^packed] in the single column.
34///
35/// For the motivation see [`KeccakfLookedup`] documentation.
36///
37/// [^packed]: here it means in the SIMD sense, not as in packed columns.
38pub type PackedLane8 = Col<B1, { 64 * 8 }>;
39
40const BATCHES_PER_PERMUTATION: usize = 3;
41const TRACKS_PER_BATCH: usize = 8;
42const STATE_IN_TRACK: usize = 0;
43const STATE_OUT_TRACK: usize = 7;
44
45/// Keccak-f\[1600\] permutation function verification gadget.
46///
47/// This gadget consists of 3x horizontally combined batches of 8x rounds each, 24 rounds in total.
48/// You can think about it as 8x wide SIMD performing one permutation per a table row. Below is
49/// the graphical representation of the layout.
50///
51/// ```plain
52/// | Batch 0  | Batch 1  | Batch 3  |
53/// |----------|----------|----------|
54/// | Round 00 | Round 01 | Round 02 |
55/// | Round 03 | Round 04 | Round 05 |
56/// | Round 06 | Round 07 | Round 08 |
57/// | Round 09 | Round 10 | Round 11 |
58/// | Round 12 | Round 13 | Round 14 |
59/// | Round 15 | Round 16 | Round 17 |
60/// | Round 18 | Round 19 | Round 20 |
61/// | Round 21 | Round 22 | Round 23 |
62/// ```
63///
64/// We refer to each individual round within a batch as a **track**. For example, the 7th (
65/// zero-based here and henceforth) track of the 1st batch is responsible for the 22nd round.
66///
67/// Each batch exposes two notable columns: `state_in` and `state` which are inputs and outputs
68/// respectively for the rounds in each batch. Both of those has the type of [`StateMatrix`]
69/// containing [`PackedLane8`]. Let's break those down.
70///
71/// [`StateMatrix`] is a concept coming from the keccak which represents a 5x5 matrix. In keccak
72/// each cell is a 64-bit integer called lane. In our case however, since the SIMD-like approach,
73/// each cell is represented by a pack of columns - one for each track and this is what
74/// [`PackedLane8`] represents.
75///
76/// To feed the input to permutation, you need to initialize the `state_in` column of the 0th batch
77/// with the input state matrix. See [`Self::populate_state_in`] if you have values handy.
78pub struct KeccakfLookedup {
79	batches: [LookedupRoundBatch; BATCHES_PER_PERMUTATION],
80	/// The lanes of the input and output state columns. These are exposed to make it convenient to
81	/// use the gadget along with flushing.
82	pub input: StateMatrix<Col<B64>>,
83	pub output: StateMatrix<Col<B64>>,
84
85	/// Represents a variation of the `state_in` state matrix of the 0th batch where each track is
86	/// shifted in place of previous one, meaning that the 0th track will store the `state_in` for
87	/// the 3rd round.
88	///
89	/// This is used for the state-in to state-out linking rule.
90	next_state_in: StateMatrix<Col<B64, TRACKS_PER_BATCH>>,
91
92	/// Link selector.
93	///
94	/// This is all ones for the first 7 tracks and all zeroes for the last one.
95	///
96	/// Used to turn off the state-in to state-out forwarding check for the last track.
97	link_sel: Col<B1, TRACKS_PER_BATCH>,
98}
99
100impl KeccakfLookedup {
101	/// Creates a new instance of the gadget.
102	///
103	/// See the struct documentation for more details.
104	pub fn new(table: &mut TableBuilder, lookup_chan: ChannelId) -> Self {
105		let state_in: StateMatrix<PackedLane8> =
106			StateMatrix::from_fn(|(x, y)| table.add_committed(format!("state_in[{x},{y}]")));
107
108		let mut state = state_in;
109
110		// Declaring packed state_in columns for exposing in the struct.
111		let state_in_packed: StateMatrix<Col<B64, 8>> = StateMatrix::from_fn(|(x, y)| {
112			table.add_packed(format!("state_in_packed[{x},{y}]"), state[(x, y)])
113		});
114
115		// Constructing the batches of rounds. The final value of `state` will be the permutation
116		// output.
117		let batches = array::from_fn(|batch_no| {
118			let batch = LookedupRoundBatch::new(
119				&mut table.with_namespace(format!("batch[{batch_no}]")),
120				state.clone(),
121				lookup_chan,
122				batch_no,
123			);
124			state = batch.state_out.clone();
125			batch
126		});
127
128		// Declaring packed state_out columns to be exposed in the struct.
129		let state_out_packed: StateMatrix<Col<B64, 8>> = StateMatrix::from_fn(|(x, y)| {
130			table.add_packed(format!("state_out_packed[{x},{y}]"), state[(x, y)])
131		});
132
133		let input = StateMatrix::from_fn(|(x, y)| {
134			table.add_selected(format!("input[{x},{y}]"), state_in_packed[(x, y)], 0)
135		});
136
137		let output = StateMatrix::from_fn(|(x, y)| {
138			table.add_selected(format!("output[{x},{y}]"), state_out_packed[(x, y)], 7)
139		});
140
141		let link_sel = table.add_constant(
142			"link_sel",
143			array::from_fn(|bit_index| if bit_index < 7 { B1::ONE } else { B1::ZERO }),
144		);
145		let next_state_in = StateMatrix::from_fn(|(x, y)| {
146			table.add_shifted(
147				format!("next_state_in[{x},{y}]"),
148				state_in_packed[(x, y)],
149				TRACKS_PER_BATCH.ilog2() as usize,
150				1,
151				ShiftVariant::LogicalRight,
152			)
153		});
154		for x in 0..5 {
155			for y in 0..5 {
156				table.assert_zero(
157					"link_out_to_next_in",
158					(state_out_packed[(x, y)] - next_state_in[(x, y)])
159						* upcast_expr(Expr::from(link_sel)),
160				);
161			}
162		}
163		Self {
164			batches,
165			next_state_in,
166			input,
167			output,
168			link_sel,
169		}
170	}
171
172	/// Populate the gadget.
173	///
174	/// Requires state in already to be populated. To populate with known values use
175	/// [`Self::populate_state_in`].
176	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
177	where
178		P: PackedFieldIndexable<Scalar = B128>
179			+ PackedExtension<B1>
180			+ PackedExtension<B8>
181			+ PackedExtension<B32>
182			+ PackedExtension<B64>,
183		PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
184		PackedSubfield<P, B32>: PackedFieldIndexable<Scalar = B32>,
185	{
186		// `state_in` for the first track of the first batch specifies the initial state for
187		// permutation. Read it out, gather trace and populate each batch.
188		let permutation_traces = self.batches[0]
189			.read_state_ins(index, 0)?
190			.map(trace::keccakf_trace)
191			.collect::<Vec<PermutationTrace>>();
192		for batch in &self.batches {
193			batch.populate(index, &permutation_traces)?;
194		}
195
196		for k in 0..permutation_traces.len() {
197			for x in 0..5 {
198				for y in 0..5 {
199					let mut next_state_in: std::cell::RefMut<'_, [u64]> =
200						index.get_mut_as(self.next_state_in[(x, y)])?;
201					let batch_0_state_in: std::cell::Ref<'_, [u64]> =
202						index.get_as(self.batches[0].state_in[(x, y)])?;
203					let batch_2_state_out: std::cell::Ref<'_, [u64]> =
204						index.get_as(self.batches[2].state_out[(x, y)])?;
205					for track in 0..TRACKS_PER_BATCH - 1 {
206						next_state_in[TRACKS_PER_BATCH * k + track] =
207							batch_0_state_in[TRACKS_PER_BATCH * k + track + 1];
208						assert_eq!(
209							next_state_in[TRACKS_PER_BATCH * k + track],
210							batch_2_state_out[TRACKS_PER_BATCH * k + track]
211						);
212					}
213					// Populating the packed and selected input and output columns.
214					let mut input: std::cell::RefMut<'_, [u64]> =
215						index.get_mut_as(self.input[(x, y)])?;
216					let mut output: std::cell::RefMut<'_, [u64]> =
217						index.get_mut_as(self.output[(x, y)])?;
218
219					input[k] = permutation_traces[k].input()[(x, y)];
220					output[k] = permutation_traces[k].output()[(x, y)];
221				}
222			}
223		}
224
225		{
226			let mut link_sel: std::cell::RefMut<'_, [u8]> = index.get_mut_as(self.link_sel)?;
227			for link_sel_i in link_sel.iter_mut() {
228				*link_sel_i = 0x7f;
229			}
230		}
231
232		Ok(())
233	}
234
235	/// Returns the `state_in` column for the 0th batch. The input to the permutation is at the
236	/// 0th track.
237	pub fn packed_state_in(&self) -> &StateMatrix<PackedLane8> {
238		&self.batches[0].state_in
239	}
240
241	/// Returns the `state_out` column for the 2nd batch. The output of the permutation is at the
242	/// 7th track.
243	pub fn packed_state_out(&self) -> &StateMatrix<PackedLane8> {
244		&self.batches[2].state_out
245	}
246
247	/// Populate the input state of the permutation.
248	pub fn populate_state_in<'a, P>(
249		&self,
250		index: &mut TableWitnessSegment<P>,
251		state_ins: impl IntoIterator<Item = &'a StateMatrix<u64>>,
252	) -> Result<()>
253	where
254		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
255		PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
256	{
257		self.batches[0].populate_state_in(index, STATE_IN_TRACK, state_ins)?;
258		Ok(())
259	}
260
261	/// Read the resulting states of permutation, one item per row.
262	///
263	/// Only makes sense to call after [`Self::populate`] was called.
264	pub fn read_state_outs<'a, P>(
265		&self,
266		index: &'a TableWitnessSegment<P>,
267	) -> Result<impl Iterator<Item = StateMatrix<u64>> + 'a>
268	where
269		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
270		PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
271	{
272		self.batches[2].read_state_outs(index, STATE_OUT_TRACK)
273	}
274}
275
276/// A gadget of a batch of keccak-f\[1600\] permutation rounds.
277///
278/// This batch runs 8 rounds of keccak-f. Since SHA3-256 is defined to have 24 rounds, you would
279/// need to use 3 of these gadgets to implement a full permutation.
280struct LookedupRoundBatch {
281	batch_no: usize,
282	state_in: StateMatrix<PackedLane8>,
283	state_out: StateMatrix<PackedLane8>,
284	c: StateRow<PackedLane8>,
285	c_shift: StateRow<PackedLane8>,
286	d: StateRow<PackedLane8>,
287	a_theta: StateMatrix<PackedLane8>,
288	b: StateMatrix<PackedLane8>,
289	b1_and_b2: StateMatrix<PackedLane8>,
290	merged: StateMatrix<Col<B32, 64>>,
291	round_const: PackedLane8,
292}
293
294impl LookedupRoundBatch {
295	fn new(
296		table: &mut TableBuilder,
297		state_in: StateMatrix<PackedLane8>,
298		lookup_chan: ChannelId,
299		batch_no: usize,
300	) -> Self {
301		assert!(batch_no < BATCHES_PER_PERMUTATION);
302		// # θ step
303		//
304		// for x in 0…4:
305		//   C[x] = A[x,0] xor A[x,1] xor A[x,2] xor A[x,3] xor A[x,4],
306		//   D[x] = C[x-1] xor rot(C[x+1],1),
307		let c = StateRow::from_fn(|x| {
308			table.add_computed(
309				format!("c[{x}]"),
310				sum_expr(array::from_fn::<_, 5, _>(|offset| state_in[(x, offset)])),
311			)
312		});
313		let c_shift = StateRow::from_fn(|x| {
314			table.add_shifted(format!("c[{x}]"), c[x], 6, 1, ShiftVariant::CircularLeft)
315		});
316		let d =
317			StateRow::from_fn(|x| table.add_computed(format!("d[{x}]"), c[x + 4] + c_shift[x + 1]));
318
319		// for (x,y) in (0…4,0…4):
320		//   A[x,y] = A[x,y] xor D[x]
321		let a_theta = StateMatrix::from_fn(|(x, y)| {
322			table.add_computed(format!("a_theta[{x},{y}]"), state_in[(x, y)] + d[x])
323		});
324
325		// # ρ and π steps
326		let b = StateMatrix::from_fn(|(x, y)| {
327			if (x, y) == (0, 0) {
328				a_theta[(0, 0)]
329			} else {
330				const INV2: usize = 3;
331				let dy = x; // 0‥4
332				let tmp = (y + 2 * x) % 5;
333				let dx = (INV2 * tmp) % 5;
334				let rot = RHO[dx][dy] as usize;
335				const LOG2_64: usize = 6;
336				table.add_shifted(
337					format!("b[{x},{y}]"),
338					a_theta[(dx, dy)],
339					LOG2_64,
340					rot,
341					ShiftVariant::CircularLeft,
342				)
343			}
344		});
345
346		// The columns need to be packed to compute bitwise and using lookups.
347		let b_packed: StateMatrix<Col<B8, 64>> = StateMatrix::from_fn(|(x, y)| {
348			table.add_packed(format!("b_packed[{x},{y}]"), b[(x, y)])
349		});
350
351		let b1_and_b2: StateMatrix<Col<B1, { 64 * 8 }>> = StateMatrix::from_fn(|(x, y)| {
352			// B2[x,y] = B[x+2,y] and not B[x+1,y]
353			table.add_committed(format!("b1&b2[{x},{y}]"))
354		});
355
356		let b1_and_b2_packed: StateMatrix<Col<B8, 64>> = StateMatrix::from_fn(|(x, y)| {
357			// B2[x,y] = B[x+2,y] and not B[x+1,y]
358			table.add_packed(format!("b1&b2_packed[{x},{y}]"), b1_and_b2[(x, y)])
359		});
360
361		let merged = StateMatrix::from_fn(|(x, y)| {
362			let b2 = b_packed[(x + 2, y)];
363			let b1 = b_packed[(x + 1, y)];
364			let b1_and_b2 = b1_and_b2_packed[(x, y)];
365			let col = merge_and_columns(table, b1, b2, b1_and_b2);
366			table.read(lookup_chan, [col]);
367			col
368		});
369
370		let round_const = table
371			.add_constant(format!("round_const[{batch_no}]"), round_consts_for_batch(batch_no));
372
373		let state_out = StateMatrix::from_fn(|(x, y)| {
374			if (x, y) == (0, 0) {
375				table.add_computed(
376					format!("chi_iota[{x},{y}]"),
377					round_const + b[(x, y)] + b[(x + 2, y)] + b1_and_b2[(x, y)],
378				)
379			} else {
380				table.add_computed(
381					format!("chi[{x},{y}]"),
382					b[(x, y)] + b[(x + 2, y)] + b1_and_b2[(x, y)],
383				)
384			}
385		});
386
387		Self {
388			batch_no,
389			state_in,
390			state_out,
391			c,
392			c_shift,
393			d,
394			a_theta,
395			b,
396			b1_and_b2,
397			merged,
398			round_const,
399		}
400	}
401
402	fn populate<P>(
403		&self,
404		index: &mut TableWitnessSegment<P>,
405		permutation_traces: &[trace::PermutationTrace],
406	) -> Result<()>
407	where
408		P: PackedFieldIndexable<Scalar = B128>
409			+ PackedExtension<B1>
410			+ PackedExtension<B8>
411			+ PackedExtension<B32>,
412		PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
413		PackedSubfield<P, B32>: PackedFieldIndexable<Scalar = B32>,
414	{
415		for (k, trace) in permutation_traces.iter().enumerate() {
416			// Gather all batch round traces for the batch number.
417			let brt = PerBatchLens::new(trace, self.batch_no);
418
419			// Fill in round_const witness with the corresponding round constants
420			let mut round_const: std::cell::RefMut<'_, [u64]> =
421				index.get_mut_as(self.round_const)?;
422			for track in 0..TRACKS_PER_BATCH {
423				round_const[TRACKS_PER_BATCH * k + track] = brt[track].rc;
424			}
425			drop(round_const);
426
427			for x in 0..5 {
428				let mut c: std::cell::RefMut<'_, [u64]> = index.get_mut_as(self.c[x])?;
429				let mut c_shift: std::cell::RefMut<'_, [u64]> =
430					index.get_mut_as(self.c_shift[x])?;
431				let mut d: std::cell::RefMut<'_, [u64]> = index.get_mut_as(self.d[x])?;
432				for track in 0..TRACKS_PER_BATCH {
433					let cell_pos = TRACKS_PER_BATCH * k + track;
434					c[cell_pos] = brt[track].c[x];
435					c_shift[cell_pos] = c[cell_pos].rotate_left(1);
436					d[cell_pos] = brt[track].d[x];
437				}
438
439				for y in 0..5 {
440					let mut state_in: std::cell::RefMut<'_, [u64]> =
441						index.get_mut_as(self.state_in[(x, y)])?;
442					let mut state_out: std::cell::RefMut<'_, [u64]> =
443						index.get_mut_as(self.state_out[(x, y)])?;
444					let mut a_theta: std::cell::RefMut<'_, [u64]> =
445						index.get_mut_as(self.a_theta[(x, y)])?;
446
447					// b[0,0] is defined as a_theta[0,0].
448					//
449					// That means two things:
450					// 1. The value is assigned by `a_theta`.
451					// 2. We have to skip mutably borrowing it here because that would overlap with
452					//    the mutable borrow of `a_theta` above.
453					let mut b: Option<std::cell::RefMut<'_, [u64]>> = if (x, y) != (0, 0) {
454						Some(index.get_mut_as(self.b[(x, y)])?)
455					} else {
456						None
457					};
458					for track in 0..TRACKS_PER_BATCH {
459						let cell_pos = TRACKS_PER_BATCH * k + track;
460						state_in[cell_pos] = brt[track].state_in[(x, y)];
461						state_out[cell_pos] = brt[track].state_out[(x, y)];
462						a_theta[cell_pos] = brt[track].a_theta[(x, y)];
463						if let Some(ref mut b) = b {
464							b[cell_pos] = brt[track].b[(x, y)];
465						}
466					}
467				}
468			}
469		}
470		for xy in 0..25 {
471			let x = xy % 5;
472			let y = xy / 5;
473			let mut merged: std::cell::RefMut<'_, [B32]> =
474				index.get_scalars_mut(self.merged[(x, y)])?;
475
476			let b1: std::cell::Ref<'_, [B8]> = index.get_as(self.b[(x + 1, y)])?;
477
478			let b2: std::cell::Ref<'_, [B8]> = index.get_as(self.b[(x + 2, y)])?;
479			let mut b1_and_b2: std::cell::RefMut<'_, [B8]> =
480				index.get_mut_as(self.b1_and_b2[(x, y)])?;
481			for i in 0..b2.len() {
482				// B[x,y] xor ((not B[x+1,y]) and B[x+2,y])
483				let in_a = b1[i].val();
484				let in_b = b2[i].val();
485				let output = in_a & in_b;
486				b1_and_b2[i] = output.into();
487				merged[i] = merge_bitand_vals(in_a, in_b, output).into();
488			}
489		}
490		Ok(())
491	}
492
493	fn populate_state_in<'a, P>(
494		&self,
495		index: &mut TableWitnessSegment<P>,
496		track: usize,
497		state_ins: impl IntoIterator<Item = &'a StateMatrix<u64>>,
498	) -> Result<()>
499	where
500		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
501		PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
502	{
503		for (k, state_in) in state_ins.into_iter().enumerate() {
504			for x in 0..5 {
505				for y in 0..5 {
506					let mut state_in_data: std::cell::RefMut<'_, [u64]> =
507						index.get_mut_as(self.state_in[(x, y)])?;
508					state_in_data[TRACKS_PER_BATCH * k + track] = state_in[(x, y)];
509				}
510			}
511		}
512		Ok(())
513	}
514
515	fn read_state_ins<'a, P>(
516		&self,
517		index: &'a TableWitnessSegment<P>,
518		track: usize,
519	) -> Result<impl Iterator<Item = StateMatrix<u64>> + 'a>
520	where
521		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
522		PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
523	{
524		let state_in = self
525			.state_in
526			.as_inner()
527			.try_map_ext(|col| index.get_mut_as(col))?;
528
529		let iter = (0..index.size()).map(move |k| {
530			StateMatrix::from_fn(|(x, y)| state_in[x + 5 * y][TRACKS_PER_BATCH * k + track])
531		});
532		Ok(iter)
533	}
534
535	fn read_state_outs<'a, P>(
536		&self,
537		index: &'a TableWitnessSegment<P>,
538		track: usize,
539	) -> Result<impl Iterator<Item = StateMatrix<u64>> + 'a>
540	where
541		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1> + PackedExtension<B8>,
542		PackedSubfield<P, B8>: PackedTransformationFactory<PackedSubfield<P, B8>>,
543	{
544		let state_out = self
545			.state_out
546			.as_inner()
547			.try_map_ext(|col| index.get_mut_as(col))?;
548		let iter = (0..index.size()).map(move |k| {
549			StateMatrix::from_fn(|(x, y)| state_out[x + 5 * y][TRACKS_PER_BATCH * k + track])
550		});
551		Ok(iter)
552	}
553}
554
555/// Returns an expression representing a sum over the given values.
556fn sum_expr<F, const V: usize, const N: usize>(values: [Col<F, V>; N]) -> Expr<F, V>
557where
558	F: TowerField,
559{
560	assert!(!values.is_empty());
561	let mut expr: Expr<F, V> = values[0].into();
562	for value in &values[1..] {
563		expr = expr + *value;
564	}
565	expr
566}
567
568/// Returns the RC for every round/track in the given batch.
569///
570/// The return type is basically 8 tracks of 64-bit round constants represented as bit patterns.
571fn round_consts_for_batch(batch_no: usize) -> [B1; 64 * 8] {
572	assert!(batch_no < BATCHES_PER_PERMUTATION);
573	let mut batch_rc = [B1::from(0); 64 * 8];
574	for track in 0..TRACKS_PER_BATCH {
575		let rc = RC[nth_round_per_batch(batch_no, track)];
576		for bit in 0..64 {
577			let bit_value = ((rc >> bit) & 1) as u8;
578			batch_rc[track * 64 + bit] = B1::from(bit_value);
579		}
580	}
581	batch_rc
582}
583
584/// Calculates the round number that is performed at the `nth` track of the `batch_no` batch.
585fn nth_round_per_batch(batch_no: usize, nth: usize) -> usize {
586	assert!(batch_no < BATCHES_PER_PERMUTATION);
587	assert!(nth < TRACKS_PER_BATCH);
588	nth * BATCHES_PER_PERMUTATION + batch_no
589}
590
591/// This is specialization of [`PermutationTrace`] for a particular batch.
592/// See [`PermutationTrace::per_batch`].
593struct PerBatchLens<'a> {
594	pt: &'a PermutationTrace,
595	batch_no: usize,
596}
597
598impl PerBatchLens<'_> {
599	/// View the trace from the given `batch_no` lens.
600	fn new(pt: &PermutationTrace, batch_no: usize) -> PerBatchLens<'_> {
601		PerBatchLens { pt, batch_no }
602	}
603}
604
605impl std::ops::Index<usize> for PerBatchLens<'_> {
606	type Output = trace::RoundTrace;
607	fn index(&self, index: usize) -> &Self::Output {
608		let round = nth_round_per_batch(self.batch_no, index);
609		&self.pt[round]
610	}
611}
612
613#[cfg(test)]
614mod tests {
615	use std::cmp::Reverse;
616
617	use binius_compute::cpu::alloc::CpuComputeAllocator;
618	use binius_field::{arch::OptimalUnderlier, as_packed_field::PackedType};
619	use itertools::Itertools;
620
621	use super::*;
622	use crate::{
623		builder::{ConstraintSystem, WitnessIndex, tally},
624		gadgets::{
625			hash::keccak::test_vector::TEST_VECTOR,
626			indexed_lookup::and::{BitAndIndexedLookup, BitAndLookup},
627		},
628	};
629
630	#[test]
631	fn ensure_committed_bits_per_row() {
632		let mut cs = ConstraintSystem::new();
633		let lookup_chan = cs.add_channel("lookup_channel");
634		let mut table = cs.add_table("stacked permutation");
635		let _ = KeccakfLookedup::new(&mut table, lookup_chan);
636		let id = table.id();
637		let stat = cs.tables[id].stat();
638		assert_eq!(stat.bits_per_row_committed(), 51200);
639	}
640
641	#[test]
642	fn test_round_gadget() {
643		const N_ROWS: usize = 1;
644
645		let mut cs = ConstraintSystem::new();
646		let lookup_chan = cs.add_channel("lookup_channel");
647		let permutation = cs.add_channel("permutation_channel");
648
649		let mut lookup = cs.add_table("bitand_lookup");
650		let bitand_lookup = BitAndLookup::new(&mut lookup, lookup_chan, permutation, 20);
651		let mut table = cs.add_table("test");
652
653		let state_in = StateMatrix::from_fn(|(x, y)| table.add_committed(format!("in[{x},{y}]")));
654		let rb = LookedupRoundBatch::new(&mut table, state_in, lookup_chan, 0);
655
656		let mut allocator = CpuComputeAllocator::new(1 << 16);
657		let allocator = allocator.into_bump_allocator();
658		let table_id = table.id();
659
660		let mut witness = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
661		let table_witness = witness.init_table(table_id, N_ROWS).unwrap();
662		let mut segment = table_witness.full_segment();
663
664		let trace = trace::keccakf_trace(StateMatrix::default());
665		rb.populate(&mut segment, &[trace]).unwrap();
666
667		let counts = tally(&cs, &mut witness, &[], lookup_chan, &BitAndIndexedLookup).unwrap();
668		// Fill the lookup table with the sorted counts
669		let sorted_counts = counts
670			.into_iter()
671			.enumerate()
672			.sorted_by_key(|(_, count)| Reverse(*count))
673			.collect::<Vec<_>>();
674
675		witness
676			.fill_table_parallel(&bitand_lookup, &sorted_counts)
677			.unwrap();
678		let table_sizes = witness.table_sizes();
679		let ccs = cs.compile().unwrap();
680		let witness = witness.into_multilinear_extension_index();
681
682		binius_core::constraint_system::validate::validate_witness(
683			&ccs,
684			&[],
685			&table_sizes,
686			&witness,
687		)
688		.unwrap();
689	}
690
691	#[test]
692	fn test_permutation() {
693		const N_ROWS: usize = TEST_VECTOR.len();
694
695		let mut cs = ConstraintSystem::new();
696		let lookup_chan = cs.add_channel("lookup_channel");
697		let permutation = cs.add_channel("permutation_channel");
698		let mut lookup = cs.add_table("bitand_lookup");
699		let bitand_lookup = BitAndLookup::new(&mut lookup, lookup_chan, permutation, 20);
700		let mut table = cs.add_table("test");
701		let keccakf = KeccakfLookedup::new(&mut table, lookup_chan);
702
703		let mut allocator = CpuComputeAllocator::new(1 << 17);
704		let allocator = allocator.into_bump_allocator();
705		let table_id = table.id();
706
707		let mut witness = WitnessIndex::<PackedType<OptimalUnderlier, B128>>::new(&cs, &allocator);
708		let table_witness = witness.init_table(table_id, N_ROWS).unwrap();
709		let mut segment = table_witness.full_segment();
710
711		let state_ins = TEST_VECTOR
712			.iter()
713			.map(|&[state_in, _]| StateMatrix::from_values(state_in))
714			.collect::<Vec<_>>();
715
716		keccakf.populate_state_in(&mut segment, &state_ins).unwrap();
717		keccakf.populate(&mut segment).unwrap();
718
719		let state_outs = keccakf
720			.read_state_outs(&segment)
721			.unwrap()
722			.collect::<Vec<_>>();
723		for (i, actual_out) in state_outs.iter().enumerate() {
724			let expected_out = StateMatrix::from_values(TEST_VECTOR[i][1]);
725			if *actual_out != expected_out {
726				panic!("Mismatch at index {i}: expected {expected_out:#?}, got {actual_out:#?}",);
727			}
728		}
729
730		let counts = tally(&cs, &mut witness, &[], lookup_chan, &BitAndIndexedLookup).unwrap();
731		// Fill the lookup table with the sorted counts
732		let sorted_counts = counts
733			.into_iter()
734			.enumerate()
735			.sorted_by_key(|(_, count)| Reverse(*count))
736			.collect::<Vec<_>>();
737
738		witness
739			.fill_table_parallel(&bitand_lookup, &sorted_counts)
740			.unwrap();
741
742		let table_sizes = witness.table_sizes();
743		let ccs = cs.compile().unwrap();
744		let witness = witness.into_multilinear_extension_index();
745
746		binius_core::constraint_system::validate::validate_witness(
747			&ccs,
748			&[],
749			&table_sizes,
750			&witness,
751		)
752		.unwrap();
753	}
754}