binius_circuits/
plain_lookup.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp::Reverse, fmt::Debug, hash::Hash};
4
5use anyhow::{ensure, Result};
6use binius_core::{
7	constraint_system::channel::{FlushDirection, OracleOrConst},
8	oracle::OracleId,
9};
10use binius_field::{
11	as_packed_field::PackScalar,
12	packed::{get_packed_slice, set_packed_slice},
13	BinaryField1b, ExtensionField, Field, PackedField, TowerField,
14};
15use bytemuck::Pod;
16use itertools::izip;
17
18use crate::builder::{
19	types::{F, U},
20	ConstraintSystemBuilder,
21};
22
23/// A gadget validating the lookup relation between:
24/// * `lookups_u` - the set of "looked up" tables
25/// * `lookup_t` - the lookup table
26///
27/// Both looked up and lookup tables are defined by tuples of column oracles, where each column has
28/// the same `FTable` type. Primary reason to support this behaviour is to be able to lookup into
29/// tables "wider" than the largest 128-bit field.
30///
31/// Looked up tables are assumed to have `n_lookups` values each, whereas `lookup_t` is considered
32/// to be always full.
33///
34/// The prover needs to provide multiplicities of the lookup table elements; the helper
35/// [`count_multiplicities`] method does that using a hash map of counters, but in some cases it
36/// should be possible to compute this table more efficiently or in an indirect way.
37///
38/// # How this Works
39/// We create two channel for this lookup - a multiplicity channel and a permutation channel.
40/// We let the prover push all values in `lookups_u`, that is all values to be looked up, into the
41/// multiplicity channel. We also must pull valid table values (i.e. values that appear in
42/// `lookup_t`) from this channel if the channel is to balance. By ensuring that only valid table
43/// values get pulled from the channel, and observing the channel to balance, we ensure
44/// that only valid table values get pushed (by the prover) into the channel. Therefore our
45/// construction is sound.
46///
47/// In order for the construction to be complete, allowing an honest prover to pass, we must pull
48/// each table value from the multiplicity channel with exactly the same multiplicity (duplicate
49/// count) that the prover pushed that table value into the channel. To do so, we allow the prover
50/// to commit information on the multiplicity of each table value.
51///
52/// The prover counts the multiplicity of each table value, and creates a bit column for each of the
53/// LOG_MAX_MULTIPLICITY bits in the bit-decomposition of the multiplicities.
54/// Then we flush the table values LOG_MAX_MULTIPLICITY times, each time using a different bit
55/// column as the 'selector' oracle to select which values in the table actually get pushed into the
56/// channel flushed. When flushing the table with the i'th bit column as the selector, we flush with
57/// multiplicity 1 << i.
58///
59/// The reason for using _two_ channels is a prover-side optimization - instead of counting
60/// multiplicities on the original `lookup_t`, we commit a permuted version of that with
61/// non-decreasing multiplicities. This enforces nonzero scalars prefixes on the committed
62/// multiplicity bits columns, which can be used to optimize the flush and GKR reduction sumchecks.
63/// In order to constrain this committed lookup to be a permutation of lookup_t we do a push/pull on
64/// the permutation channel.
65pub fn plain_lookup<FTable, const LOG_MAX_MULTIPLICITY: usize>(
66	builder: &mut ConstraintSystemBuilder,
67	name: impl ToString,
68	n_lookups: &[usize],
69	lookups_u: &[impl AsRef<[OracleId]>],
70	lookup_t: impl AsRef<[OracleId]>,
71	multiplicities: Option<impl AsRef<[usize]>>,
72) -> Result<()>
73where
74	U: PackScalar<FTable> + Pod,
75	F: ExtensionField<FTable>,
76	FTable: TowerField,
77{
78	ensure!(n_lookups.len() == lookups_u.len(), "n_vars and lookups_u must be of the same length");
79	ensure!(
80		lookups_u
81			.iter()
82			.all(|oracles| oracles.as_ref().len() == lookup_t.as_ref().len()),
83		"looked up and lookup tables must have the same number of oracles"
84	);
85
86	let lookups_u_count_sum = n_lookups.iter().sum::<usize>();
87	ensure!(lookups_u_count_sum < 1 << LOG_MAX_MULTIPLICITY, "LOG_MAX_MULTIPLICITY too small");
88
89	builder.push_namespace(name);
90
91	let t_log_rows = builder.log_rows(lookup_t.as_ref().iter().copied())?;
92	let bits = builder.add_committed_multiple::<LOG_MAX_MULTIPLICITY>(
93		"multiplicity_bits",
94		t_log_rows,
95		BinaryField1b::TOWER_LEVEL,
96	);
97
98	let permuted_lookup_t = (0..lookup_t.as_ref().len())
99		.map(|i| builder.add_committed(format!("permuted_t_{i}"), t_log_rows, FTable::TOWER_LEVEL))
100		.collect::<Vec<_>>();
101
102	if let Some(witness) = builder.witness() {
103		let mut indexed_multiplicities = multiplicities
104			.expect("multiplicities should be supplied when proving")
105			.as_ref()
106			.iter()
107			.copied()
108			.enumerate()
109			.collect::<Vec<_>>();
110
111		let multiplicities_sum = indexed_multiplicities
112			.iter()
113			.map(|&(_, multiplicity)| multiplicity)
114			.sum::<usize>();
115		ensure!(multiplicities_sum == lookups_u_count_sum, "Multiplicities do not add up.");
116
117		indexed_multiplicities.sort_by_key(|&(_, multiplicity)| Reverse(multiplicity));
118
119		for (i, bit) in bits.into_iter().enumerate() {
120			let nonzero_scalars_prefix =
121				indexed_multiplicities.partition_point(|&(_, count)| count >= 1 << i);
122
123			let mut column = witness.new_column_with_nonzero_scalars_prefix::<BinaryField1b>(
124				bit,
125				nonzero_scalars_prefix,
126			);
127
128			let packed = column.packed();
129
130			for (j, &(_, multiplicity)) in indexed_multiplicities.iter().enumerate() {
131				if (1 << i) & multiplicity != 0 {
132					set_packed_slice(packed, j, BinaryField1b::ONE);
133				}
134			}
135		}
136
137		for (&permuted, &original) in izip!(&permuted_lookup_t, lookup_t.as_ref()) {
138			let original_slice = witness.get::<FTable>(original)?.packed();
139
140			let mut permuted_column = witness.new_column::<FTable>(permuted);
141			let permuted_slice = permuted_column.packed();
142
143			let mut iterator = indexed_multiplicities
144				.iter()
145				.map(|&(index, _)| get_packed_slice(original_slice, index));
146			for v in permuted_slice.iter_mut() {
147				*v = PackedField::from_scalars(&mut iterator);
148			}
149		}
150	}
151
152	let permutation_channel = builder.add_channel();
153	let multiplicity_channel = builder.add_channel();
154
155	builder.send(
156		permutation_channel,
157		1 << t_log_rows,
158		permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
159	)?;
160	builder.receive(
161		permutation_channel,
162		1 << t_log_rows,
163		lookup_t.as_ref().iter().copied().map(OracleOrConst::Oracle),
164	)?;
165
166	for (lookup_u, &count) in izip!(lookups_u, n_lookups) {
167		builder.send(
168			multiplicity_channel,
169			count,
170			lookup_u.as_ref().iter().copied().map(OracleOrConst::Oracle),
171		)?;
172	}
173
174	for (i, bit) in bits.into_iter().enumerate() {
175		builder.flush_custom(
176			FlushDirection::Pull,
177			multiplicity_channel,
178			vec![bit],
179			permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
180			1 << i,
181		)?
182	}
183
184	builder.pop_namespace();
185
186	Ok(())
187}
188
189#[cfg(test)]
190pub mod test_plain_lookup {
191	use binius_field::BinaryField32b;
192	use rand::{rngs::StdRng, SeedableRng};
193
194	use super::*;
195	use crate::transparent;
196
197	const fn into_lookup_claim(x: u8, y: u8, z: u16) -> u32 {
198		((z as u32) << 16) | ((y as u32) << 8) | (x as u32)
199	}
200
201	fn generate_u8_mul_table() -> Vec<u32> {
202		let mut result = Vec::with_capacity(1 << 16);
203		for x in 0..=255u8 {
204			for y in 0..=255u8 {
205				let product = x as u16 * y as u16;
206				result.push(into_lookup_claim(x, y, product));
207			}
208		}
209		result
210	}
211
212	fn generate_random_u8_mul_claims(vals: &mut [u32]) {
213		use rand::Rng;
214		let mut rng = StdRng::seed_from_u64(0);
215		for val in vals {
216			let x = rng.gen_range(0..=255u8);
217			let y = rng.gen_range(0..=255u8);
218			let product = x as u16 * y as u16;
219			*val = into_lookup_claim(x, y, product);
220		}
221	}
222
223	pub fn test_u8_mul_lookup<const LOG_MAX_MULTIPLICITY: usize>(
224		builder: &mut ConstraintSystemBuilder,
225		log_lookup_count: usize,
226	) -> Result<(), anyhow::Error> {
227		let table_values = generate_u8_mul_table();
228		let table = transparent::make_transparent(
229			builder,
230			"u8_mul_table",
231			bytemuck::cast_slice::<_, BinaryField32b>(&table_values),
232		)?;
233
234		let lookup_values =
235			builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL);
236
237		let lookup_values_count = 1 << log_lookup_count;
238
239		let multiplicities = if let Some(witness) = builder.witness() {
240			let mut lookup_values_col = witness.new_column::<BinaryField32b>(lookup_values);
241			let mut_slice = lookup_values_col.as_mut_slice::<u32>();
242			generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]);
243			Some(count_multiplicities(&table_values, mut_slice, true).unwrap())
244		} else {
245			None
246		};
247
248		plain_lookup::<BinaryField32b, LOG_MAX_MULTIPLICITY>(
249			builder,
250			"u8_mul_lookup",
251			&[1 << log_lookup_count],
252			&[[lookup_values]],
253			&[table],
254			multiplicities,
255		)?;
256
257		Ok(())
258	}
259}
260
261pub fn count_multiplicities<T>(
262	table: &[T],
263	values: &[T],
264	check_inclusion: bool,
265) -> Result<Vec<usize>, anyhow::Error>
266where
267	T: Eq + Hash + Debug,
268{
269	use std::collections::{HashMap, HashSet};
270
271	if check_inclusion {
272		let table_set: HashSet<_> = table.iter().collect();
273		if let Some(invalid_value) = values.iter().find(|value| !table_set.contains(value)) {
274			return Err(anyhow::anyhow!("value {:?} not in table", invalid_value));
275		}
276	}
277
278	let counts: HashMap<_, usize> =
279		values
280			.iter()
281			.fold(HashMap::with_capacity(table.len()), |mut acc, value| {
282				*acc.entry(value).or_insert(0) += 1;
283				acc
284			});
285
286	let multiplicities = table
287		.iter()
288		.map(|item| counts.get(item).copied().unwrap_or(0))
289		.collect();
290
291	Ok(multiplicities)
292}
293
294#[cfg(test)]
295mod count_multiplicity_tests {
296	use super::*;
297
298	#[test]
299	fn test_basic_functionality() {
300		let table = vec![1, 2, 3, 4];
301		let values = vec![1, 2, 2, 3, 3, 3];
302		let result = count_multiplicities(&table, &values, true).unwrap();
303		assert_eq!(result, vec![1, 2, 3, 0]);
304	}
305
306	#[test]
307	fn test_empty_values() {
308		let table = vec![1, 2, 3];
309		let values: Vec<i32> = vec![];
310		let result = count_multiplicities(&table, &values, true).unwrap();
311		assert_eq!(result, vec![0, 0, 0]);
312	}
313
314	#[test]
315	fn test_empty_table() {
316		let table: Vec<i32> = vec![];
317		let values = vec![1, 2, 3];
318		let result = count_multiplicities(&table, &values, false).unwrap();
319		assert!(result.is_empty());
320	}
321
322	#[test]
323	fn test_value_not_in_table() {
324		let table = vec![1, 2, 3];
325		let values = vec![1, 4, 2];
326		let result = count_multiplicities(&table, &values, true);
327		assert!(result.is_err());
328		assert_eq!(result.unwrap_err().to_string(), "value 4 not in table");
329	}
330
331	#[test]
332	fn test_duplicates_in_table() {
333		let table = vec![1, 1, 2, 3];
334		let values = vec![1, 2, 2, 3, 3, 3];
335		let result = count_multiplicities(&table, &values, true).unwrap();
336		assert_eq!(result, vec![1, 1, 2, 3]);
337	}
338
339	#[test]
340	fn test_non_integer_values() {
341		let table = vec!["a", "b", "c"];
342		let values = vec!["a", "b", "b", "c", "c", "c"];
343		let result = count_multiplicities(&table, &values, true).unwrap();
344		assert_eq!(result, vec![1, 2, 3]);
345	}
346}
347
348#[cfg(test)]
349mod tests {
350	use binius_core::fiat_shamir::HasherChallenger;
351	use binius_field::tower::CanonicalTowerFamily;
352	use binius_hal::make_portable_backend;
353	use binius_hash::groestl::{Groestl256, Groestl256ByteCompression};
354
355	use super::test_plain_lookup;
356	use crate::builder::ConstraintSystemBuilder;
357
358	#[test]
359	fn test_plain_u8_mul_lookup() {
360		const MAX_LOG_MULTIPLICITY: usize = 20;
361		let log_lookup_count = 19;
362
363		let log_inv_rate = 1;
364		let security_bits = 20;
365
366		let proof = {
367			let allocator = bumpalo::Bump::new();
368			let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
369
370			test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
371				&mut builder,
372				log_lookup_count,
373			)
374			.unwrap();
375
376			let witness = builder.take_witness().unwrap();
377			let constraint_system = builder.build().unwrap();
378			// validating witness with `validate_witness` is too slow for large transparents like
379			// the `table`
380
381			let backend = make_portable_backend();
382
383			binius_core::constraint_system::prove::<
384				crate::builder::types::U,
385				CanonicalTowerFamily,
386				Groestl256,
387				Groestl256ByteCompression,
388				HasherChallenger<Groestl256>,
389				_,
390			>(&constraint_system, log_inv_rate, security_bits, &[], witness, &backend)
391			.unwrap()
392		};
393
394		// verify
395		{
396			let mut builder = ConstraintSystemBuilder::new();
397
398			test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
399				&mut builder,
400				log_lookup_count,
401			)
402			.unwrap();
403
404			let constraint_system = builder.build().unwrap();
405
406			binius_core::constraint_system::verify::<
407				crate::builder::types::U,
408				CanonicalTowerFamily,
409				Groestl256,
410				Groestl256ByteCompression,
411				HasherChallenger<Groestl256>,
412			>(&constraint_system, log_inv_rate, security_bits, &[], proof)
413			.unwrap();
414		}
415	}
416}