binius_m3/gadgets/indexed_lookup/
incr.rs

1// Copyright 2025 Irreducible Inc.
2
3//! This module provides gadgets for performing indexed lookup operations for incrementing
4//! 8-bit values with carry, using lookup tables. It includes types and functions for
5//! constructing, populating, and testing increment lookup tables and their associated
6//! circuits.
7use std::{iter, slice};
8
9use binius_core::constraint_system::channel::ChannelId;
10use binius_field::{
11	PackedExtension, PackedFieldIndexable, ext_basis,
12	packed::{get_packed_slice, set_packed_slice},
13};
14use binius_math::{ArithCircuit, ArithExpr};
15
16use crate::{
17	builder::{
18		B1, B8, B32, B128, Col, IndexedLookup, TableBuilder, TableFiller, TableId,
19		TableWitnessSegment, upcast_col,
20	},
21	gadgets::lookup::LookupProducer,
22};
23
24/// Represents an increment operation with carry in a lookup table.
25///
26/// This struct holds columns for an 8-bit increment operation where:
27/// - `input` is the 8-bit value to be incremented
28/// - `carry_in` is a 1-bit carry input
29/// - `output` is the 8-bit result of the increment
30/// - `carry_out` is a 1-bit carry output
31/// - `merged` is a 32-bit encoding of all inputs and outputs for lookup
32///
33/// The increment operation computes: output = input + carry_in, with carry_out
34/// set if the result overflows 8 bits.
35pub struct Incr {
36	pub input: Col<B8>,
37	pub carry_in: Col<B1>,
38	pub output: Col<B8>,
39	pub carry_out: Col<B1>,
40	pub merged: Col<B32>,
41}
42
43impl Incr {
44	/// Constructs a new increment gadget, registering the necessary columns in the table.
45	///
46	/// # Arguments
47	/// * `table` - The table builder to register columns with.
48	/// * `lookup_chan` - The channel for lookup operations.
49	/// * `input` - The input column (8 bits).
50	/// * `carry_in` - The carry-in column (1 bit).
51	///
52	/// # Returns
53	/// An `Incr` struct with all columns set up.
54	pub fn new(
55		table: &mut TableBuilder,
56		lookup_chan: ChannelId,
57		input: Col<B8>,
58		carry_in: Col<B1>,
59	) -> Self {
60		let output = table.add_committed::<B8, 1>("output");
61		let carry_out = table.add_committed::<B1, 1>("carry_out");
62		let merged = merge_incr_cols(table, input, carry_in, output, carry_out);
63
64		table.pull(lookup_chan, [merged]);
65
66		Self {
67			input,
68			carry_in,
69			output,
70			carry_out,
71			merged,
72		}
73	}
74
75	/// Populates the witness segment for this increment operation.
76	///
77	/// # Arguments
78	/// * `witness` - The witness segment to populate.
79	///
80	/// # Returns
81	/// `Ok(())` if successful, or an error otherwise.
82	pub fn populate<P>(&self, witness: &mut TableWitnessSegment<P>) -> anyhow::Result<()>
83	where
84		P: PackedFieldIndexable<Scalar = B128>
85			+ PackedExtension<B1>
86			+ PackedExtension<B8>
87			+ PackedExtension<B32>,
88	{
89		let input = witness.get_as::<u8, _, 1>(self.input)?;
90		let carry_in = witness.get(self.carry_in)?;
91		let mut output = witness.get_mut_as::<u8, _, 1>(self.output)?;
92		let mut carry_out = witness.get_mut(self.carry_out)?;
93		let mut merged = witness.get_mut_as::<u32, _, 1>(self.merged)?;
94
95		for i in 0..witness.size() {
96			let input_i = input[i];
97			let carry_in_bit = bool::from(get_packed_slice(&carry_in, i).val());
98
99			let (output_i, carry_out_bit) = input_i.overflowing_add(carry_in_bit.into());
100			output[i] = output_i;
101			set_packed_slice(&mut carry_out, i, B1::from(carry_out_bit));
102			merged[i] = ((carry_out_bit as u32) << 17)
103				| ((carry_in_bit as u32) << 16)
104				| ((output_i as u32) << 8)
105				| input_i as u32;
106		}
107
108		Ok(())
109	}
110}
111
112/// Helper struct for producing increment lookups from input/carry pairs.
113pub struct IncrLooker {
114	pub input: Col<B8>,
115	pub carry_in: Col<B1>,
116	incr: Incr,
117}
118
119impl IncrLooker {
120	/// Constructs a new increment looker, registering columns in the table.
121	pub fn new(table: &mut TableBuilder, lookup_chan: ChannelId) -> Self {
122		let input = table.add_committed::<B8, 1>("input");
123		let carry_in = table.add_committed::<B1, 1>("carry_in");
124		let incr = Incr::new(table, lookup_chan, input, carry_in);
125
126		Self {
127			input,
128			carry_in,
129			incr,
130		}
131	}
132
133	/// Populates the witness segment for a sequence of (input, carry_in) events.
134	pub fn populate<'a, P>(
135		&self,
136		witness: &mut TableWitnessSegment<P>,
137		events: impl Iterator<Item = &'a (u8, bool)>,
138	) -> anyhow::Result<()>
139	where
140		P: PackedFieldIndexable<Scalar = B128>
141			+ PackedExtension<B1>
142			+ PackedExtension<B8>
143			+ PackedExtension<B32>,
144	{
145		{
146			let mut input = witness.get_mut_as::<u8, _, 1>(self.input)?;
147			let mut carry_in = witness.get_mut(self.carry_in)?;
148
149			for (i, &(input_i, carry_in_bit)) in events.enumerate() {
150				input[i] = input_i;
151				set_packed_slice(&mut carry_in, i, B1::from(carry_in_bit));
152			}
153		}
154
155		self.incr.populate(witness)?;
156		Ok(())
157	}
158}
159
160/// Represents the increment lookup table, supporting filling and permutation checks.
161pub struct IncrLookup {
162	table_id: TableId,
163	entries_ordered: Col<B32>,
164	entries_sorted: Col<B32>,
165	lookup_producer: LookupProducer,
166}
167
168impl IncrLookup {
169	/// Constructs a new increment lookup table.
170	///
171	/// # Arguments
172	/// * `table` - The table builder.
173	/// * `chan` - The lookup channel.
174	/// * `permutation_chan` - The channel for permutation checks.
175	/// * `n_multiplicity_bits` - Number of bits for multiplicity.
176	pub fn new(
177		table: &mut TableBuilder,
178		chan: ChannelId,
179		permutation_chan: ChannelId,
180		n_multiplicity_bits: usize,
181	) -> Self {
182		table.require_fixed_size(IncrIndexedLookup.log_size());
183
184		// The entries_ordered column is the one that is filled with the lookup table entries.
185		let entries_ordered = table.add_fixed("incr_lookup", incr_circuit());
186		let entries_sorted = table.add_committed::<B32, 1>("entries_sorted");
187
188		// Use flush to check that entries_sorted is a permutation of entries_ordered.
189		table.push(permutation_chan, [entries_ordered]);
190		table.pull(permutation_chan, [entries_sorted]);
191
192		let lookup_producer =
193			LookupProducer::new(table, chan, &[entries_sorted], n_multiplicity_bits);
194		Self {
195			table_id: table.id(),
196			entries_ordered,
197			entries_sorted,
198			lookup_producer,
199		}
200	}
201}
202
203/// Implements filling for the increment lookup table.
204impl TableFiller for IncrLookup {
205	// Tuple of index and count
206	type Event = (usize, u32);
207
208	fn id(&self) -> TableId {
209		self.table_id
210	}
211
212	fn fill(&self, rows: &[Self::Event], witness: &mut TableWitnessSegment) -> anyhow::Result<()> {
213		// Fill the entries_ordered column
214		{
215			let mut col_data = witness.get_scalars_mut(self.entries_ordered)?;
216			let start_index = witness.index() << witness.log_size();
217			for (i, col_data_i) in col_data.iter_mut().enumerate() {
218				let mut entry_128b = B128::default();
219				IncrIndexedLookup.index_to_entry(start_index + i, slice::from_mut(&mut entry_128b));
220				*col_data_i = B32::try_from(entry_128b).expect("guaranteed by IncrIndexedLookup");
221			}
222		}
223
224		// Fill the entries_sorted column
225		{
226			let mut entries_sorted = witness.get_scalars_mut(self.entries_sorted)?;
227			for (merged_i, &(index, _)) in iter::zip(&mut *entries_sorted, rows.iter()) {
228				let mut entry_128b = B128::default();
229				IncrIndexedLookup.index_to_entry(index, slice::from_mut(&mut entry_128b));
230				*merged_i = B32::try_from(entry_128b).expect("guaranteed by IncrIndexedLookup");
231			}
232		}
233
234		self.lookup_producer
235			.populate(witness, rows.iter().map(|&(_i, count)| count))?;
236		Ok(())
237	}
238}
239
240/// Internal struct for indexed lookup logic for increment operations.
241pub struct IncrIndexedLookup;
242
243impl IndexedLookup<B128> for IncrIndexedLookup {
244	/// Returns the log2 size of the table (9 for 8 bits + 1 carry).
245	fn log_size(&self) -> usize {
246		// Input is an 8-bit value plus 1-bit carry-in
247		8 + 1
248	}
249
250	/// Converts a table entry to its index.
251	fn entry_to_index(&self, entry: &[B128]) -> usize {
252		debug_assert_eq!(entry.len(), 1);
253		let merged_val = entry[0].val() as u32;
254		let input = merged_val & 0xFF;
255		let carry_in_bit = (merged_val >> 16) & 1 == 1;
256		(carry_in_bit as usize) << 8 | input as usize
257	}
258
259	/// Converts an index to a table entry.
260	fn index_to_entry(&self, index: usize, entry: &mut [B128]) {
261		debug_assert_eq!(entry.len(), 1);
262		let input = (index % (1 << 8)) as u8;
263		let carry_in_bit = (index >> 8) & 1 == 1;
264		let (output, carry_out_bit) = input.overflowing_add(carry_in_bit.into());
265		let entry_u32 = merge_incr_vals(input, carry_in_bit, output, carry_out_bit);
266		entry[0] = B32::new(entry_u32).into();
267	}
268}
269
270/// Returns a circuit that describes the carry-in for the i_th bit of incrementing an 8-bit
271/// number by a carry-in bit. The circuit is a product of the lower bits.
272pub fn carry_in_circuit(i: usize) -> ArithExpr<B128> {
273	// The circuit is a lookup table for the increment operation, which takes an 8-bit input and
274	// returns an 8-bit output and a carry bit. The circuit is defined as follows:
275	let mut circuit = ArithExpr::Var(8);
276	for var in 0..i {
277		circuit *= ArithExpr::Var(var)
278	}
279	circuit
280}
281
282/// Returns a circuit that describes the increment operation for an 8-bit addition.
283/// The circuit encodes input, output, carry-in, and carry-out into a single value.
284pub fn incr_circuit() -> ArithCircuit<B128> {
285	// The circuit is a lookup table for the increment operation, which takes an 8-bit input and
286	// returns an 8-bit output and a carry bit. The circuit is defined as follows:
287	let mut circuit = ArithExpr::zero();
288	for i in 0..8 {
289		circuit += ArithExpr::Var(i) * ArithExpr::Const(B128::from(1 << i));
290
291		let carry = carry_in_circuit(i);
292		circuit += (ArithExpr::Var(i) + carry.clone()) * ArithExpr::Const(B128::from(1 << (i + 8)));
293	}
294	circuit += ArithExpr::Var(8) * ArithExpr::Const(B128::from(1 << 16));
295	let carry_out = carry_in_circuit(8);
296	circuit += carry_out * ArithExpr::Const(B128::from(1 << 17));
297	circuit.into()
298}
299
300/// Merges the input, output, carry-in, and carry-out columns into a single B32 column for lookup.
301pub fn merge_incr_cols(
302	table: &mut TableBuilder,
303	input: Col<B8>,
304	carry_in: Col<B1>,
305	output: Col<B8>,
306	carry_out: Col<B1>,
307) -> Col<B32> {
308	let beta_1 = ext_basis::<B32, B8>(1);
309	let beta_2_0 = ext_basis::<B32, B8>(2);
310	let beta_2_1 = beta_2_0 * ext_basis::<B8, B1>(1);
311	table.add_computed(
312		"merged",
313		upcast_col(input)
314			+ upcast_col(output) * beta_1
315			+ upcast_col(carry_in) * beta_2_0
316			+ upcast_col(carry_out) * beta_2_1,
317	)
318}
319
320/// Merges the input, output, carry-in, and carry-out values into a single u32 for lookup.
321pub fn merge_incr_vals(input: u8, carry_in: bool, output: u8, carry_out: bool) -> u32 {
322	((carry_out as u32) << 17) | ((carry_in as u32) << 16) | ((output as u32) << 8) | input as u32
323}
324
325#[cfg(test)]
326mod tests {
327	//! Tests for the increment indexed lookup gadgets.
328	use std::{cmp::Reverse, iter::repeat_with};
329
330	use binius_compute::cpu::alloc::CpuComputeAllocator;
331	use binius_core::constraint_system::channel::{Boundary, FlushDirection};
332	use binius_field::arch::OptimalUnderlier;
333	use itertools::Itertools;
334	use rand::{Rng, SeedableRng, rngs::StdRng};
335
336	use super::*;
337	use crate::builder::{
338		ConstraintSystem, WitnessIndex, tally,
339		test_utils::{ClosureFiller, validate_system_witness},
340	};
341
342	/// Unit test for a fixed lookup table, which requires counting lookups during witness
343	/// generation of the looker tables.
344	#[test]
345	fn test_fixed_lookup_producer() {
346		let mut cs = ConstraintSystem::new();
347		let incr_lookup_chan = cs.add_channel("incr lookup");
348		let incr_lookup_perm_chan = cs.add_channel("incr lookup permutation");
349
350		let n_multiplicity_bits = 8;
351
352		let mut incr_table = cs.add_table("increment");
353		let incr_lookup = IncrLookup::new(
354			&mut incr_table,
355			incr_lookup_chan,
356			incr_lookup_perm_chan,
357			n_multiplicity_bits,
358		);
359
360		let mut looker_1 = cs.add_table("looker 1");
361		let looker_1_id = looker_1.id();
362		let incr_1 = IncrLooker::new(&mut looker_1, incr_lookup_chan);
363
364		let mut looker_2 = cs.add_table("looker 2");
365		let looker_2_id = looker_2.id();
366		let incr_2 = IncrLooker::new(&mut looker_2, incr_lookup_chan);
367
368		let looker_1_size = 5;
369		let looker_2_size = 6;
370
371		let mut allocator = CpuComputeAllocator::new(1 << 12);
372		let allocator = allocator.into_bump_allocator();
373		let mut witness = WitnessIndex::new(&cs, &allocator);
374
375		let mut rng = StdRng::seed_from_u64(0);
376		let inputs_1 = repeat_with(|| {
377			let input = rng.random::<u8>();
378			let carry_in_bit = rng.random_bool(0.5);
379			(input, carry_in_bit)
380		})
381		.take(looker_1_size)
382		.collect::<Vec<_>>();
383
384		witness
385			.fill_table_sequential(
386				&ClosureFiller::new(looker_1_id, |inputs, segment| {
387					incr_1.populate(segment, inputs.iter())
388				}),
389				&inputs_1,
390			)
391			.unwrap();
392
393		let inputs_2 = repeat_with(|| {
394			let input = rng.random::<u8>();
395			let carry_in_bit = rng.random_bool(0.5);
396			(input, carry_in_bit)
397		})
398		.take(looker_2_size)
399		.collect::<Vec<_>>();
400
401		witness
402			.fill_table_sequential(
403				&ClosureFiller::new(looker_2_id, |inputs, segment| {
404					incr_2.populate(segment, inputs.iter())
405				}),
406				&inputs_2,
407			)
408			.unwrap();
409
410		let boundary_reads = vec![
411			merge_incr_vals(111, false, 111, false),
412			merge_incr_vals(111, true, 112, false),
413			merge_incr_vals(255, false, 255, false),
414			merge_incr_vals(255, true, 0, true),
415		];
416		let boundaries = boundary_reads
417			.into_iter()
418			.map(|val| Boundary {
419				values: vec![B32::new(val).into()],
420				direction: FlushDirection::Pull,
421				channel_id: incr_lookup_chan,
422				multiplicity: 1,
423			})
424			.collect::<Vec<_>>();
425
426		// Tally the lookup counts from the looker tables
427		let counts =
428			tally(&cs, &mut witness, &boundaries, incr_lookup_chan, &IncrIndexedLookup).unwrap();
429
430		// Fill the lookup table with the sorted counts
431		let sorted_counts = counts
432			.into_iter()
433			.enumerate()
434			.sorted_by_key(|(_, count)| Reverse(*count))
435			.collect::<Vec<_>>();
436
437		witness
438			.fill_table_sequential(&incr_lookup, &sorted_counts)
439			.unwrap();
440
441		validate_system_witness::<OptimalUnderlier>(&cs, witness, boundaries);
442	}
443}