1use std::array;
12
13use anyhow::Error;
14use binius_core::oracle::OracleId;
15use binius_field::{
16 as_packed_field::PackedType,
17 packed::{get_packed_slice, set_packed_slice},
18 BinaryField, BinaryField16b, BinaryField1b, BinaryField64b, Field, TowerField,
19};
20use binius_macros::arith_expr;
21use binius_maybe_rayon::iter::{
22 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
23};
24use binius_utils::bail;
25use itertools::izip;
26
27use super::static_exp::u16_static_exp_lookups;
28use crate::builder::{
29 types::{F, U},
30 ConstraintSystemBuilder,
31};
32
33pub fn mul<FExpBase>(
34 builder: &mut ConstraintSystemBuilder,
35 name: impl ToString,
36 xin_bits: Vec<OracleId>,
37 yin_bits: Vec<OracleId>,
38) -> Result<Vec<OracleId>, anyhow::Error>
39where
40 FExpBase: TowerField,
41 F: From<FExpBase>,
42{
43 let name = name.to_string();
44
45 let log_rows = builder.log_rows([xin_bits.clone(), yin_bits.clone()].into_iter().flatten())?;
46
47 let xin_exp_result_id =
49 builder.add_committed(format!("{} xin_exp_result", name), log_rows, FExpBase::TOWER_LEVEL);
50
51 let yin_exp_result_id =
53 builder.add_committed(format!("{} yin_exp_result", name), log_rows, FExpBase::TOWER_LEVEL);
54
55 let cout_low_exp_result_id = builder.add_committed(
57 format!("{} cout_low_exp_result", name),
58 log_rows,
59 FExpBase::TOWER_LEVEL,
60 );
61
62 let cout_high_exp_result_id = builder.add_committed(
64 format!("{} cout_high_exp_result", name),
65 log_rows,
66 FExpBase::TOWER_LEVEL,
67 );
68
69 let result_bits = xin_bits.len() + yin_bits.len();
70
71 if result_bits > FExpBase::N_BITS {
72 bail!(anyhow::anyhow!("FExpBase to small"));
73 }
74
75 let cout_bits = (0..result_bits)
76 .map(|i| {
77 builder.add_committed(
78 format!("{} bit of {}", i, name),
79 log_rows,
80 BinaryField1b::TOWER_LEVEL,
81 )
82 })
83 .collect::<Vec<_>>();
84
85 if let Some(witness) = builder.witness() {
86 let xin_columns = xin_bits
87 .iter()
88 .map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
89 .collect::<Result<Vec<_>, Error>>()?;
90
91 let yin_columns = yin_bits
92 .iter()
93 .map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
94 .collect::<Result<Vec<_>, Error>>()?;
95
96 let result = columns_to_numbers(&xin_columns)
97 .into_iter()
98 .zip(columns_to_numbers(&yin_columns))
99 .map(|(x, y)| x * y)
100 .collect::<Vec<_>>();
101
102 let mut cout_columns = cout_bits
103 .iter()
104 .map(|&id| witness.new_column::<BinaryField1b>(id))
105 .collect::<Vec<_>>();
106
107 let mut cout_columns_u8 = cout_columns
108 .iter_mut()
109 .map(|column| column.packed())
110 .collect::<Vec<_>>();
111
112 numbers_to_columns(&result, &mut cout_columns_u8);
113 }
114
115 builder.assert_zero(
117 name.clone(),
118 [xin_bits[0], yin_bits[0], cout_bits[0]],
119 arith_expr!([xin, yin, cout] = xin * yin - cout).convert_field(),
120 );
121
122 builder.assert_zero(
124 name,
125 [
126 yin_exp_result_id,
127 cout_low_exp_result_id,
128 cout_high_exp_result_id,
129 ],
130 arith_expr!([yin, low, high] = low * high - yin).convert_field(),
131 );
132
133 let (cout_low_bits, cout_high_bits) = cout_bits.split_at(cout_bits.len() / 2);
134
135 builder.add_static_exp(
136 xin_bits,
137 xin_exp_result_id,
138 FExpBase::MULTIPLICATIVE_GENERATOR.into(),
139 FExpBase::TOWER_LEVEL,
140 );
141 builder.add_dynamic_exp(yin_bits, yin_exp_result_id, xin_exp_result_id);
142 builder.add_static_exp(
143 cout_low_bits.to_vec(),
144 cout_low_exp_result_id,
145 FExpBase::MULTIPLICATIVE_GENERATOR.into(),
146 FExpBase::TOWER_LEVEL,
147 );
148 builder.add_static_exp(
149 cout_high_bits.to_vec(),
150 cout_high_exp_result_id,
151 exp_pow2(FExpBase::MULTIPLICATIVE_GENERATOR, cout_low_bits.len()).into(),
152 FExpBase::TOWER_LEVEL,
153 );
154
155 Ok(cout_bits)
156}
157
158pub fn u32_mul<const LOG_MAX_MULTIPLICITY: usize>(
167 builder: &mut ConstraintSystemBuilder,
168 name: impl ToString,
169 xin_bits: [OracleId; 32],
170 yin_bits: [OracleId; 32],
171) -> Result<[OracleId; 64], anyhow::Error> {
172 let log_rows = builder.log_rows(xin_bits)?;
173
174 let name = name.to_string();
175
176 let [xin_low, xin_high] = array::from_fn(|i| {
177 let bits: [(usize, F); 16] =
178 array::from_fn(|j| (xin_bits[16 * i + j], <F as TowerField>::basis(0, j).unwrap()));
179
180 builder
181 .add_linear_combination("xin_low", log_rows, bits)
182 .unwrap()
183 });
184
185 if let Some(witness) = builder.witness() {
186 let xin_columns = xin_bits
187 .iter()
188 .map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
189 .collect::<Result<Vec<_>, Error>>()?;
190
191 let xin_numbers = columns_to_numbers(&xin_columns);
192
193 let mut xin_low = witness.new_column::<BinaryField16b>(xin_low);
194 let xin_low = xin_low.as_mut_slice::<u16>();
195
196 let mut xin_high = witness.new_column::<BinaryField16b>(xin_high);
197 let xin_high = xin_high.as_mut_slice::<u16>();
198
199 izip!(xin_numbers, xin_low, xin_high).for_each(|(xin, low, high)| {
200 *low = (xin & 0xFFFF) as u16;
201 *high = (xin >> 16) as u16;
202 });
203 }
204
205 let (xin_low_exp_res_id, g) = u16_static_exp_lookups::<LOG_MAX_MULTIPLICITY>(
207 builder,
208 "xin_low_exp_res",
209 xin_low,
210 BinaryField64b::MULTIPLICATIVE_GENERATOR,
211 None,
212 )?;
213
214 let (xin_high_exp_res_id, g_16) = u16_static_exp_lookups::<LOG_MAX_MULTIPLICITY>(
216 builder,
217 "xin_high_exp_res",
218 xin_high,
219 exp_pow2(BinaryField64b::MULTIPLICATIVE_GENERATOR, 16),
220 None,
221 )?;
222
223 let xin_exp_res_id =
225 builder.add_committed("xin_exp_result", log_rows, BinaryField64b::TOWER_LEVEL);
226
227 builder.assert_zero(
228 "xin_exp_res_id zerocheck",
229 [xin_low_exp_res_id, xin_high_exp_res_id, xin_exp_res_id],
230 arith_expr!(
231 [xin_low_exp_res, xin_high_exp_res, xin_exp_result_id] =
232 xin_low_exp_res * xin_high_exp_res - xin_exp_result_id
233 )
234 .convert_field(),
235 );
236
237 if let Some(witness) = builder.witness() {
238 let xin_low_exp_res = witness
239 .get::<BinaryField64b>(xin_low_exp_res_id)?
240 .as_slice::<BinaryField64b>();
241
242 let xin_high_exp_res = witness
243 .get::<BinaryField64b>(xin_high_exp_res_id)?
244 .as_slice::<BinaryField64b>();
245
246 let mut xin_exp_res = witness.new_column::<BinaryField64b>(xin_exp_res_id);
247 let xin_exp_res = xin_exp_res.as_mut_slice::<BinaryField64b>();
248 xin_exp_res
249 .par_iter_mut()
250 .enumerate()
251 .for_each(|(i, xin_exp_res)| {
252 *xin_exp_res = xin_low_exp_res[i] * xin_high_exp_res[i];
253 });
254 }
255
256 let yin_exp_result_id = builder.add_committed(
258 format!("{} yin_exp_result", name),
259 log_rows,
260 BinaryField64b::TOWER_LEVEL,
261 );
262
263 builder.add_dynamic_exp(yin_bits.to_vec(), yin_exp_result_id, xin_exp_res_id);
264
265 let cout_bits: [OracleId; 64] =
266 builder.add_committed_multiple("cout_bits", log_rows, BinaryField1b::TOWER_LEVEL);
267
268 let cout: [OracleId; 4] = array::from_fn(|i| {
269 let bits: [(usize, F); 16] =
270 array::from_fn(|j| (cout_bits[16 * i + j], <F as TowerField>::basis(0, j).unwrap()));
271
272 builder
273 .add_linear_combination("cout 16b", log_rows, bits)
274 .unwrap()
275 });
276
277 if let Some(witness) = builder.witness() {
278 let xin_columns = xin_bits
279 .iter()
280 .map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
281 .collect::<Result<Vec<_>, Error>>()?;
282
283 let yin_columns = yin_bits
284 .iter()
285 .map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
286 .collect::<Result<Vec<_>, Error>>()?;
287
288 let result = columns_to_numbers(&xin_columns)
289 .into_iter()
290 .zip(columns_to_numbers(&yin_columns))
291 .map(|(x, y)| x * y)
292 .collect::<Vec<_>>();
293
294 let mut cout_columns = cout_bits
295 .iter()
296 .map(|&id| witness.new_column::<BinaryField1b>(id))
297 .collect::<Vec<_>>();
298
299 let mut cout_columns = cout_columns
300 .iter_mut()
301 .map(|column| column.packed())
302 .collect::<Vec<_>>();
303
304 numbers_to_columns(&result, &mut cout_columns);
305
306 let mut cout = cout.map(|id| witness.new_column::<BinaryField16b>(id));
307
308 let mut cout = cout
309 .iter_mut()
310 .map(|cout| cout.as_mut_slice::<u16>())
311 .collect::<Vec<_>>();
312
313 cout.iter_mut().enumerate().for_each(|(j, cout)| {
314 cout.par_iter_mut().enumerate().for_each(|(i, cout)| {
315 let value = result[i];
316
317 *cout = ((value >> (j * 16)) & 0xFFFF) as u16;
318 });
319 });
320 }
321
322 let cout_exp_res_id = (0..4)
324 .map(|i| {
325 let g_table = match i {
326 0 => Some(g),
327 1 => Some(g_16),
328 _ => None,
329 };
330
331 u16_static_exp_lookups::<LOG_MAX_MULTIPLICITY>(
332 builder,
333 format!("cout_exp_result_id {}", i),
334 cout[i],
335 exp_pow2(BinaryField64b::MULTIPLICATIVE_GENERATOR, 16 * i),
336 g_table,
337 )
338 .map(|res| res.0)
339 })
340 .collect::<Result<Vec<_>, anyhow::Error>>()?;
341
342 builder.assert_zero(
344 name.clone(),
345 [xin_bits[0], yin_bits[0], cout_bits[0]],
346 arith_expr!([x, y, c] = x * y - c).convert_field(),
347 );
348
349 builder.assert_zero(
351 name,
352 [
353 yin_exp_result_id,
354 cout_exp_res_id[0],
355 cout_exp_res_id[1],
356 cout_exp_res_id[2],
357 cout_exp_res_id[3],
358 ],
359 arith_expr!(
360 [yin, cout_0, cout_1, cout_2, cout_3] = cout_0 * cout_1 * cout_2 * cout_3 - yin
361 )
362 .convert_field(),
363 );
364
365 Ok(cout_bits)
366}
367
368fn exp_pow2<F: BinaryField>(mut g: F, log_exp: usize) -> F {
369 for _ in 0..log_exp {
370 g *= g
371 }
372 g
373}
374
375fn columns_to_numbers(columns: &[&[PackedType<U, BinaryField1b>]]) -> Vec<u128> {
376 let width = PackedType::<U, BinaryField1b>::WIDTH;
377 let mut numbers: Vec<u128> = vec![0; columns.first().map(|c| c.len() * width).unwrap_or(0)];
378
379 for (bit, column) in columns.iter().enumerate() {
380 numbers.par_iter_mut().enumerate().for_each(|(i, number)| {
381 if get_packed_slice(column, i) == BinaryField1b::ONE {
382 *number |= 1 << bit;
383 }
384 });
385 }
386 numbers
387}
388
389fn numbers_to_columns(numbers: &[u128], columns: &mut [&mut [PackedType<U, BinaryField1b>]]) {
390 columns
391 .par_iter_mut()
392 .enumerate()
393 .for_each(|(bit, column)| {
394 for (i, number) in numbers.iter().enumerate() {
395 if (number >> bit) & 1 == 1 {
396 set_packed_slice(column, i, BinaryField1b::ONE);
397 }
398 }
399 });
400}
401
402#[cfg(test)]
403mod tests {
404 use binius_core::{
405 constraint_system::{self},
406 fiat_shamir::HasherChallenger,
407 tower::CanonicalTowerFamily,
408 };
409 use binius_field::{BinaryField1b, BinaryField8b};
410 use binius_hal::make_portable_backend;
411 use binius_hash::groestl::{Groestl256, Groestl256ByteCompression};
412
413 use super::mul;
414 use crate::{
415 builder::{types::U, ConstraintSystemBuilder},
416 unconstrained::unconstrained,
417 };
418
419 #[test]
420 fn test_mul() {
421 let allocator = bumpalo::Bump::new();
422 let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
423
424 let log_n_muls = 9;
425
426 let in_a = (0..2)
427 .map(|i| {
428 unconstrained::<BinaryField1b>(&mut builder, format!("in_a_{}", i), log_n_muls)
429 .unwrap()
430 })
431 .collect::<Vec<_>>();
432 let in_b = (0..2)
433 .map(|i| {
434 unconstrained::<BinaryField1b>(&mut builder, format!("in_b_{}", i), log_n_muls)
435 .unwrap()
436 })
437 .collect::<Vec<_>>();
438
439 mul::<BinaryField8b>(&mut builder, "test", in_a, in_b).unwrap();
440
441 let witness = builder
442 .take_witness()
443 .expect("builder created with witness");
444
445 let constraint_system = builder.build().unwrap();
446
447 let backend = make_portable_backend();
448
449 let proof = constraint_system::prove::<
450 U,
451 CanonicalTowerFamily,
452 Groestl256,
453 Groestl256ByteCompression,
454 HasherChallenger<Groestl256>,
455 _,
456 >(&constraint_system, 1, 10, &[], witness, &backend)
457 .unwrap();
458
459 constraint_system::verify::<
460 U,
461 CanonicalTowerFamily,
462 Groestl256,
463 Groestl256ByteCompression,
464 HasherChallenger<Groestl256>,
465 >(&constraint_system, 1, 10, &[], proof)
466 .unwrap();
467 }
468}