binius_circuits/
plain_lookup.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp::Reverse, fmt::Debug, hash::Hash};
4
5use anyhow::{ensure, Result};
6use binius_core::{
7	constraint_system::channel::{FlushDirection, OracleOrConst},
8	oracle::OracleId,
9};
10use binius_field::{
11	as_packed_field::PackScalar,
12	packed::{get_packed_slice, set_packed_slice},
13	BinaryField1b, ExtensionField, Field, PackedField, TowerField,
14};
15use bytemuck::Pod;
16use itertools::izip;
17
18use crate::builder::{
19	types::{F, U},
20	ConstraintSystemBuilder,
21};
22
23/// A gadget validating the lookup relation between:
24/// * `lookups_u` - the set of "looked up" tables
25/// * `lookup_t` - the lookup table
26///
27/// Both looked up and lookup tables are defined by tuples of column oracles, where each column has the same `FTable` type.
28/// Primary reason to support this behaviour is to be able to lookup into tables "wider" than the largest 128-bit field.
29///
30/// Looked up tables are assumed to have `n_lookups` values each, whereas `lookup_t` is considered to be always full.
31///
32/// The prover needs to provide multiplicities of the lookup table elements; the helper [`count_multiplicities`] method
33/// does that using a hash map of counters, but in some cases it should be possible to compute this table more efficiently
34/// or in an indirect way.
35///
36/// # How this Works
37/// We create two channel for this lookup - a multiplicity channel and a permutation channel.
38/// We let the prover push all values in `lookups_u`, that is all values to be looked up, into the multiplicity channel.
39/// We also must pull valid table values (i.e. values that appear in `lookup_t`) from this channel if the channel is to balance.
40/// By ensuring that only valid table values get pulled from the channel, and observing the channel to balance, we ensure
41/// that only valid table values get pushed (by the prover) into the channel. Therefore our construction is sound.
42///
43/// In order for the construction to be complete, allowing an honest prover to pass, we must pull each
44/// table value from the multiplicity channel with exactly the same multiplicity (duplicate count) that the prover pushed that table
45/// value into the channel. To do so, we allow the prover to commit information on the multiplicity of each table value.
46///
47/// The prover counts the multiplicity of each table value, and creates a bit column for each of the LOG_MAX_MULTIPLICITY bits in
48/// the bit-decomposition of the multiplicities.
49/// Then we flush the table values LOG_MAX_MULTIPLICITY times, each time using a different bit column as the 'selector' oracle to select
50/// which values in the table actually get pushed into the channel flushed. When flushing the table with the i'th bit column as the selector,
51///we flush with multiplicity 1 << i.
52///
53/// The reason for using _two_ channels is a prover-side optimization - instead of counting multiplicities on the original `lookup_t`, we
54/// commit a permuted version of that with non-decreasing multiplicities. This enforces nonzero scalars prefixes on the committed multiplicity
55/// bits columns, which can be used to optimize the flush and GKR reduction sumchecks. In order to constrain this committed lookup to be a
56/// permutation of lookup_t we do a push/pull on the permutation channel.
57pub fn plain_lookup<FTable, const LOG_MAX_MULTIPLICITY: usize>(
58	builder: &mut ConstraintSystemBuilder,
59	name: impl ToString,
60	n_lookups: &[usize],
61	lookups_u: &[impl AsRef<[OracleId]>],
62	lookup_t: impl AsRef<[OracleId]>,
63	multiplicities: Option<impl AsRef<[usize]>>,
64) -> Result<()>
65where
66	U: PackScalar<FTable> + Pod,
67	F: ExtensionField<FTable>,
68	FTable: TowerField,
69{
70	ensure!(n_lookups.len() == lookups_u.len(), "n_vars and lookups_u must be of the same length");
71	ensure!(
72		lookups_u
73			.iter()
74			.all(|oracles| oracles.as_ref().len() == lookup_t.as_ref().len()),
75		"looked up and lookup tables must have the same number of oracles"
76	);
77
78	let lookups_u_count_sum = n_lookups.iter().sum::<usize>();
79	ensure!(lookups_u_count_sum < 1 << LOG_MAX_MULTIPLICITY, "LOG_MAX_MULTIPLICITY too small");
80
81	builder.push_namespace(name);
82
83	let t_log_rows = builder.log_rows(lookup_t.as_ref().iter().copied())?;
84	let bits = builder.add_committed_multiple::<LOG_MAX_MULTIPLICITY>(
85		"multiplicity_bits",
86		t_log_rows,
87		BinaryField1b::TOWER_LEVEL,
88	);
89
90	let permuted_lookup_t = (0..lookup_t.as_ref().len())
91		.map(|i| {
92			builder.add_committed(format!("permuted_t_{}", i), t_log_rows, FTable::TOWER_LEVEL)
93		})
94		.collect::<Vec<_>>();
95
96	if let Some(witness) = builder.witness() {
97		let mut indexed_multiplicities = multiplicities
98			.expect("multiplicities should be supplied when proving")
99			.as_ref()
100			.iter()
101			.copied()
102			.enumerate()
103			.collect::<Vec<_>>();
104
105		let multiplicities_sum = indexed_multiplicities
106			.iter()
107			.map(|&(_, multiplicity)| multiplicity)
108			.sum::<usize>();
109		ensure!(multiplicities_sum == lookups_u_count_sum, "Multiplicities do not add up.");
110
111		indexed_multiplicities.sort_by_key(|&(_, multiplicity)| Reverse(multiplicity));
112
113		for (i, bit) in bits.into_iter().enumerate() {
114			let nonzero_scalars_prefix =
115				indexed_multiplicities.partition_point(|&(_, count)| count >= 1 << i);
116
117			let mut column = witness.new_column_with_nonzero_scalars_prefix::<BinaryField1b>(
118				bit,
119				nonzero_scalars_prefix,
120			);
121
122			let packed = column.packed();
123
124			for (j, &(_, multiplicity)) in indexed_multiplicities.iter().enumerate() {
125				if (1 << i) & multiplicity != 0 {
126					set_packed_slice(packed, j, BinaryField1b::ONE);
127				}
128			}
129		}
130
131		for (&permuted, &original) in izip!(&permuted_lookup_t, lookup_t.as_ref()) {
132			let original_slice = witness.get::<FTable>(original)?.packed();
133
134			let mut permuted_column = witness.new_column::<FTable>(permuted);
135			let permuted_slice = permuted_column.packed();
136
137			let mut iterator = indexed_multiplicities
138				.iter()
139				.map(|&(index, _)| get_packed_slice(original_slice, index));
140			for v in permuted_slice.iter_mut() {
141				*v = PackedField::from_scalars(&mut iterator);
142			}
143		}
144	}
145
146	let permutation_channel = builder.add_channel();
147	let multiplicity_channel = builder.add_channel();
148
149	builder.send(
150		permutation_channel,
151		1 << t_log_rows,
152		permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
153	)?;
154	builder.receive(
155		permutation_channel,
156		1 << t_log_rows,
157		lookup_t.as_ref().iter().copied().map(OracleOrConst::Oracle),
158	)?;
159
160	for (lookup_u, &count) in izip!(lookups_u, n_lookups) {
161		builder.send(
162			multiplicity_channel,
163			count,
164			lookup_u.as_ref().iter().copied().map(OracleOrConst::Oracle),
165		)?;
166	}
167
168	for (i, bit) in bits.into_iter().enumerate() {
169		builder.flush_custom(
170			FlushDirection::Pull,
171			multiplicity_channel,
172			bit,
173			permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
174			1 << i,
175		)?
176	}
177
178	builder.pop_namespace();
179
180	Ok(())
181}
182
183#[cfg(test)]
184pub mod test_plain_lookup {
185	use binius_field::BinaryField32b;
186	use rand::{rngs::StdRng, SeedableRng};
187
188	use super::*;
189	use crate::transparent;
190
191	const fn into_lookup_claim(x: u8, y: u8, z: u16) -> u32 {
192		((z as u32) << 16) | ((y as u32) << 8) | (x as u32)
193	}
194
195	fn generate_u8_mul_table() -> Vec<u32> {
196		let mut result = Vec::with_capacity(1 << 16);
197		for x in 0..=255u8 {
198			for y in 0..=255u8 {
199				let product = x as u16 * y as u16;
200				result.push(into_lookup_claim(x, y, product));
201			}
202		}
203		result
204	}
205
206	fn generate_random_u8_mul_claims(vals: &mut [u32]) {
207		use rand::Rng;
208		let mut rng = StdRng::seed_from_u64(0);
209		for val in vals {
210			let x = rng.gen_range(0..=255u8);
211			let y = rng.gen_range(0..=255u8);
212			let product = x as u16 * y as u16;
213			*val = into_lookup_claim(x, y, product);
214		}
215	}
216
217	pub fn test_u8_mul_lookup<const LOG_MAX_MULTIPLICITY: usize>(
218		builder: &mut ConstraintSystemBuilder,
219		log_lookup_count: usize,
220	) -> Result<(), anyhow::Error> {
221		let table_values = generate_u8_mul_table();
222		let table = transparent::make_transparent(
223			builder,
224			"u8_mul_table",
225			bytemuck::cast_slice::<_, BinaryField32b>(&table_values),
226		)?;
227
228		let lookup_values =
229			builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL);
230
231		let lookup_values_count = 1 << log_lookup_count;
232
233		let multiplicities = if let Some(witness) = builder.witness() {
234			let mut lookup_values_col = witness.new_column::<BinaryField32b>(lookup_values);
235			let mut_slice = lookup_values_col.as_mut_slice::<u32>();
236			generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]);
237			Some(count_multiplicities(&table_values, mut_slice, true).unwrap())
238		} else {
239			None
240		};
241
242		plain_lookup::<BinaryField32b, LOG_MAX_MULTIPLICITY>(
243			builder,
244			"u8_mul_lookup",
245			&[1 << log_lookup_count],
246			&[[lookup_values]],
247			&[table],
248			multiplicities,
249		)?;
250
251		Ok(())
252	}
253}
254
255pub fn count_multiplicities<T>(
256	table: &[T],
257	values: &[T],
258	check_inclusion: bool,
259) -> Result<Vec<usize>, anyhow::Error>
260where
261	T: Eq + Hash + Debug,
262{
263	use std::collections::{HashMap, HashSet};
264
265	if check_inclusion {
266		let table_set: HashSet<_> = table.iter().collect();
267		if let Some(invalid_value) = values.iter().find(|value| !table_set.contains(value)) {
268			return Err(anyhow::anyhow!("value {:?} not in table", invalid_value));
269		}
270	}
271
272	let counts: HashMap<_, usize> =
273		values
274			.iter()
275			.fold(HashMap::with_capacity(table.len()), |mut acc, value| {
276				*acc.entry(value).or_insert(0) += 1;
277				acc
278			});
279
280	let multiplicities = table
281		.iter()
282		.map(|item| counts.get(item).copied().unwrap_or(0))
283		.collect();
284
285	Ok(multiplicities)
286}
287
288#[cfg(test)]
289mod count_multiplicity_tests {
290	use super::*;
291
292	#[test]
293	fn test_basic_functionality() {
294		let table = vec![1, 2, 3, 4];
295		let values = vec![1, 2, 2, 3, 3, 3];
296		let result = count_multiplicities(&table, &values, true).unwrap();
297		assert_eq!(result, vec![1, 2, 3, 0]);
298	}
299
300	#[test]
301	fn test_empty_values() {
302		let table = vec![1, 2, 3];
303		let values: Vec<i32> = vec![];
304		let result = count_multiplicities(&table, &values, true).unwrap();
305		assert_eq!(result, vec![0, 0, 0]);
306	}
307
308	#[test]
309	fn test_empty_table() {
310		let table: Vec<i32> = vec![];
311		let values = vec![1, 2, 3];
312		let result = count_multiplicities(&table, &values, false).unwrap();
313		assert_eq!(result, vec![]);
314	}
315
316	#[test]
317	fn test_value_not_in_table() {
318		let table = vec![1, 2, 3];
319		let values = vec![1, 4, 2];
320		let result = count_multiplicities(&table, &values, true);
321		assert!(result.is_err());
322		assert_eq!(result.unwrap_err().to_string(), "value 4 not in table");
323	}
324
325	#[test]
326	fn test_duplicates_in_table() {
327		let table = vec![1, 1, 2, 3];
328		let values = vec![1, 2, 2, 3, 3, 3];
329		let result = count_multiplicities(&table, &values, true).unwrap();
330		assert_eq!(result, vec![1, 1, 2, 3]);
331	}
332
333	#[test]
334	fn test_non_integer_values() {
335		let table = vec!["a", "b", "c"];
336		let values = vec!["a", "b", "b", "c", "c", "c"];
337		let result = count_multiplicities(&table, &values, true).unwrap();
338		assert_eq!(result, vec![1, 2, 3]);
339	}
340}
341
342#[cfg(test)]
343mod tests {
344	use binius_core::{fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily};
345	use binius_hal::make_portable_backend;
346	use binius_hash::groestl::{Groestl256, Groestl256ByteCompression};
347
348	use super::test_plain_lookup;
349	use crate::builder::ConstraintSystemBuilder;
350
351	#[test]
352	fn test_plain_u8_mul_lookup() {
353		const MAX_LOG_MULTIPLICITY: usize = 20;
354		let log_lookup_count = 19;
355
356		let log_inv_rate = 1;
357		let security_bits = 20;
358
359		let proof = {
360			let allocator = bumpalo::Bump::new();
361			let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
362
363			test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
364				&mut builder,
365				log_lookup_count,
366			)
367			.unwrap();
368
369			let witness = builder.take_witness().unwrap();
370			let constraint_system = builder.build().unwrap();
371			// validating witness with `validate_witness` is too slow for large transparents like the `table`
372
373			let backend = make_portable_backend();
374
375			binius_core::constraint_system::prove::<
376				crate::builder::types::U,
377				CanonicalTowerFamily,
378				Groestl256,
379				Groestl256ByteCompression,
380				HasherChallenger<Groestl256>,
381				_,
382			>(&constraint_system, log_inv_rate, security_bits, &[], witness, &backend)
383			.unwrap()
384		};
385
386		// verify
387		{
388			let mut builder = ConstraintSystemBuilder::new();
389
390			test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
391				&mut builder,
392				log_lookup_count,
393			)
394			.unwrap();
395
396			let constraint_system = builder.build().unwrap();
397
398			binius_core::constraint_system::verify::<
399				crate::builder::types::U,
400				CanonicalTowerFamily,
401				Groestl256,
402				Groestl256ByteCompression,
403				HasherChallenger<Groestl256>,
404			>(&constraint_system, log_inv_rate, security_bits, &[], proof)
405			.unwrap();
406		}
407	}
408}