1use binius_core::{constraint_system::channel::FlushDirection, oracle::OracleId};
4use binius_field::{
5 as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, ExtensionField, Field,
6 TowerField,
7};
8use bytemuck::Pod;
9
10use crate::builder::{
11 types::{F, U},
12 ConstraintSystemBuilder,
13};
14
15pub fn plain_lookup<FS, const LOG_MAX_MULTIPLICITY: usize>(
45 builder: &mut ConstraintSystemBuilder,
46 table: OracleId,
47 lookup_values: OracleId,
48 lookup_values_count: usize,
49) -> Result<(), anyhow::Error>
50where
51 U: PackScalar<FS> + Pod,
52 F: ExtensionField<FS>,
53 FS: TowerField + Pod,
54{
55 let n_vars = builder.log_rows([table])?;
56
57 let channel = builder.add_channel();
58
59 builder.send(channel, lookup_values_count, [lookup_values])?;
60
61 let mut multiplicities = None;
62 if let Some(witness) = builder.witness() {
64 let table_slice = witness.get::<FS>(table)?.as_slice::<FS>();
65 let values_slice = witness.get::<FS>(lookup_values)?.as_slice::<FS>();
66
67 multiplicities = Some(count_multiplicities(
68 &table_slice[0..1 << n_vars],
69 &values_slice[0..lookup_values_count],
70 false,
71 )?);
72 }
73
74 let bits: [OracleId; LOG_MAX_MULTIPLICITY] = get_bits(builder, table, multiplicities)?;
75 bits.into_iter().enumerate().try_for_each(|(i, bit)| {
76 builder.flush_custom(FlushDirection::Pull, channel, bit, [table], 1 << i)
77 })?;
78
79 Ok(())
80}
81
82fn get_bits<FS, const LOG_MAX_MULTIPLICITY: usize>(
84 builder: &mut ConstraintSystemBuilder,
85 table: OracleId,
86 multiplicities: Option<Vec<usize>>,
87) -> Result<[OracleId; LOG_MAX_MULTIPLICITY], anyhow::Error>
88where
89 U: PackScalar<FS>,
90 F: ExtensionField<FS>,
91 FS: TowerField + Pod,
92{
93 let n_vars = builder.log_rows([table])?;
94
95 let bits: [OracleId; LOG_MAX_MULTIPLICITY] = builder
96 .add_committed_multiple::<LOG_MAX_MULTIPLICITY>("bits", n_vars, BinaryField1b::TOWER_LEVEL);
97
98 if let Some(witness) = builder.witness() {
99 let multiplicities =
100 multiplicities.ok_or_else(|| anyhow::anyhow!("multiplicities empty for prover"))?;
101 debug_assert_eq!(1 << n_vars, multiplicities.len());
102
103 if multiplicities
105 .iter()
106 .any(|&multiplicity| multiplicity >= 1 << LOG_MAX_MULTIPLICITY)
107 {
108 return Err(anyhow::anyhow!(
109 "one or more multiplicities exceed `1 << LOG_MAX_MULTIPLICITY`"
110 ));
111 }
112
113 let mut bit_cols = bits.map(|bit| witness.new_column::<BinaryField1b>(bit));
115 let mut packed_bit_cols = bit_cols.each_mut().map(|bit_col| bit_col.packed());
116
117 multiplicities
118 .iter()
119 .enumerate()
120 .for_each(|(i, multiplicity)| {
121 (0..LOG_MAX_MULTIPLICITY).for_each(|j| {
122 let bit_set = multiplicity & (1 << j) != 0;
123 set_packed_slice(
124 packed_bit_cols[j],
125 i,
126 match bit_set {
127 true => BinaryField1b::ONE,
128 false => BinaryField1b::ZERO,
129 },
130 );
131 })
132 });
133 }
134
135 Ok(bits)
136}
137
138#[cfg(test)]
139pub mod test_plain_lookup {
140 use binius_field::BinaryField32b;
141 use binius_maybe_rayon::prelude::*;
142
143 use super::*;
144 use crate::transparent;
145
146 const fn into_lookup_claim(x: u8, y: u8, z: u16) -> u32 {
147 ((z as u32) << 16) | ((y as u32) << 8) | (x as u32)
148 }
149
150 fn generate_u8_mul_table() -> Vec<u32> {
151 let mut result = Vec::with_capacity(1 << 16);
152 for x in 0..=255u8 {
153 for y in 0..=255u8 {
154 let product = x as u16 * y as u16;
155 result.push(into_lookup_claim(x, y, product));
156 }
157 }
158 result
159 }
160
161 fn generate_random_u8_mul_claims(vals: &mut [u32]) {
162 use rand::Rng;
163 vals.par_iter_mut().for_each(|val| {
164 let mut rng = rand::thread_rng();
165 let x = rng.gen_range(0..=255u8);
166 let y = rng.gen_range(0..=255u8);
167 let product = x as u16 * y as u16;
168 *val = into_lookup_claim(x, y, product);
169 });
170 }
171
172 pub fn test_u8_mul_lookup<const LOG_MAX_MULTIPLICITY: usize>(
173 builder: &mut ConstraintSystemBuilder,
174 log_lookup_count: usize,
175 ) -> Result<(), anyhow::Error> {
176 let table_values = generate_u8_mul_table();
177 let table = transparent::make_transparent(
178 builder,
179 "u8_mul_table",
180 bytemuck::cast_slice::<_, BinaryField32b>(&table_values),
181 )?;
182
183 let lookup_values =
184 builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL);
185
186 let lookup_values_count = 1 << log_lookup_count;
187
188 if let Some(witness) = builder.witness() {
189 let mut lookup_values_col = witness.new_column::<BinaryField32b>(lookup_values);
190 let mut_slice = lookup_values_col.as_mut_slice::<u32>();
191 generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]);
192 }
193
194 plain_lookup::<BinaryField32b, LOG_MAX_MULTIPLICITY>(
195 builder,
196 table,
197 lookup_values,
198 lookup_values_count,
199 )?;
200
201 Ok(())
202 }
203}
204
205fn count_multiplicities<T: Eq + std::hash::Hash + Clone + std::fmt::Debug>(
206 table: &[T],
207 values: &[T],
208 check_inclusion: bool,
209) -> Result<Vec<usize>, anyhow::Error> {
210 use std::collections::{HashMap, HashSet};
211
212 if check_inclusion {
213 let table_set: HashSet<_> = table.iter().cloned().collect();
214 if let Some(invalid_value) = values.iter().find(|value| !table_set.contains(value)) {
215 return Err(anyhow::anyhow!("value {:?} not in table", invalid_value));
216 }
217 }
218
219 let counts: HashMap<_, usize> = values.iter().fold(HashMap::new(), |mut acc, value| {
220 *acc.entry(value).or_insert(0) += 1;
221 acc
222 });
223
224 let multiplicities = table
225 .iter()
226 .map(|item| counts.get(item).copied().unwrap_or(0))
227 .collect();
228
229 Ok(multiplicities)
230}
231
232#[cfg(test)]
233mod count_multiplicity_tests {
234 use super::*;
235
236 #[test]
237 fn test_basic_functionality() {
238 let table = vec![1, 2, 3, 4];
239 let values = vec![1, 2, 2, 3, 3, 3];
240 let result = count_multiplicities(&table, &values, true).unwrap();
241 assert_eq!(result, vec![1, 2, 3, 0]);
242 }
243
244 #[test]
245 fn test_empty_values() {
246 let table = vec![1, 2, 3];
247 let values: Vec<i32> = vec![];
248 let result = count_multiplicities(&table, &values, true).unwrap();
249 assert_eq!(result, vec![0, 0, 0]);
250 }
251
252 #[test]
253 fn test_empty_table() {
254 let table: Vec<i32> = vec![];
255 let values = vec![1, 2, 3];
256 let result = count_multiplicities(&table, &values, false).unwrap();
257 assert_eq!(result, vec![]);
258 }
259
260 #[test]
261 fn test_value_not_in_table() {
262 let table = vec![1, 2, 3];
263 let values = vec![1, 4, 2];
264 let result = count_multiplicities(&table, &values, true);
265 assert!(result.is_err());
266 assert_eq!(result.unwrap_err().to_string(), "value 4 not in table");
267 }
268
269 #[test]
270 fn test_duplicates_in_table() {
271 let table = vec![1, 1, 2, 3];
272 let values = vec![1, 2, 2, 3, 3, 3];
273 let result = count_multiplicities(&table, &values, true).unwrap();
274 assert_eq!(result, vec![1, 1, 2, 3]);
275 }
276
277 #[test]
278 fn test_non_integer_values() {
279 let table = vec!["a", "b", "c"];
280 let values = vec!["a", "b", "b", "c", "c", "c"];
281 let result = count_multiplicities(&table, &values, true).unwrap();
282 assert_eq!(result, vec![1, 2, 3]);
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use binius_core::{fiat_shamir::HasherChallenger, tower::CanonicalTowerFamily};
289 use binius_hal::make_portable_backend;
290 use binius_hash::compress::Groestl256ByteCompression;
291 use binius_math::DefaultEvaluationDomainFactory;
292 use groestl_crypto::Groestl256;
293
294 use super::test_plain_lookup;
295 use crate::builder::ConstraintSystemBuilder;
296
297 #[test]
298 fn test_plain_u8_mul_lookup() {
299 const MAX_LOG_MULTIPLICITY: usize = 18;
300 let log_lookup_count = 19;
301
302 let log_inv_rate = 1;
303 let security_bits = 20;
304
305 let proof = {
306 let allocator = bumpalo::Bump::new();
307 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
308
309 test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
310 &mut builder,
311 log_lookup_count,
312 )
313 .unwrap();
314
315 let witness = builder.take_witness().unwrap();
316 let constraint_system = builder.build().unwrap();
317 let domain_factory = DefaultEvaluationDomainFactory::default();
320 let backend = make_portable_backend();
321
322 binius_core::constraint_system::prove::<
323 crate::builder::types::U,
324 CanonicalTowerFamily,
325 _,
326 Groestl256,
327 Groestl256ByteCompression,
328 HasherChallenger<Groestl256>,
329 _,
330 >(
331 &constraint_system,
332 log_inv_rate,
333 security_bits,
334 &[],
335 witness,
336 &domain_factory,
337 &backend,
338 )
339 .unwrap()
340 };
341
342 {
344 let mut builder = ConstraintSystemBuilder::new();
345
346 test_plain_lookup::test_u8_mul_lookup::<MAX_LOG_MULTIPLICITY>(
347 &mut builder,
348 log_lookup_count,
349 )
350 .unwrap();
351
352 let constraint_system = builder.build().unwrap();
353
354 binius_core::constraint_system::verify::<
355 crate::builder::types::U,
356 CanonicalTowerFamily,
357 Groestl256,
358 Groestl256ByteCompression,
359 HasherChallenger<Groestl256>,
360 >(&constraint_system, log_inv_rate, security_bits, &[], proof)
361 .unwrap();
362 }
363 }
364}