binius_m3/gadgets/
lookup.rs

1// Copyright 2025 Irreducible Inc.
2
3use anyhow::{Result, ensure};
4use binius_core::constraint_system::channel::ChannelId;
5use binius_field::{ExtensionField, PackedExtension, PackedField, PackedSubfield, TowerField};
6use itertools::Itertools;
7
8use crate::builder::{B1, B128, Col, FlushOpts, TableBuilder, TableWitnessSegment};
9
10/// A lookup producer gadget is used to create a lookup table.
11///
12/// The lookup producer pushes the value columns to a channel with prover-chosen multiplicities.
13/// This allows consumers of the channel to read any value in the table an arbitrary number of
14/// times. Table values are given as tuples of column entries.
15#[derive(Debug)]
16pub struct LookupProducer {
17	multiplicity_bits: Vec<Col<B1>>,
18}
19
20impl LookupProducer {
21	pub fn new<FSub>(
22		table: &mut TableBuilder,
23		chan: ChannelId,
24		value_cols: &[Col<FSub>],
25		n_multiplicity_bits: usize,
26	) -> Self
27	where
28		B128: ExtensionField<FSub>,
29		FSub: TowerField,
30	{
31		let multiplicity_bits = (0..n_multiplicity_bits)
32			.map(|i| table.add_committed::<B1, 1>(format!("multiplicity_bits[{i}]")))
33			.collect::<Vec<_>>();
34
35		for (i, &multiplicity_col) in multiplicity_bits.iter().enumerate() {
36			table.push_with_opts(
37				chan,
38				value_cols.iter().copied(),
39				FlushOpts {
40					multiplicity: 1 << i,
41					selectors: vec![multiplicity_col],
42				},
43			);
44		}
45
46		Self { multiplicity_bits }
47	}
48
49	/// Populate the multiplicity witness columns.
50	///
51	/// ## Pre-condition
52	///
53	/// * Multiplicities must be sorted in ascending order.
54	pub fn populate<P>(
55		&self,
56		index: &mut TableWitnessSegment<P>,
57		counts: impl Iterator<Item = u32> + Clone,
58	) -> Result<(), anyhow::Error>
59	where
60		P: PackedExtension<B1>,
61		P::Scalar: TowerField,
62	{
63		if self.multiplicity_bits.len() < u32::BITS as usize {
64			for count in counts.clone() {
65				ensure!(
66					count < (1 << self.multiplicity_bits.len()) as u32,
67					"count {count} exceeds maximum configured multiplicity; \
68					try raising the multiplicity bits in the constraint system"
69				);
70			}
71		}
72
73		// TODO: Optimize the gadget for bit-transposing u32s
74		for (j, &multiplicity_col) in self.multiplicity_bits.iter().enumerate().take(32) {
75			let mut multiplicity_col = index.get_mut(multiplicity_col)?;
76			for (packed, counts) in multiplicity_col
77				.iter_mut()
78				.zip(&counts.clone().chunks(<PackedSubfield<P, B1>>::WIDTH))
79			{
80				for (i, count) in counts.enumerate() {
81					packed.set(i, B1::from((count >> j) & 1 == 1))
82				}
83			}
84		}
85		Ok(())
86	}
87}
88
89#[cfg(test)]
90mod tests {
91	use std::{cmp::Reverse, iter, iter::repeat_with};
92
93	use binius_compute::cpu::alloc::CpuComputeAllocator;
94	use binius_field::{arch::OptimalUnderlier128b, as_packed_field::PackedType};
95	use rand::{Rng, SeedableRng, rngs::StdRng};
96
97	use super::*;
98	use crate::builder::{
99		ConstraintSystem, WitnessIndex,
100		test_utils::{ClosureFiller, validate_system_witness},
101	};
102
103	fn with_lookup_test_instance(
104		no_zero_counts: bool,
105		f: impl FnOnce(&ConstraintSystem<B128>, WitnessIndex<PackedType<OptimalUnderlier128b, B128>>),
106	) {
107		let mut cs = ConstraintSystem::new();
108		let chan = cs.add_channel("values");
109
110		let mut lookup_table = cs.add_table("lookup");
111		lookup_table.require_power_of_two_size();
112		let lookup_table_id = lookup_table.id();
113		let values_col = lookup_table.add_committed::<B128, 1>("values");
114		let lookup_producer = LookupProducer::new(&mut lookup_table, chan, &[values_col], 8);
115
116		let mut looker_1 = cs.add_table("looker 1");
117		let looker_1_id = looker_1.id();
118		let looker_1_vals = looker_1.add_committed::<B128, 1>("values");
119		looker_1.pull(chan, [looker_1_vals]);
120
121		let mut looker_2 = cs.add_table("looker 2");
122		let looker_2_id = looker_2.id();
123		let looker_2_vals = looker_2.add_committed::<B128, 1>("values");
124		looker_2.pull(chan, [looker_2_vals]);
125
126		let lookup_table_size = 64;
127		let mut rng = StdRng::seed_from_u64(0);
128		let values = repeat_with(|| B128::random(&mut rng))
129			.take(lookup_table_size)
130			.collect::<Vec<_>>();
131
132		let mut counts = vec![0u32; lookup_table_size];
133
134		let looker_1_size = 56;
135		let looker_2_size = 67;
136
137		// Choose looked-up indices randomly, but ensuring they are at least one if no_zero_counts
138		// is true. This tests an edge case.
139		let mut look_indices = Vec::with_capacity(looker_1_size + looker_2_size);
140		if no_zero_counts {
141			look_indices.extend(0..lookup_table_size);
142		}
143		let remaining = look_indices.capacity() - look_indices.len();
144		look_indices.extend(repeat_with(|| rng.random_range(0..lookup_table_size)).take(remaining));
145
146		let look_values = look_indices
147			.into_iter()
148			.map(|index| {
149				counts[index] += 1;
150				values[index]
151			})
152			.collect::<Vec<_>>();
153
154		let (inputs_1, inputs_2) = look_values.split_at(looker_1_size);
155
156		let values_and_counts = iter::zip(values, counts)
157			.sorted_unstable_by_key(|&(_val, count)| Reverse(count))
158			.collect::<Vec<_>>();
159
160		let mut allocator = CpuComputeAllocator::new(1 << 12);
161		let allocator = allocator.into_bump_allocator();
162		let mut witness =
163			WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
164
165		// Fill the lookup table
166		witness
167			.fill_table_sequential(
168				&ClosureFiller::new(lookup_table_id, |values_and_counts, witness| {
169					{
170						let mut values_col = witness.get_scalars_mut(values_col)?;
171						for (dst, (val, _)) in iter::zip(&mut *values_col, values_and_counts) {
172							*dst = *val;
173						}
174					}
175					lookup_producer
176						.populate(witness, values_and_counts.iter().map(|(_, count)| *count))?;
177					Ok(())
178				}),
179				&values_and_counts,
180			)
181			.unwrap();
182
183		// Fill looker tables
184		witness
185			.fill_table_sequential(
186				&ClosureFiller::new(looker_1_id, |inputs_1, witness| {
187					let mut looker_1_vals = witness.get_scalars_mut(looker_1_vals)?;
188					for (dst, src) in iter::zip(&mut *looker_1_vals, inputs_1) {
189						*dst = *src;
190					}
191					Ok(())
192				}),
193				inputs_1,
194			)
195			.unwrap();
196
197		witness
198			.fill_table_sequential(
199				&ClosureFiller::new(looker_2_id, |inputs_2, witness| {
200					let mut looker_2_vals = witness.get_scalars_mut(looker_2_vals)?;
201					for (dst, src) in iter::zip(&mut *looker_2_vals, inputs_2) {
202						*dst = *src;
203					}
204					Ok(())
205				}),
206				inputs_2,
207			)
208			.unwrap();
209
210		f(&cs, witness)
211	}
212
213	#[test]
214	fn test_basic_lookup_producer() {
215		with_lookup_test_instance(false, |cs, witness| {
216			validate_system_witness::<OptimalUnderlier128b>(cs, witness, vec![])
217		});
218	}
219
220	#[test]
221	fn test_lookup_producer_no_zero_counts() {
222		with_lookup_test_instance(true, |cs, witness| {
223			validate_system_witness::<OptimalUnderlier128b>(cs, witness, vec![])
224		});
225	}
226
227	#[test]
228	fn test_lookup_overflows_max_multiplicity() {
229		let mut cs = ConstraintSystem::new();
230		let chan = cs.add_channel("values");
231
232		let mut lookup_table = cs.add_table("lookup");
233		lookup_table.require_power_of_two_size();
234		let lookup_table_id = lookup_table.id();
235		let values_col = lookup_table.add_committed::<B128, 1>("values");
236		let lookup_producer = LookupProducer::new(&mut lookup_table, chan, &[values_col], 1);
237
238		let lookup_table_size = 64;
239		let mut rng = StdRng::seed_from_u64(0);
240		let values = repeat_with(|| B128::random(&mut rng))
241			.take(lookup_table_size)
242			.collect::<Vec<_>>();
243
244		let counts = vec![9; lookup_table_size];
245		let values_and_counts = iter::zip(values, counts).collect::<Vec<_>>();
246
247		let mut allocator = CpuComputeAllocator::new(1 << 12);
248		let allocator = allocator.into_bump_allocator();
249		let mut witness =
250			WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
251
252		// Attempt to fill the lookup table
253		let result = witness.fill_table_sequential(
254			&ClosureFiller::new(lookup_table_id, |values_and_counts, witness| {
255				{
256					let mut values_col = witness.get_scalars_mut(values_col)?;
257					for (dst, (val, _)) in iter::zip(&mut *values_col, values_and_counts) {
258						*dst = *val;
259					}
260				}
261				lookup_producer
262					.populate(witness, values_and_counts.iter().map(|(_, count)| *count))?;
263				Ok(())
264			}),
265			&values_and_counts,
266		);
267		assert!(result.is_err());
268	}
269}