binius_circuits/
plain_lookup.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_core::{constraint_system::channel::FlushDirection, oracle::OracleId};
4use binius_field::{
5	as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, ExtensionField, Field,
6	TowerField,
7};
8use bytemuck::Pod;
9
10use crate::builder::{
11	types::{F, U},
12	ConstraintSystemBuilder,
13};
14
15/// Checks values in `lookup_values` to be in `table`.
16///
17/// # Introduction
18/// This is a gadget for performing a "lookup", wherein a set of values are claimed by the prover to be a subset of a set of values known to the verifier.
19/// We call the set of values known to the verifier as the "table", and we call the set of values held by the prover as the "lookup values."
20/// We represent these sets using oracles `table` and `lookup_values` as lists of values.
21/// This gadget performs the lookup by verifying that every value in the oracle `lookup_vales` appears somewhere in the oracle `table`.
22///
23/// # Parameters
24/// - `builder`: a mutable reference to the `ConstraintSystemBuilder`.
25/// - `table`: an oracle holding the table of valid lookup values.
26/// - `lookup_values`: an oracle holding the values to be looked up.
27/// - `lookup_values_count`: only the first `lookup_values_count` values in `lookup_values` will be looked up.
28///
29/// # How this Works
30/// We create a single channel for this lookup.
31/// We let the prover push all values in `lookup_values`, that is all values to be looked up, into the channel.
32/// We also must pull valid table values (i.e. values that appear in `table`) from the channel if the channel is to balance.
33/// By ensuring that only valid table values get pulled from the channel, and observing the channel to balance, we ensure that only valid table values get pushed (by the prover) into the channel.
34/// Therefore our construction is sound.
35/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each
36/// table value from the channel with exactly the same multiplicity (duplicate count) that the prover pushed that table value into the channel.
37/// To do so, we allow the prover to commit information on the multiplicity of each table value.
38///
39/// The prover counts the multiplicity of each table value, and creates a bit column for
40/// each of the LOG_MAX_MULTIPLICITY bits in the bit-decomposition of the multiplicities.
41/// Then we flush the table values LOG_MAX_MULTIPLICITY times, each time using a different bit column as the 'selector' oracle to select which values in the
42/// table actually get pushed into the channel flushed. When flushing the table with the i'th bit column as the selector, we flush with multiplicity 1 << i.
43///
44pub fn plain_lookup<FS, const LOG_MAX_MULTIPLICITY: usize>(
45	builder: &mut ConstraintSystemBuilder,
46	table: OracleId,
47	lookup_values: OracleId,
48	lookup_values_count: usize,
49) -> Result<(), anyhow::Error>
50where
51	U: PackScalar<FS> + Pod,
52	F: ExtensionField<FS>,
53	FS: TowerField + Pod,
54{
55	let n_vars = builder.log_rows([table])?;
56
57	let channel = builder.add_channel();
58
59	builder.send(channel, lookup_values_count, [lookup_values])?;
60
61	let mut multiplicities = None;
62	// have prover compute and fill the multiplicities
63	if let Some(witness) = builder.witness() {
64		let table_slice = witness.get::<FS>(table)?.as_slice::<FS>();
65		let values_slice = witness.get::<FS>(lookup_values)?.as_slice::<FS>();
66
67		multiplicities = Some(count_multiplicities(
68			&table_slice[0..1 << n_vars],
69			&values_slice[0..lookup_values_count],
70			false,
71		)?);
72	}
73
74	let bits: [OracleId; LOG_MAX_MULTIPLICITY] = get_bits(builder, table, multiplicities)?;
75	bits.into_iter().enumerate().try_for_each(|(i, bit)| {
76		builder.flush_custom(FlushDirection::Pull, channel, bit, [table], 1 << i)
77	})?;
78
79	Ok(())
80}
81
82// the `i`'th returned bit column holds the `i`'th multiplicity bit.
83fn get_bits<FS, const LOG_MAX_MULTIPLICITY: usize>(
84	builder: &mut ConstraintSystemBuilder,
85	table: OracleId,
86	multiplicities: Option<Vec<usize>>,
87) -> Result<[OracleId; LOG_MAX_MULTIPLICITY], anyhow::Error>
88where
89	U: PackScalar<FS>,
90	F: ExtensionField<FS>,
91	FS: TowerField + Pod,
92{
93	let n_vars = builder.log_rows([table])?;
94
95	let bits: [OracleId; LOG_MAX_MULTIPLICITY] = builder
96		.add_committed_multiple::<LOG_MAX_MULTIPLICITY>("bits", n_vars, BinaryField1b::TOWER_LEVEL);
97
98	if let Some(witness) = builder.witness() {
99		let multiplicities =
100			multiplicities.ok_or_else(|| anyhow::anyhow!("multiplicities empty for prover"))?;
101		debug_assert_eq!(1 << n_vars, multiplicities.len());
102
103		// check all multiplicities are in range
104		if multiplicities
105			.iter()
106			.any(|&multiplicity| multiplicity >= 1 << LOG_MAX_MULTIPLICITY)
107		{
108			return Err(anyhow::anyhow!(
109				"one or more multiplicities exceed `1 << LOG_MAX_MULTIPLICITY`"
110			));
111		}
112
113		// create the columns for the bits
114		let mut bit_cols = bits.map(|bit| witness.new_column::<BinaryField1b>(bit));
115		let mut packed_bit_cols = bit_cols.each_mut().map(|bit_col| bit_col.packed());
116
117		multiplicities
118			.iter()
119			.enumerate()
120			.for_each(|(i, multiplicity)| {
121				(0..LOG_MAX_MULTIPLICITY).for_each(|j| {
122					let bit_set = multiplicity & (1 << j) != 0;
123					set_packed_slice(
124						packed_bit_cols[j],
125						i,
126						match bit_set {
127							true => BinaryField1b::ONE,
128							false => BinaryField1b::ZERO,
129						},
130					);
131				})
132			});
133	}
134
135	Ok(bits)
136}
137
138#[cfg(test)]
139pub mod test_plain_lookup {
140	use binius_field::BinaryField32b;
141	use binius_maybe_rayon::prelude::*;
142
143	use super::*;
144	use crate::transparent;
145
146	const fn into_lookup_claim(x: u8, y: u8, z: u16) -> u32 {
147		((z as u32) << 16) | ((y as u32) << 8) | (x as u32)
148	}
149
150	fn generate_u8_mul_table() -> Vec<u32> {
151		let mut result = Vec::with_capacity(1 << 16);
152		for x in 0..=255u8 {
153			for y in 0..=255u8 {
154				let product = x as u16 * y as u16;
155				result.push(into_lookup_claim(x, y, product));
156			}
157		}
158		result
159	}
160
161	fn generate_random_u8_mul_claims(vals: &mut [u32]) {
162		use rand::Rng;
163		vals.par_iter_mut().for_each(|val| {
164			let mut rng = rand::thread_rng();
165			let x = rng.gen_range(0..=255u8);
166			let y = rng.gen_range(0..=255u8);
167			let product = x as u16 * y as u16;
168			*val = into_lookup_claim(x, y, product);
169		});
170	}
171
172	pub fn test_u8_mul_lookup<const LOG_MAX_MULTIPLICITY: usize>(
173		builder: &mut ConstraintSystemBuilder,
174		log_lookup_count: usize,
175	) -> Result<(), anyhow::Error> {
176		let table_values = generate_u8_mul_table();
177		let table = transparent::make_transparent(
178			builder,
179			"u8_mul_table",
180			bytemuck::cast_slice::<_, BinaryField32b>(&table_values),
181		)?;
182
183		let lookup_values =
184			builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL);
185
186		let lookup_values_count = 1 << log_lookup_count;
187
188		if let Some(witness) = builder.witness() {
189			let mut lookup_values_col = witness.new_column::<BinaryField32b>(lookup_values);
190			let mut_slice = lookup_values_col.as_mut_slice::<u32>();
191			generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]);
192		}
193
194		plain_lookup::<BinaryField32b, LOG_MAX_MULTIPLICITY>(
195			builder,
196			table,
197			lookup_values,
198			lookup_values_count,
199		)?;
200
201		Ok(())
202	}
203}
204
205fn count_multiplicities<T: Eq + std::hash::Hash + Clone + std::fmt::Debug>(
206	table: &[T],
207	values: &[T],
208	check_inclusion: bool,
209) -> Result<Vec<usize>, anyhow::Error> {
210	use std::collections::{HashMap, HashSet};
211
212	if check_inclusion {
213		let table_set: HashSet<_> = table.iter().cloned().collect();
214		if let Some(invalid_value) = values.iter().find(|value| !table_set.contains(value)) {
215			return Err(anyhow::anyhow!("value {:?} not in table", invalid_value));
216		}
217	}
218
219	let counts: HashMap<_, usize> = values.iter().fold(HashMap::new(), |mut acc, value| {
220		*acc.entry(value).or_insert(0) += 1;
221		acc
222	});
223
224	let multiplicities = table
225		.iter()
226		.map(|item| counts.get(item).copied().unwrap_or(0))
227		.collect();
228
229	Ok(multiplicities)
230}
231
232#[cfg(test)]
233mod count_multiplicity_tests {
234	use super::*;
235
236	#[test]
237	fn test_basic_functionality() {
238		let table = vec![1, 2, 3, 4];
239		let values = vec![1, 2, 2, 3, 3, 3];
240		let result = count_multiplicities(&table, &values, true).unwrap();
241		assert_eq!(result, vec![1, 2, 3, 0]);
242	}
243
244	#[test]
245	fn test_empty_values() {
246		let table = vec![1, 2, 3];
247		let values: Vec<i32> = vec![];
248		let result = count_multiplicities(&table, &values, true).unwrap();
249		assert_eq!(result, vec![0, 0, 0]);
250	}
251
252	#[test]
253	fn test_empty_table() {
254		let table: Vec<i32> = vec![];
255		let values = vec![1, 2, 3];
256		let result = count_multiplicities(&table, &values, false).unwrap();
257		assert_eq!(result, vec![]);
258	}
259
260	#[test]
261	fn test_value_not_in_table() {
262		let table = vec![1, 2, 3];
263		let values = vec![1, 4, 2];
264		let result = count_multiplicities(&table, &values, true);
265		assert!(result.is_err());
266		assert_eq!(result.unwrap_err().to_string(), "value 4 not in table");
267	}
268
269	#[test]
270	fn test_duplicates_in_table() {
271		let table = vec![1, 1, 2, 3];
272		let values = vec![1, 2, 2, 3, 3, 3];
273		let result = count_multiplicities(&table, &values, true).unwrap();
274		assert_eq!(result, vec![1, 1, 2, 3]);
275	}
276
277	#[test]
278	fn test_non_integer_values() {
279		let table = vec!["a", "b", "c"];
280		let values = vec!["a", "b", "b", "c", "c", "c"];
281		let result = count_multiplicities(&table, &values, true).unwrap();
282		assert_eq!(result, vec![1, 2, 3]);
283	}
284}
285
286#[cfg(test)]
287mod tests {
288	use binius_core::{fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily};
289	use binius_hal::make_portable_backend;
290	use binius_hash::compress::Groestl256ByteCompression;
291	use binius_math::DefaultEvaluationDomainFactory;
292	use groestl_crypto::Groestl256;
293
294	use super::test_plain_lookup;
295	use crate::builder::ConstraintSystemBuilder;
296
297	#[test]
298	fn test_plain_u8_mul_lookup() {
299		const MAX_LOG_MULTIPLICITY: usize = 18;
300		let log_lookup_count = 19;
301
302		let log_inv_rate = 1;
303		let security_bits = 20;
304
305		let proof = {
306			let allocator = bumpalo::Bump::new();
307			let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
308
309			test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
310				&mut builder,
311				log_lookup_count,
312			)
313			.unwrap();
314
315			let witness = builder.take_witness().unwrap();
316			let constraint_system = builder.build().unwrap();
317			// validating witness with `validate_witness` is too slow for large transparents like the `table`
318
319			let domain_factory = DefaultEvaluationDomainFactory::default();
320			let backend = make_portable_backend();
321
322			binius_core::constraint_system::prove::<
323				crate::builder::types::U,
324				CanonicalTowerFamily,
325				_,
326				Groestl256,
327				Groestl256ByteCompression,
328				HasherChallenger<Groestl256>,
329				_,
330			>(
331				&constraint_system,
332				log_inv_rate,
333				security_bits,
334				&[],
335				witness,
336				&domain_factory,
337				&backend,
338			)
339			.unwrap()
340		};
341
342		// verify
343		{
344			let mut builder = ConstraintSystemBuilder::new();
345
346			test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
347				&mut builder,
348				log_lookup_count,
349			)
350			.unwrap();
351
352			let constraint_system = builder.build().unwrap();
353
354			binius_core::constraint_system::verify::<
355				crate::builder::types::U,
356				CanonicalTowerFamily,
357				Groestl256,
358				Groestl256ByteCompression,
359				HasherChallenger<Groestl256>,
360			>(&constraint_system, log_inv_rate, security_bits, &[], proof)
361			.unwrap();
362		}
363	}
364}