binius_m3/gadgets/
lookup.rs

1// Copyright 2025 Irreducible Inc.
2
3use anyhow::Result;
4use binius_core::constraint_system::channel::ChannelId;
5use binius_field::{ExtensionField, PackedExtension, PackedField, PackedSubfield, TowerField};
6use itertools::Itertools;
7
8use crate::builder::{Col, FlushOpts, TableBuilder, TableWitnessSegment, B1, B128};
9
10/// A lookup producer gadget is used to create a lookup table.
11///
12/// The lookup producer pushes the value columns to a channel with prover-chosen multiplicities.
13/// This allows consumers of the channel can read any value in the table an arbitrary number of
14/// times. Table values are given as tuples of column entries.
15#[derive(Debug)]
16pub struct LookupProducer {
17	multiplicity_bits: Vec<Col<B1>>,
18}
19
20impl LookupProducer {
21	pub fn new<FSub>(
22		table: &mut TableBuilder,
23		chan: ChannelId,
24		value_cols: &[Col<FSub>],
25		n_multiplicity_bits: usize,
26	) -> Self
27	where
28		B128: ExtensionField<FSub>,
29		FSub: TowerField,
30	{
31		let multiplicity_bits = (0..n_multiplicity_bits)
32			.map(|i| table.add_committed::<B1, 1>(format!("multiplicity_bits[{i}]")))
33			.collect::<Vec<_>>();
34
35		for (i, &multiplicity_col) in multiplicity_bits.iter().enumerate() {
36			table.push_with_opts(
37				chan,
38				value_cols.iter().copied(),
39				FlushOpts {
40					multiplicity: 1 << i,
41					selector: Some(multiplicity_col),
42				},
43			);
44		}
45
46		Self { multiplicity_bits }
47	}
48
49	/// Populate the multiplicity witness columns.
50	///
51	/// ## Pre-condition
52	///
53	/// * Multiplicities must be sorted in ascending order.
54	pub fn populate<P>(
55		&self,
56		index: &mut TableWitnessSegment<P>,
57		counts: impl Iterator<Item = u32> + Clone,
58	) -> Result<(), anyhow::Error>
59	where
60		P: PackedExtension<B1>,
61		P::Scalar: TowerField,
62	{
63		// TODO: Optimize the gadget for bit-transposing u32s
64		for (j, &multiplicity_col) in self.multiplicity_bits.iter().enumerate().take(32) {
65			let mut multiplicity_col = index.get_mut(multiplicity_col)?;
66			for (packed, counts) in multiplicity_col
67				.iter_mut()
68				.zip(&counts.clone().chunks(<PackedSubfield<P, B1>>::WIDTH))
69			{
70				for (i, count) in counts.enumerate() {
71					packed.set(i, B1::from((count >> j) & 1 == 1))
72				}
73			}
74		}
75		Ok(())
76	}
77}
78
79#[cfg(test)]
80mod tests {
81	use std::{cmp::Reverse, iter, iter::repeat_with};
82
83	use binius_field::{arch::OptimalUnderlier128b, as_packed_field::PackedType};
84	use bumpalo::Bump;
85	use rand::{rngs::StdRng, Rng, SeedableRng};
86
87	use super::*;
88	use crate::builder::{test_utils::ClosureFiller, ConstraintSystem, Statement};
89
90	#[test]
91	fn test_basic_lookup_producer() {
92		let mut cs = ConstraintSystem::new();
93		let chan = cs.add_channel("values");
94
95		let mut lookup_table = cs.add_table("lookup");
96		let lookup_table_id = lookup_table.id();
97		let values_col = lookup_table.add_committed::<B128, 1>("values");
98		let lookup_producer = LookupProducer::new(&mut lookup_table, chan, &[values_col], 8);
99
100		let mut looker_1 = cs.add_table("looker 1");
101		let looker_1_id = looker_1.id();
102		let looker_1_vals = looker_1.add_committed::<B128, 1>("values");
103		looker_1.pull(chan, [looker_1_vals]);
104
105		let mut looker_2 = cs.add_table("looker 2");
106		let looker_2_id = looker_2.id();
107		let looker_2_vals = looker_2.add_committed::<B128, 1>("values");
108		looker_2.pull(chan, [looker_2_vals]);
109
110		let lookup_table_size = 45;
111		let mut rng = StdRng::seed_from_u64(0);
112		let values = repeat_with(|| B128::random(&mut rng))
113			.take(lookup_table_size)
114			.collect::<Vec<_>>();
115
116		let mut counts = vec![0u32; lookup_table_size];
117
118		let looker_1_size = 56;
119		let inputs_1 = repeat_with(|| {
120			let index = rng.gen_range(0..lookup_table_size);
121			counts[index] += 1;
122			values[index]
123		})
124		.take(looker_1_size)
125		.collect::<Vec<_>>();
126
127		let looker_2_size = 67;
128		let inputs_2 = repeat_with(|| {
129			let index = rng.gen_range(0..lookup_table_size);
130			counts[index] += 1;
131			values[index]
132		})
133		.take(looker_2_size)
134		.collect::<Vec<_>>();
135
136		let values_and_counts = iter::zip(values, counts)
137			.sorted_unstable_by_key(|&(_val, count)| Reverse(count))
138			.collect::<Vec<_>>();
139
140		let statement = Statement {
141			boundaries: vec![],
142			table_sizes: vec![lookup_table_size, looker_1_size, looker_2_size],
143		};
144		let allocator = Bump::new();
145		let mut witness = cs
146			.build_witness::<PackedType<OptimalUnderlier128b, B128>>(&allocator, &statement)
147			.unwrap();
148
149		// Fill the lookup table
150		witness
151			.fill_table_sequential(
152				&ClosureFiller::new(lookup_table_id, |values_and_counts, witness| {
153					{
154						let mut values_col = witness.get_scalars_mut(values_col)?;
155						for (dst, (val, _)) in iter::zip(&mut *values_col, values_and_counts) {
156							*dst = *val;
157						}
158					}
159					lookup_producer
160						.populate(witness, values_and_counts.iter().map(|(_, count)| *count))?;
161					Ok(())
162				}),
163				&values_and_counts,
164			)
165			.unwrap();
166
167		// Fill looker tables
168		witness
169			.fill_table_sequential(
170				&ClosureFiller::new(looker_1_id, |inputs_1, witness| {
171					let mut looker_1_vals = witness.get_scalars_mut(looker_1_vals)?;
172					for (dst, src) in iter::zip(&mut *looker_1_vals, inputs_1) {
173						*dst = **src;
174					}
175					Ok(())
176				}),
177				&inputs_1,
178			)
179			.unwrap();
180
181		witness
182			.fill_table_sequential(
183				&ClosureFiller::new(looker_2_id, |inputs_2, witness| {
184					let mut looker_2_vals = witness.get_scalars_mut(looker_2_vals)?;
185					for (dst, src) in iter::zip(&mut *looker_2_vals, inputs_2) {
186						*dst = **src;
187					}
188					Ok(())
189				}),
190				&inputs_2,
191			)
192			.unwrap();
193
194		let ccs = cs.compile(&statement).unwrap();
195		let witness = witness.into_multilinear_extension_index();
196
197		binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
198	}
199}