1use std::{cmp::Reverse, fmt::Debug, hash::Hash};
4
5use anyhow::{ensure, Result};
6use binius_core::{
7 constraint_system::channel::{FlushDirection, OracleOrConst},
8 oracle::OracleId,
9};
10use binius_field::{
11 as_packed_field::PackScalar,
12 packed::{get_packed_slice, set_packed_slice},
13 BinaryField1b, ExtensionField, Field, PackedField, TowerField,
14};
15use bytemuck::Pod;
16use itertools::izip;
17
18use crate::builder::{
19 types::{F, U},
20 ConstraintSystemBuilder,
21};
22
23pub fn plain_lookup<FTable, const LOG_MAX_MULTIPLICITY: usize>(
66 builder: &mut ConstraintSystemBuilder,
67 name: impl ToString,
68 n_lookups: &[usize],
69 lookups_u: &[impl AsRef<[OracleId]>],
70 lookup_t: impl AsRef<[OracleId]>,
71 multiplicities: Option<impl AsRef<[usize]>>,
72) -> Result<()>
73where
74 U: PackScalar<FTable> + Pod,
75 F: ExtensionField<FTable>,
76 FTable: TowerField,
77{
78 ensure!(n_lookups.len() == lookups_u.len(), "n_vars and lookups_u must be of the same length");
79 ensure!(
80 lookups_u
81 .iter()
82 .all(|oracles| oracles.as_ref().len() == lookup_t.as_ref().len()),
83 "looked up and lookup tables must have the same number of oracles"
84 );
85
86 let lookups_u_count_sum = n_lookups.iter().sum::<usize>();
87 ensure!(lookups_u_count_sum < 1 << LOG_MAX_MULTIPLICITY, "LOG_MAX_MULTIPLICITY too small");
88
89 builder.push_namespace(name);
90
91 let t_log_rows = builder.log_rows(lookup_t.as_ref().iter().copied())?;
92 let bits = builder.add_committed_multiple::<LOG_MAX_MULTIPLICITY>(
93 "multiplicity_bits",
94 t_log_rows,
95 BinaryField1b::TOWER_LEVEL,
96 );
97
98 let permuted_lookup_t = (0..lookup_t.as_ref().len())
99 .map(|i| builder.add_committed(format!("permuted_t_{i}"), t_log_rows, FTable::TOWER_LEVEL))
100 .collect::<Vec<_>>();
101
102 if let Some(witness) = builder.witness() {
103 let mut indexed_multiplicities = multiplicities
104 .expect("multiplicities should be supplied when proving")
105 .as_ref()
106 .iter()
107 .copied()
108 .enumerate()
109 .collect::<Vec<_>>();
110
111 let multiplicities_sum = indexed_multiplicities
112 .iter()
113 .map(|&(_, multiplicity)| multiplicity)
114 .sum::<usize>();
115 ensure!(multiplicities_sum == lookups_u_count_sum, "Multiplicities do not add up.");
116
117 indexed_multiplicities.sort_by_key(|&(_, multiplicity)| Reverse(multiplicity));
118
119 for (i, bit) in bits.into_iter().enumerate() {
120 let nonzero_scalars_prefix =
121 indexed_multiplicities.partition_point(|&(_, count)| count >= 1 << i);
122
123 let mut column = witness.new_column_with_nonzero_scalars_prefix::<BinaryField1b>(
124 bit,
125 nonzero_scalars_prefix,
126 );
127
128 let packed = column.packed();
129
130 for (j, &(_, multiplicity)) in indexed_multiplicities.iter().enumerate() {
131 if (1 << i) & multiplicity != 0 {
132 set_packed_slice(packed, j, BinaryField1b::ONE);
133 }
134 }
135 }
136
137 for (&permuted, &original) in izip!(&permuted_lookup_t, lookup_t.as_ref()) {
138 let original_slice = witness.get::<FTable>(original)?.packed();
139
140 let mut permuted_column = witness.new_column::<FTable>(permuted);
141 let permuted_slice = permuted_column.packed();
142
143 let mut iterator = indexed_multiplicities
144 .iter()
145 .map(|&(index, _)| get_packed_slice(original_slice, index));
146 for v in permuted_slice.iter_mut() {
147 *v = PackedField::from_scalars(&mut iterator);
148 }
149 }
150 }
151
152 let permutation_channel = builder.add_channel();
153 let multiplicity_channel = builder.add_channel();
154
155 builder.send(
156 permutation_channel,
157 1 << t_log_rows,
158 permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
159 )?;
160 builder.receive(
161 permutation_channel,
162 1 << t_log_rows,
163 lookup_t.as_ref().iter().copied().map(OracleOrConst::Oracle),
164 )?;
165
166 for (lookup_u, &count) in izip!(lookups_u, n_lookups) {
167 builder.send(
168 multiplicity_channel,
169 count,
170 lookup_u.as_ref().iter().copied().map(OracleOrConst::Oracle),
171 )?;
172 }
173
174 for (i, bit) in bits.into_iter().enumerate() {
175 builder.flush_custom(
176 FlushDirection::Pull,
177 multiplicity_channel,
178 vec![bit],
179 permuted_lookup_t.iter().copied().map(OracleOrConst::Oracle),
180 1 << i,
181 )?
182 }
183
184 builder.pop_namespace();
185
186 Ok(())
187}
188
189#[cfg(test)]
190pub mod test_plain_lookup {
191 use binius_field::BinaryField32b;
192 use rand::{rngs::StdRng, SeedableRng};
193
194 use super::*;
195 use crate::transparent;
196
197 const fn into_lookup_claim(x: u8, y: u8, z: u16) -> u32 {
198 ((z as u32) << 16) | ((y as u32) << 8) | (x as u32)
199 }
200
201 fn generate_u8_mul_table() -> Vec<u32> {
202 let mut result = Vec::with_capacity(1 << 16);
203 for x in 0..=255u8 {
204 for y in 0..=255u8 {
205 let product = x as u16 * y as u16;
206 result.push(into_lookup_claim(x, y, product));
207 }
208 }
209 result
210 }
211
212 fn generate_random_u8_mul_claims(vals: &mut [u32]) {
213 use rand::Rng;
214 let mut rng = StdRng::seed_from_u64(0);
215 for val in vals {
216 let x = rng.gen_range(0..=255u8);
217 let y = rng.gen_range(0..=255u8);
218 let product = x as u16 * y as u16;
219 *val = into_lookup_claim(x, y, product);
220 }
221 }
222
223 pub fn test_u8_mul_lookup<const LOG_MAX_MULTIPLICITY: usize>(
224 builder: &mut ConstraintSystemBuilder,
225 log_lookup_count: usize,
226 ) -> Result<(), anyhow::Error> {
227 let table_values = generate_u8_mul_table();
228 let table = transparent::make_transparent(
229 builder,
230 "u8_mul_table",
231 bytemuck::cast_slice::<_, BinaryField32b>(&table_values),
232 )?;
233
234 let lookup_values =
235 builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL);
236
237 let lookup_values_count = 1 << log_lookup_count;
238
239 let multiplicities = if let Some(witness) = builder.witness() {
240 let mut lookup_values_col = witness.new_column::<BinaryField32b>(lookup_values);
241 let mut_slice = lookup_values_col.as_mut_slice::<u32>();
242 generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]);
243 Some(count_multiplicities(&table_values, mut_slice, true).unwrap())
244 } else {
245 None
246 };
247
248 plain_lookup::<BinaryField32b, LOG_MAX_MULTIPLICITY>(
249 builder,
250 "u8_mul_lookup",
251 &[1 << log_lookup_count],
252 &[[lookup_values]],
253 &[table],
254 multiplicities,
255 )?;
256
257 Ok(())
258 }
259}
260
261pub fn count_multiplicities<T>(
262 table: &[T],
263 values: &[T],
264 check_inclusion: bool,
265) -> Result<Vec<usize>, anyhow::Error>
266where
267 T: Eq + Hash + Debug,
268{
269 use std::collections::{HashMap, HashSet};
270
271 if check_inclusion {
272 let table_set: HashSet<_> = table.iter().collect();
273 if let Some(invalid_value) = values.iter().find(|value| !table_set.contains(value)) {
274 return Err(anyhow::anyhow!("value {:?} not in table", invalid_value));
275 }
276 }
277
278 let counts: HashMap<_, usize> =
279 values
280 .iter()
281 .fold(HashMap::with_capacity(table.len()), |mut acc, value| {
282 *acc.entry(value).or_insert(0) += 1;
283 acc
284 });
285
286 let multiplicities = table
287 .iter()
288 .map(|item| counts.get(item).copied().unwrap_or(0))
289 .collect();
290
291 Ok(multiplicities)
292}
293
294#[cfg(test)]
295mod count_multiplicity_tests {
296 use super::*;
297
298 #[test]
299 fn test_basic_functionality() {
300 let table = vec![1, 2, 3, 4];
301 let values = vec![1, 2, 2, 3, 3, 3];
302 let result = count_multiplicities(&table, &values, true).unwrap();
303 assert_eq!(result, vec![1, 2, 3, 0]);
304 }
305
306 #[test]
307 fn test_empty_values() {
308 let table = vec![1, 2, 3];
309 let values: Vec<i32> = vec![];
310 let result = count_multiplicities(&table, &values, true).unwrap();
311 assert_eq!(result, vec![0, 0, 0]);
312 }
313
314 #[test]
315 fn test_empty_table() {
316 let table: Vec<i32> = vec![];
317 let values = vec![1, 2, 3];
318 let result = count_multiplicities(&table, &values, false).unwrap();
319 assert!(result.is_empty());
320 }
321
322 #[test]
323 fn test_value_not_in_table() {
324 let table = vec![1, 2, 3];
325 let values = vec![1, 4, 2];
326 let result = count_multiplicities(&table, &values, true);
327 assert!(result.is_err());
328 assert_eq!(result.unwrap_err().to_string(), "value 4 not in table");
329 }
330
331 #[test]
332 fn test_duplicates_in_table() {
333 let table = vec![1, 1, 2, 3];
334 let values = vec![1, 2, 2, 3, 3, 3];
335 let result = count_multiplicities(&table, &values, true).unwrap();
336 assert_eq!(result, vec![1, 1, 2, 3]);
337 }
338
339 #[test]
340 fn test_non_integer_values() {
341 let table = vec!["a", "b", "c"];
342 let values = vec!["a", "b", "b", "c", "c", "c"];
343 let result = count_multiplicities(&table, &values, true).unwrap();
344 assert_eq!(result, vec![1, 2, 3]);
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use binius_core::fiat_shamir::HasherChallenger;
351 use binius_field::tower::CanonicalTowerFamily;
352 use binius_hal::make_portable_backend;
353 use binius_hash::groestl::{Groestl256, Groestl256ByteCompression};
354
355 use super::test_plain_lookup;
356 use crate::builder::ConstraintSystemBuilder;
357
358 #[test]
359 fn test_plain_u8_mul_lookup() {
360 const MAX_LOG_MULTIPLICITY: usize = 20;
361 let log_lookup_count = 19;
362
363 let log_inv_rate = 1;
364 let security_bits = 20;
365
366 let proof = {
367 let allocator = bumpalo::Bump::new();
368 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
369
370 test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
371 &mut builder,
372 log_lookup_count,
373 )
374 .unwrap();
375
376 let witness = builder.take_witness().unwrap();
377 let constraint_system = builder.build().unwrap();
378 let backend = make_portable_backend();
382
383 binius_core::constraint_system::prove::<
384 crate::builder::types::U,
385 CanonicalTowerFamily,
386 Groestl256,
387 Groestl256ByteCompression,
388 HasherChallenger<Groestl256>,
389 _,
390 >(&constraint_system, log_inv_rate, security_bits, &[], witness, &backend)
391 .unwrap()
392 };
393
394 {
396 let mut builder = ConstraintSystemBuilder::new();
397
398 test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
399 &mut builder,
400 log_lookup_count,
401 )
402 .unwrap();
403
404 let constraint_system = builder.build().unwrap();
405
406 binius_core::constraint_system::verify::<
407 crate::builder::types::U,
408 CanonicalTowerFamily,
409 Groestl256,
410 Groestl256ByteCompression,
411 HasherChallenger<Groestl256>,
412 >(&constraint_system, log_inv_rate, security_bits, &[], proof)
413 .unwrap();
414 }
415 }
416}