binius_m3/gadgets/indexed_lookup/
and.rs

1// Copyright 2025 Irreducible Inc.
2
3/// This module provides gadgets for performing indexed lookup operations for the bitwise AND
4/// of two 8-bit values, using lookup tables. It includes types and functions for constructing,
5/// populating, and testing AND lookup tables and their associated circuits.
6use std::{iter, slice};
7
8use binius_core::constraint_system::channel::ChannelId;
9use binius_field::ext_basis;
10use binius_math::{ArithCircuit, ArithExpr};
11
12use crate::{
13	builder::{
14		B8, B32, B128, Col, IndexedLookup, TableBuilder, TableFiller, TableId, TableWitnessSegment,
15		column::upcast_col,
16	},
17	gadgets::lookup::LookupProducer,
18};
19
20/// A gadget that computes the logical AND of two boolean columns using a lookup table.
21///
22/// This struct holds columns for an 8-bit AND operation where:
23/// - `entries_ordered` is the fixed column containing all possible AND table entries
24/// - `entries_sorted` is a committed column for sorted entries
25/// - `lookup_producer` manages lookup multiplicities and constraints
26pub struct BitAndLookup {
27	/// The table ID
28	pub table_id: TableId,
29	entries_ordered: Col<B32>,
30	entries_sorted: Col<B32>,
31	lookup_producer: LookupProducer,
32}
33
34pub struct BitAnd<const V: usize = 1> {
35	/// Input column A (8 bits)
36	in_a: Col<B8, V>,
37	/// Input column B (8 bits)
38	in_b: Col<B8, V>,
39	/// Output column (8 bits), result of in_a & in_b
40	pub output: Col<B8, V>,
41	/// Merged column for lookup (32 bits)
42	merged: Col<B32, V>,
43}
44/// Constructs a new bitwise-AND gadget, registering the necessary columns in the table.
45impl<const V: usize> BitAnd<V> {
46	///
47	/// # Arguments
48	/// * `table` - The table builder to register columns with.
49	/// * `lookup_chan` - The channel for lookup operations.
50	/// * `in_a` - The first input column (8 bits).
51	/// * `in_b` - The second input column (8 bits).
52	///
53	/// # Returns
54	/// An `And` struct with all columns set up.
55	pub fn new(
56		table: &mut TableBuilder,
57		lookup_chan: ChannelId,
58		in_a: Col<B8, V>,
59		in_b: Col<B8, V>,
60	) -> Self {
61		let output = table.add_committed::<B8, V>("output");
62		let merged = merge_and_columns(table, in_a, in_b, output);
63		table.read(lookup_chan, [merged]);
64		Self {
65			in_a,
66			in_b,
67			output,
68			merged,
69		}
70	}
71
72	/// Populates the witness segment for this AND operation.
73	///
74	/// # Arguments
75	/// * `witness` - The witness segment to populate.
76	///
77	/// # Returns
78	/// `Ok(())` if successful, or an error otherwise.
79	pub fn populate(&self, witness: &mut TableWitnessSegment) -> anyhow::Result<()> {
80		let in_a_col = witness.get_scalars(self.in_a)?;
81		let in_b_col = witness.get_scalars(self.in_b)?;
82		let mut output_col: std::cell::RefMut<'_, [B8]> = witness.get_scalars_mut(self.output)?;
83		let mut merged_col: std::cell::RefMut<'_, [B32]> = witness.get_scalars_mut(self.merged)?;
84
85		for i in 0..witness.size() {
86			let in_a = in_a_col[i].val();
87			let in_b = in_b_col[i].val();
88			let output = in_a & in_b;
89			output_col[i] = output.into();
90
91			// Merge the values into a single u32
92			merged_col[i] = merge_bitand_vals(in_a, in_b, output).into();
93		}
94		Ok(())
95	}
96}
97
98/// Merges the input and output columns into a single B32 column for lookup.
99pub fn merge_and_columns<const V: usize>(
100	table: &mut TableBuilder,
101	in_a: Col<B8, V>,
102	in_b: Col<B8, V>,
103	output: Col<B8, V>,
104) -> Col<B32, V> {
105	table.add_computed(
106		"merged",
107		upcast_col(in_a)
108			+ upcast_col(in_b) * ext_basis::<B32, B8>(1)
109			+ upcast_col(output) * ext_basis::<B32, B8>(2),
110	)
111}
112
113/// Merges the input and output values into a single u32 for lookup.
114pub fn merge_bitand_vals(in_a: u8, in_b: u8, output: u8) -> u32 {
115	(in_a as u32) | ((in_b as u32) << 8) | ((output as u32) << 16)
116}
117
118/// Returns an arithmetic expression that represents the bitwise-AND operation as a lookup circuit.
119/// The circuit encodes input A, input B, and output into a single value.
120pub fn bitand_circuit() -> ArithCircuit<B128> {
121	// The circuit is a lookup table for the and operation, which takes 2 8-bit inputs and
122	// returns a field element which is the result of the bitwise andconcatenated with the inputs.
123	let mut circuit = ArithExpr::zero();
124	for i in 0..8 {
125		circuit += ArithExpr::Var(i) * ArithExpr::Const(B32::new(1 << i));
126		circuit += ArithExpr::Var(i + 8) * ArithExpr::Const(B32::new(1 << (i + 8)));
127		circuit +=
128			ArithExpr::Var(i) * ArithExpr::Var(i + 8) * ArithExpr::Const(B32::new(1 << (i + 16)));
129	}
130	ArithCircuit::<B32>::from(circuit)
131		.try_convert_field()
132		.expect("And circuit should convert to B128")
133}
134
135impl BitAndLookup {
136	/// Constructs a new AND lookup table.
137	///
138	/// # Arguments
139	/// * `table` - The table builder.
140	/// * `chan` - The lookup channel.
141	/// * `permutation_chan` - The channel for permutation checks.
142	/// * `n_multiplicity_bits` - Number of bits for multiplicity.
143	pub fn new(
144		table: &mut TableBuilder,
145		chan: ChannelId,
146		permutation_chan: ChannelId,
147		n_multiplicity_bits: usize,
148	) -> Self {
149		table.require_fixed_size(BitAndIndexedLookup.log_size());
150
151		// The entries_ordered column is the one that is filled with the lookup table entries.
152		let entries_ordered = table.add_fixed("bitand_lookup", bitand_circuit());
153		let entries_sorted = table.add_committed::<B32, 1>("entries_sorted");
154
155		// Use flush to check that entries_sorted is a permutation of entries_ordered.
156		table.push(permutation_chan, [entries_ordered]);
157		table.pull(permutation_chan, [entries_sorted]);
158
159		let lookup_producer =
160			LookupProducer::new(table, chan, &[entries_sorted], n_multiplicity_bits);
161		Self {
162			table_id: table.id(),
163			entries_ordered,
164			entries_sorted,
165			lookup_producer,
166		}
167	}
168}
169
170/// The 2 columns that are the inputs to the AND operation, and the gadget exposes an output column
171/// that corresponds to the bitwise AND of the two inputs.
172pub struct BitAndLooker {
173	/// Input column A (8 bits)
174	pub in_a: Col<B8>,
175	/// Input column B (8 bits)
176	pub in_b: Col<B8>,
177	/// Internal AND gadget
178	and: BitAnd,
179}
180
181impl BitAndLooker {
182	/// Constructs a new AND looker, registering columns in the table.
183	pub fn new(table: &mut TableBuilder, lookup_chan: ChannelId) -> Self {
184		let in_a = table.add_committed::<B8, 1>("in_a");
185		let in_b = table.add_committed::<B8, 1>("in_b");
186		// Create the And gadget which will compute the AND of in_a and in_b
187		let and = BitAnd::new(table, lookup_chan, in_a, in_b);
188		Self { in_a, in_b, and }
189	}
190
191	/// Populates the witness segment for a sequence of (in_a, in_b) events.
192	pub fn populate<'a>(
193		&self,
194		witness: &'a mut TableWitnessSegment,
195		inputs: impl Iterator<Item = &'a (u8, u8)> + Clone,
196	) -> anyhow::Result<()> {
197		{
198			let mut in_a_col: std::cell::RefMut<'_, [u8]> = witness.get_mut_as(self.in_a)?;
199			let mut in_b_col: std::cell::RefMut<'_, [u8]> = witness.get_mut_as(self.in_b)?;
200
201			for (i, &(in_a, in_b)) in inputs.enumerate() {
202				in_a_col[i] = in_a;
203				in_b_col[i] = in_b;
204			}
205		}
206
207		self.and.populate(witness)?;
208		Ok(())
209	}
210}
211
212/// Internal struct for indexed lookup logic for AND operations.
213pub struct BitAndIndexedLookup;
214
215impl IndexedLookup<B128> for BitAndIndexedLookup {
216	/// Returns the log2 size of the table (16 for 8 bits + 8 bits).
217	fn log_size(&self) -> usize {
218		16
219	}
220
221	/// Converts a table entry to its index.
222	fn entry_to_index(&self, entry: &[B128]) -> usize {
223		debug_assert_eq!(entry.len(), 1, "AndLookup entry must be a single B128 field");
224		let merged_val = entry[0].val() as u32;
225		(merged_val & 0xFFFF) as usize
226	}
227
228	/// Converts an index to a table entry.
229	fn index_to_entry(&self, index: usize, entry: &mut [B128]) {
230		debug_assert_eq!(entry.len(), 1, "AndLookup entry must be a single B128 field");
231		let in_a = index & 0xFF;
232		let in_b = (index >> 8) & 0xFF;
233		let output = in_a & in_b;
234		let merged = merge_bitand_vals(in_a as u8, in_b as u8, output as u8);
235		entry[0] = B128::from(merged as u128);
236	}
237}
238
239/// Implements filling for the AND lookup table.
240impl TableFiller for BitAndLookup {
241	// Tuple of index and count
242	type Event = (usize, u32);
243
244	fn id(&self) -> TableId {
245		self.table_id
246	}
247
248	fn fill(&self, rows: &[Self::Event], witness: &mut TableWitnessSegment) -> anyhow::Result<()> {
249		// Fill the entries_ordered column
250		{
251			let mut col_data = witness.get_scalars_mut(self.entries_ordered)?;
252			let start_index = witness.index() << witness.log_size();
253			for (i, col_data_i) in col_data.iter_mut().enumerate() {
254				let mut entry_128b = B128::default();
255				BitAndIndexedLookup
256					.index_to_entry(start_index + i, slice::from_mut(&mut entry_128b));
257				*col_data_i = B32::try_from(entry_128b).expect("guaranteed by BitAndIndexedLookup");
258			}
259		}
260
261		// Fill the entries_sorted column
262		{
263			let mut entries_sorted = witness.get_scalars_mut(self.entries_sorted)?;
264			for (merged_i, &(index, _)) in iter::zip(&mut *entries_sorted, rows) {
265				let mut entry_128b = B128::default();
266				BitAndIndexedLookup.index_to_entry(index, slice::from_mut(&mut entry_128b));
267				*merged_i = B32::try_from(entry_128b).expect("guaranteed by BitAndIndexedLookup");
268			}
269		}
270
271		self.lookup_producer
272			.populate(witness, rows.iter().map(|&(_i, count)| count))?;
273		Ok(())
274	}
275}
276
277#[cfg(test)]
278mod tests {
279	//! Tests for the AND indexed lookup gadgets.
280
281	use std::{cmp::Reverse, iter::repeat_with};
282
283	use binius_compute::cpu::alloc::CpuComputeAllocator;
284	use binius_core::constraint_system::channel::{Boundary, FlushDirection};
285	use binius_field::arch::OptimalUnderlier;
286	use itertools::Itertools;
287	use rand::{Rng, SeedableRng, rngs::StdRng};
288
289	use super::*;
290	use crate::builder::{
291		ConstraintSystem, WitnessIndex, tally,
292		test_utils::{ClosureFiller, validate_system_witness},
293	};
294
295	#[test]
296	fn test_and_lookup() {
297		let mut cs: ConstraintSystem<B128> = ConstraintSystem::new();
298		let lookup_chan = cs.add_channel("lookup");
299		let permutation_chan = cs.add_channel("permutation");
300		let mut and_table = cs.add_table("bitand_lookup");
301		let n_multiplicity_bits = 8;
302
303		let bitand_lookup =
304			BitAndLookup::new(&mut and_table, lookup_chan, permutation_chan, n_multiplicity_bits);
305		let mut and_looker = cs.add_table("bitand_looker");
306
307		let bitand_1 = BitAndLooker::new(&mut and_looker, lookup_chan);
308
309		let looker_1_size = 5;
310		let looker_id = and_looker.id();
311
312		let mut allocator = CpuComputeAllocator::new(1 << 16);
313		let allocator = allocator.into_bump_allocator();
314		let mut witness = WitnessIndex::new(&cs, &allocator);
315
316		let mut rng = StdRng::seed_from_u64(0);
317		let inputs_1 = repeat_with(|| {
318			let in_a = rng.random::<u8>();
319			let in_b = rng.random::<u8>();
320			(in_a, in_b)
321		})
322		.take(looker_1_size)
323		.collect::<Vec<_>>();
324
325		witness
326			.fill_table_parallel(
327				&ClosureFiller::new(looker_id, |inputs, segment| {
328					bitand_1.populate(segment, inputs.iter())
329				}),
330				&inputs_1,
331			)
332			.unwrap();
333
334		let boundary_reads = (0..5)
335			.map(|_| {
336				let in_a = rng.random::<u8>();
337				let in_b = rng.random::<u8>();
338				merge_bitand_vals(in_a, in_b, in_a & in_b)
339			})
340			.collect::<Vec<_>>();
341
342		let boundaries = boundary_reads
343			.into_iter()
344			.map(|val| Boundary {
345				values: vec![B32::new(val).into()],
346				direction: FlushDirection::Pull,
347				channel_id: lookup_chan,
348				multiplicity: 1,
349			})
350			.collect::<Vec<_>>();
351
352		// Tally the lookup counts from the looker tables
353		let counts =
354			tally(&cs, &mut witness, &boundaries, lookup_chan, &BitAndIndexedLookup).unwrap();
355
356		// Fill the lookup table with the sorted counts
357		let sorted_counts = counts
358			.into_iter()
359			.enumerate()
360			.sorted_by_key(|(_, count)| Reverse(*count))
361			.collect::<Vec<_>>();
362
363		witness
364			.fill_table_parallel(&bitand_lookup, &sorted_counts)
365			.unwrap();
366
367		validate_system_witness::<OptimalUnderlier>(&cs, witness, boundaries);
368	}
369}