1use std::{cmp::Reverse, fmt::Debug, hash::Hash};
4
5use anyhow::{ensure, Result};
6use binius_core::{
7 constraint_system::channel::{FlushDirection, OracleOrConst},
8 oracle::OracleId,
9};
10use binius_field::{
11 as_packed_field::PackScalar,
12 packed::{get_packed_slice, set_packed_slice},
13 BinaryField1b, ExtensionField, Field, PackedField, TowerField,
14};
15use bytemuck::Pod;
16use itertools::izip;
17
18use crate::builder::{
19 types::{F, U},
20 ConstraintSystemBuilder,
21};
22
23pub fn plain_lookup<FTable, const LOG_MAX_MULTIPLICITY: usize>(
58 builder: &mut ConstraintSystemBuilder,
59 name: impl ToString,
60 n_lookups: &[usize],
61 lookups_u: &[impl AsRef<[OracleId]>],
62 lookup_t: impl AsRef<[OracleId]>,
63 multiplicities: Option<impl AsRef<[usize]>>,
64) -> Result<()>
65where
66 U: PackScalar<FTable> + Pod,
67 F: ExtensionField<FTable>,
68 FTable: TowerField,
69{
70 ensure!(n_lookups.len() == lookups_u.len(), "n_vars and lookups_u must be of the same length");
71 ensure!(
72 lookups_u
73 .iter()
74 .all(|oracles| oracles.as_ref().len() == lookup_t.as_ref().len()),
75 "looked up and lookup tables must have the same number of oracles"
76 );
77
78 let lookups_u_count_sum = n_lookups.iter().sum::<usize>();
79 ensure!(lookups_u_count_sum < 1 << LOG_MAX_MULTIPLICITY, "LOG_MAX_MULTIPLICITY too small");
80
81 builder.push_namespace(name);
82
83 let t_log_rows = builder.log_rows(lookup_t.as_ref().iter().copied())?;
84 let bits = builder.add_committed_multiple::<LOG_MAX_MULTIPLICITY>(
85 "multiplicity_bits",
86 t_log_rows,
87 BinaryField1b::TOWER_LEVEL,
88 );
89
90 let permuted_lookup_t = (0..lookup_t.as_ref().len())
91 .map(|i| {
92 builder.add_committed(format!("permuted_t_{}", i), t_log_rows, FTable::TOWER_LEVEL)
93 })
94 .collect::<Vec<_>>();
95
96 if let Some(witness) = builder.witness() {
97 let mut indexed_multiplicities = multiplicities
98 .expect("multiplicities should be supplied when proving")
99 .as_ref()
100 .iter()
101 .copied()
102 .enumerate()
103 .collect::<Vec<_>>();
104
105 let multiplicities_sum = indexed_multiplicities
106 .iter()
107 .map(|&(_, multiplicity)| multiplicity)
108 .sum::<usize>();
109 ensure!(multiplicities_sum == lookups_u_count_sum, "Multiplicities do not add up.");
110
111 indexed_multiplicities.sort_by_key(|&(_, multiplicity)| Reverse(multiplicity));
112
113 for (i, bit) in bits.into_iter().enumerate() {
114 let nonzero_scalars_prefix =
115 indexed_multiplicities.partition_point(|&(_, count)| count >= 1 << i);
116
117 let mut column = witness.new_column_with_nonzero_scalars_prefix::<BinaryField1b>(
118 bit,
119 nonzero_scalars_prefix,
120 );
121
122 let packed = column.packed();
123
124 for (j, &(_, multiplicity)) in indexed_multiplicities.iter().enumerate() {
125 if (1 << i) & multiplicity != 0 {
126 set_packed_slice(packed, j, BinaryField1b::ONE);
127 }
128 }
129 }
130
131 for (&permuted, &original) in izip!(&permuted_lookup_t, lookup_t.as_ref()) {
132 let original_slice = witness.get::<FTable>(original)?.packed();
133
134 let mut permuted_column = witness.new_column::<FTable>(permuted);
135 let permuted_slice = permuted_column.packed();
136
137 let mut iterator = indexed_multiplicities
138 .iter()
139 .map(|&(index, _)| get_packed_slice(original_slice, index));
140 for v in permuted_slice.iter_mut() {
141 *v = PackedField::from_scalars(&mut iterator);
142 }
143 }
144 }
145
146 let permutation_channel = builder.add_channel();
147 let multiplicity_channel = builder.add_channel();
148
149 builder.send(
150 permutation_channel,
151 1 << t_log_rows,
152 permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
153 )?;
154 builder.receive(
155 permutation_channel,
156 1 << t_log_rows,
157 lookup_t.as_ref().iter().copied().map(OracleOrConst::Oracle),
158 )?;
159
160 for (lookup_u, &count) in izip!(lookups_u, n_lookups) {
161 builder.send(
162 multiplicity_channel,
163 count,
164 lookup_u.as_ref().iter().copied().map(OracleOrConst::Oracle),
165 )?;
166 }
167
168 for (i, bit) in bits.into_iter().enumerate() {
169 builder.flush_custom(
170 FlushDirection::Pull,
171 multiplicity_channel,
172 bit,
173 permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
174 1 << i,
175 )?
176 }
177
178 builder.pop_namespace();
179
180 Ok(())
181}
182
183#[cfg(test)]
184pub mod test_plain_lookup {
185 use binius_field::BinaryField32b;
186 use rand::{rngs::StdRng, SeedableRng};
187
188 use super::*;
189 use crate::transparent;
190
191 const fn into_lookup_claim(x: u8, y: u8, z: u16) -> u32 {
192 ((z as u32) << 16) | ((y as u32) << 8) | (x as u32)
193 }
194
195 fn generate_u8_mul_table() -> Vec<u32> {
196 let mut result = Vec::with_capacity(1 << 16);
197 for x in 0..=255u8 {
198 for y in 0..=255u8 {
199 let product = x as u16 * y as u16;
200 result.push(into_lookup_claim(x, y, product));
201 }
202 }
203 result
204 }
205
206 fn generate_random_u8_mul_claims(vals: &mut [u32]) {
207 use rand::Rng;
208 let mut rng = StdRng::seed_from_u64(0);
209 for val in vals {
210 let x = rng.gen_range(0..=255u8);
211 let y = rng.gen_range(0..=255u8);
212 let product = x as u16 * y as u16;
213 *val = into_lookup_claim(x, y, product);
214 }
215 }
216
217 pub fn test_u8_mul_lookup<const LOG_MAX_MULTIPLICITY: usize>(
218 builder: &mut ConstraintSystemBuilder,
219 log_lookup_count: usize,
220 ) -> Result<(), anyhow::Error> {
221 let table_values = generate_u8_mul_table();
222 let table = transparent::make_transparent(
223 builder,
224 "u8_mul_table",
225 bytemuck::cast_slice::<_, BinaryField32b>(&table_values),
226 )?;
227
228 let lookup_values =
229 builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL);
230
231 let lookup_values_count = 1 << log_lookup_count;
232
233 let multiplicities = if let Some(witness) = builder.witness() {
234 let mut lookup_values_col = witness.new_column::<BinaryField32b>(lookup_values);
235 let mut_slice = lookup_values_col.as_mut_slice::<u32>();
236 generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]);
237 Some(count_multiplicities(&table_values, mut_slice, true).unwrap())
238 } else {
239 None
240 };
241
242 plain_lookup::<BinaryField32b, LOG_MAX_MULTIPLICITY>(
243 builder,
244 "u8_mul_lookup",
245 &[1 << log_lookup_count],
246 &[[lookup_values]],
247 &[table],
248 multiplicities,
249 )?;
250
251 Ok(())
252 }
253}
254
255pub fn count_multiplicities<T>(
256 table: &[T],
257 values: &[T],
258 check_inclusion: bool,
259) -> Result<Vec<usize>, anyhow::Error>
260where
261 T: Eq + Hash + Debug,
262{
263 use std::collections::{HashMap, HashSet};
264
265 if check_inclusion {
266 let table_set: HashSet<_> = table.iter().collect();
267 if let Some(invalid_value) = values.iter().find(|value| !table_set.contains(value)) {
268 return Err(anyhow::anyhow!("value {:?} not in table", invalid_value));
269 }
270 }
271
272 let counts: HashMap<_, usize> =
273 values
274 .iter()
275 .fold(HashMap::with_capacity(table.len()), |mut acc, value| {
276 *acc.entry(value).or_insert(0) += 1;
277 acc
278 });
279
280 let multiplicities = table
281 .iter()
282 .map(|item| counts.get(item).copied().unwrap_or(0))
283 .collect();
284
285 Ok(multiplicities)
286}
287
288#[cfg(test)]
289mod count_multiplicity_tests {
290 use super::*;
291
292 #[test]
293 fn test_basic_functionality() {
294 let table = vec![1, 2, 3, 4];
295 let values = vec![1, 2, 2, 3, 3, 3];
296 let result = count_multiplicities(&table, &values, true).unwrap();
297 assert_eq!(result, vec![1, 2, 3, 0]);
298 }
299
300 #[test]
301 fn test_empty_values() {
302 let table = vec![1, 2, 3];
303 let values: Vec<i32> = vec![];
304 let result = count_multiplicities(&table, &values, true).unwrap();
305 assert_eq!(result, vec![0, 0, 0]);
306 }
307
308 #[test]
309 fn test_empty_table() {
310 let table: Vec<i32> = vec![];
311 let values = vec![1, 2, 3];
312 let result = count_multiplicities(&table, &values, false).unwrap();
313 assert_eq!(result, vec![]);
314 }
315
316 #[test]
317 fn test_value_not_in_table() {
318 let table = vec![1, 2, 3];
319 let values = vec![1, 4, 2];
320 let result = count_multiplicities(&table, &values, true);
321 assert!(result.is_err());
322 assert_eq!(result.unwrap_err().to_string(), "value 4 not in table");
323 }
324
325 #[test]
326 fn test_duplicates_in_table() {
327 let table = vec![1, 1, 2, 3];
328 let values = vec![1, 2, 2, 3, 3, 3];
329 let result = count_multiplicities(&table, &values, true).unwrap();
330 assert_eq!(result, vec![1, 1, 2, 3]);
331 }
332
333 #[test]
334 fn test_non_integer_values() {
335 let table = vec!["a", "b", "c"];
336 let values = vec!["a", "b", "b", "c", "c", "c"];
337 let result = count_multiplicities(&table, &values, true).unwrap();
338 assert_eq!(result, vec![1, 2, 3]);
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use binius_core::{fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily};
345 use binius_hal::make_portable_backend;
346 use binius_hash::groestl::{Groestl256, Groestl256ByteCompression};
347
348 use super::test_plain_lookup;
349 use crate::builder::ConstraintSystemBuilder;
350
351 #[test]
352 fn test_plain_u8_mul_lookup() {
353 const MAX_LOG_MULTIPLICITY: usize = 20;
354 let log_lookup_count = 19;
355
356 let log_inv_rate = 1;
357 let security_bits = 20;
358
359 let proof = {
360 let allocator = bumpalo::Bump::new();
361 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
362
363 test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
364 &mut builder,
365 log_lookup_count,
366 )
367 .unwrap();
368
369 let witness = builder.take_witness().unwrap();
370 let constraint_system = builder.build().unwrap();
371 let backend = make_portable_backend();
374
375 binius_core::constraint_system::prove::<
376 crate::builder::types::U,
377 CanonicalTowerFamily,
378 Groestl256,
379 Groestl256ByteCompression,
380 HasherChallenger<Groestl256>,
381 _,
382 >(&constraint_system, log_inv_rate, security_bits, &[], witness, &backend)
383 .unwrap()
384 };
385
386 {
388 let mut builder = ConstraintSystemBuilder::new();
389
390 test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
391 &mut builder,
392 log_lookup_count,
393 )
394 .unwrap();
395
396 let constraint_system = builder.build().unwrap();
397
398 binius_core::constraint_system::verify::<
399 crate::builder::types::U,
400 CanonicalTowerFamily,
401 Groestl256,
402 Groestl256ByteCompression,
403 HasherChallenger<Groestl256>,
404 >(&constraint_system, log_inv_rate, security_bits, &[], proof)
405 .unwrap();
406 }
407 }
408}