binius_circuits/arithmetic/
mul.rs

1// Copyright 2025 Irreducible Inc.
2
3//! Multiplication based on exponentiation.
4//!
5//! The core idea of this method is to verify the equality $a \cdot b = c$
6//! by checking if $(g^a)^b = g^{clow} \cdot (g^{2^{len(clow)}})^{chigh}$,
7//! where exponentiation proofs can be efficiently verified using the GKR exponentiation protocol.
8//!
9//! You can read more information in [Integer Multiplication in Binius](https://www.irreducible.com/posts/integer-multiplication-in-binius).
10
11use std::array;
12
13use anyhow::Error;
14use binius_core::oracle::OracleId;
15use binius_field::{
16	as_packed_field::PackedType,
17	packed::{get_packed_slice, set_packed_slice},
18	BinaryField, BinaryField16b, BinaryField1b, BinaryField64b, Field, TowerField,
19};
20use binius_macros::arith_expr;
21use binius_maybe_rayon::iter::{
22	IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
23};
24use binius_utils::bail;
25use itertools::izip;
26
27use super::static_exp::u16_static_exp_lookups;
28use crate::builder::{
29	types::{F, U},
30	ConstraintSystemBuilder,
31};
32
33pub fn mul<FExpBase>(
34	builder: &mut ConstraintSystemBuilder,
35	name: impl ToString,
36	xin_bits: Vec<OracleId>,
37	yin_bits: Vec<OracleId>,
38) -> Result<Vec<OracleId>, anyhow::Error>
39where
40	FExpBase: TowerField,
41	F: From<FExpBase>,
42{
43	let name = name.to_string();
44
45	let log_rows = builder.log_rows([xin_bits.clone(), yin_bits.clone()].into_iter().flatten())?;
46
47	// $g^x$
48	let xin_exp_result_id =
49		builder.add_committed(format!("{} xin_exp_result", name), log_rows, FExpBase::TOWER_LEVEL);
50
51	// $(g^x)^y$
52	let yin_exp_result_id =
53		builder.add_committed(format!("{} yin_exp_result", name), log_rows, FExpBase::TOWER_LEVEL);
54
55	// $g^{clow}$
56	let cout_low_exp_result_id = builder.add_committed(
57		format!("{} cout_low_exp_result", name),
58		log_rows,
59		FExpBase::TOWER_LEVEL,
60	);
61
62	// $(g^{2^{len(clow)}})^{chigh}$
63	let cout_high_exp_result_id = builder.add_committed(
64		format!("{} cout_high_exp_result", name),
65		log_rows,
66		FExpBase::TOWER_LEVEL,
67	);
68
69	let result_bits = xin_bits.len() + yin_bits.len();
70
71	if result_bits > FExpBase::N_BITS {
72		bail!(anyhow::anyhow!("FExpBase to small"));
73	}
74
75	let cout_bits = (0..result_bits)
76		.map(|i| {
77			builder.add_committed(
78				format!("{} bit of {}", i, name),
79				log_rows,
80				BinaryField1b::TOWER_LEVEL,
81			)
82		})
83		.collect::<Vec<_>>();
84
85	if let Some(witness) = builder.witness() {
86		let xin_columns = xin_bits
87			.iter()
88			.map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
89			.collect::<Result<Vec<_>, Error>>()?;
90
91		let yin_columns = yin_bits
92			.iter()
93			.map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
94			.collect::<Result<Vec<_>, Error>>()?;
95
96		let result = columns_to_numbers(&xin_columns)
97			.into_iter()
98			.zip(columns_to_numbers(&yin_columns))
99			.map(|(x, y)| x * y)
100			.collect::<Vec<_>>();
101
102		let mut cout_columns = cout_bits
103			.iter()
104			.map(|&id| witness.new_column::<BinaryField1b>(id))
105			.collect::<Vec<_>>();
106
107		let mut cout_columns_u8 = cout_columns
108			.iter_mut()
109			.map(|column| column.packed())
110			.collect::<Vec<_>>();
111
112		numbers_to_columns(&result, &mut cout_columns_u8);
113	}
114
115	// Handling special case when $x == 0$ $y == 0$ $c == 2^{2 \cdot n} -1$
116	builder.assert_zero(
117		name.clone(),
118		[xin_bits[0], yin_bits[0], cout_bits[0]],
119		arith_expr!([xin, yin, cout] = xin * yin - cout).convert_field(),
120	);
121
122	// $(g^x)^y = g^{clow} * (g^{2^{len(clow)}})^{chigh}$
123	builder.assert_zero(
124		name,
125		[
126			yin_exp_result_id,
127			cout_low_exp_result_id,
128			cout_high_exp_result_id,
129		],
130		arith_expr!([yin, low, high] = low * high - yin).convert_field(),
131	);
132
133	let (cout_low_bits, cout_high_bits) = cout_bits.split_at(cout_bits.len() / 2);
134
135	builder.add_static_exp(
136		xin_bits,
137		xin_exp_result_id,
138		FExpBase::MULTIPLICATIVE_GENERATOR.into(),
139		FExpBase::TOWER_LEVEL,
140	);
141	builder.add_dynamic_exp(yin_bits, yin_exp_result_id, xin_exp_result_id);
142	builder.add_static_exp(
143		cout_low_bits.to_vec(),
144		cout_low_exp_result_id,
145		FExpBase::MULTIPLICATIVE_GENERATOR.into(),
146		FExpBase::TOWER_LEVEL,
147	);
148	builder.add_static_exp(
149		cout_high_bits.to_vec(),
150		cout_high_exp_result_id,
151		exp_pow2(FExpBase::MULTIPLICATIVE_GENERATOR, cout_low_bits.len()).into(),
152		FExpBase::TOWER_LEVEL,
153	);
154
155	Ok(cout_bits)
156}
157
158/// u32 Multiplication based on plain lookups for static exponentiation
159/// and gkr_exp for dynamic exponentiation
160///
161/// The core idea of this method is to verify the equality $x \cdot y = c$
162/// by checking if
163///
164/// $(g^{xlow} \cdot (g^{2^{16}})^{xhigh})^y = \prod_{i=0}^{3} (g^{2^{(16 \cdot i)}})^{c_i} $,
165/// where $c_i$ is a $i$ 16-bit
166pub fn u32_mul<const LOG_MAX_MULTIPLICITY: usize>(
167	builder: &mut ConstraintSystemBuilder,
168	name: impl ToString,
169	xin_bits: [OracleId; 32],
170	yin_bits: [OracleId; 32],
171) -> Result<[OracleId; 64], anyhow::Error> {
172	let log_rows = builder.log_rows(xin_bits)?;
173
174	let name = name.to_string();
175
176	let [xin_low, xin_high] = array::from_fn(|i| {
177		let bits: [(usize, F); 16] =
178			array::from_fn(|j| (xin_bits[16 * i + j], <F as TowerField>::basis(0, j).unwrap()));
179
180		builder
181			.add_linear_combination("xin_low", log_rows, bits)
182			.unwrap()
183	});
184
185	if let Some(witness) = builder.witness() {
186		let xin_columns = xin_bits
187			.iter()
188			.map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
189			.collect::<Result<Vec<_>, Error>>()?;
190
191		let xin_numbers = columns_to_numbers(&xin_columns);
192
193		let mut xin_low = witness.new_column::<BinaryField16b>(xin_low);
194		let xin_low = xin_low.as_mut_slice::<u16>();
195
196		let mut xin_high = witness.new_column::<BinaryField16b>(xin_high);
197		let xin_high = xin_high.as_mut_slice::<u16>();
198
199		izip!(xin_numbers, xin_low, xin_high).for_each(|(xin, low, high)| {
200			*low = (xin & 0xFFFF) as u16;
201			*high = (xin >> 16) as u16;
202		});
203	}
204
205	//$g^{xlow}$
206	let (xin_low_exp_res_id, g) = u16_static_exp_lookups::<LOG_MAX_MULTIPLICITY>(
207		builder,
208		"xin_low_exp_res",
209		xin_low,
210		BinaryField64b::MULTIPLICATIVE_GENERATOR,
211		None,
212	)?;
213
214	//$(g^{2^{16}})^{xhigh}$
215	let (xin_high_exp_res_id, g_16) = u16_static_exp_lookups::<LOG_MAX_MULTIPLICITY>(
216		builder,
217		"xin_high_exp_res",
218		xin_high,
219		exp_pow2(BinaryField64b::MULTIPLICATIVE_GENERATOR, 16),
220		None,
221	)?;
222
223	//$g^{xin}$
224	let xin_exp_res_id =
225		builder.add_committed("xin_exp_result", log_rows, BinaryField64b::TOWER_LEVEL);
226
227	builder.assert_zero(
228		"xin_exp_res_id zerocheck",
229		[xin_low_exp_res_id, xin_high_exp_res_id, xin_exp_res_id],
230		arith_expr!(
231			[xin_low_exp_res, xin_high_exp_res, xin_exp_result_id] =
232				xin_low_exp_res * xin_high_exp_res - xin_exp_result_id
233		)
234		.convert_field(),
235	);
236
237	if let Some(witness) = builder.witness() {
238		let xin_low_exp_res = witness
239			.get::<BinaryField64b>(xin_low_exp_res_id)?
240			.as_slice::<BinaryField64b>();
241
242		let xin_high_exp_res = witness
243			.get::<BinaryField64b>(xin_high_exp_res_id)?
244			.as_slice::<BinaryField64b>();
245
246		let mut xin_exp_res = witness.new_column::<BinaryField64b>(xin_exp_res_id);
247		let xin_exp_res = xin_exp_res.as_mut_slice::<BinaryField64b>();
248		xin_exp_res
249			.par_iter_mut()
250			.enumerate()
251			.for_each(|(i, xin_exp_res)| {
252				*xin_exp_res = xin_low_exp_res[i] * xin_high_exp_res[i];
253			});
254	}
255
256	//$(g^{x})^{y}$
257	let yin_exp_result_id = builder.add_committed(
258		format!("{} yin_exp_result", name),
259		log_rows,
260		BinaryField64b::TOWER_LEVEL,
261	);
262
263	builder.add_dynamic_exp(yin_bits.to_vec(), yin_exp_result_id, xin_exp_res_id);
264
265	let cout_bits: [OracleId; 64] =
266		builder.add_committed_multiple("cout_bits", log_rows, BinaryField1b::TOWER_LEVEL);
267
268	let cout: [OracleId; 4] = array::from_fn(|i| {
269		let bits: [(usize, F); 16] =
270			array::from_fn(|j| (cout_bits[16 * i + j], <F as TowerField>::basis(0, j).unwrap()));
271
272		builder
273			.add_linear_combination("cout 16b", log_rows, bits)
274			.unwrap()
275	});
276
277	if let Some(witness) = builder.witness() {
278		let xin_columns = xin_bits
279			.iter()
280			.map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
281			.collect::<Result<Vec<_>, Error>>()?;
282
283		let yin_columns = yin_bits
284			.iter()
285			.map(|&id| witness.get::<BinaryField1b>(id).map(|x| x.packed()))
286			.collect::<Result<Vec<_>, Error>>()?;
287
288		let result = columns_to_numbers(&xin_columns)
289			.into_iter()
290			.zip(columns_to_numbers(&yin_columns))
291			.map(|(x, y)| x * y)
292			.collect::<Vec<_>>();
293
294		let mut cout_columns = cout_bits
295			.iter()
296			.map(|&id| witness.new_column::<BinaryField1b>(id))
297			.collect::<Vec<_>>();
298
299		let mut cout_columns = cout_columns
300			.iter_mut()
301			.map(|column| column.packed())
302			.collect::<Vec<_>>();
303
304		numbers_to_columns(&result, &mut cout_columns);
305
306		let mut cout = cout.map(|id| witness.new_column::<BinaryField16b>(id));
307
308		let mut cout = cout
309			.iter_mut()
310			.map(|cout| cout.as_mut_slice::<u16>())
311			.collect::<Vec<_>>();
312
313		cout.iter_mut().enumerate().for_each(|(j, cout)| {
314			cout.par_iter_mut().enumerate().for_each(|(i, cout)| {
315				let value = result[i];
316
317				*cout = ((value >> (j * 16)) & 0xFFFF) as u16;
318			});
319		});
320	}
321
322	//$(g^{2^{(16 \cdot i)}})^{c_i}$ where $c_i$ is a $i$ 16-bit
323	let cout_exp_res_id = (0..4)
324		.map(|i| {
325			let g_table = match i {
326				0 => Some(g),
327				1 => Some(g_16),
328				_ => None,
329			};
330
331			u16_static_exp_lookups::<LOG_MAX_MULTIPLICITY>(
332				builder,
333				format!("cout_exp_result_id {}", i),
334				cout[i],
335				exp_pow2(BinaryField64b::MULTIPLICATIVE_GENERATOR, 16 * i),
336				g_table,
337			)
338			.map(|res| res.0)
339		})
340		.collect::<Result<Vec<_>, anyhow::Error>>()?;
341
342	// Handling special case when $x == 0$ $y == 0$ $c == 2^{2 \cdot n} -1$
343	builder.assert_zero(
344		name.clone(),
345		[xin_bits[0], yin_bits[0], cout_bits[0]],
346		arith_expr!([x, y, c] = x * y - c).convert_field(),
347	);
348
349	// $(g^{xlow} \cdot (g^{2^{16}})^{xhigh})^y = \prod_{i=0}^{3} (g^{2^{(16 \cdot i)}})^{c_i} $
350	builder.assert_zero(
351		name,
352		[
353			yin_exp_result_id,
354			cout_exp_res_id[0],
355			cout_exp_res_id[1],
356			cout_exp_res_id[2],
357			cout_exp_res_id[3],
358		],
359		arith_expr!(
360			[yin, cout_0, cout_1, cout_2, cout_3] = cout_0 * cout_1 * cout_2 * cout_3 - yin
361		)
362		.convert_field(),
363	);
364
365	Ok(cout_bits)
366}
367
368fn exp_pow2<F: BinaryField>(mut g: F, log_exp: usize) -> F {
369	for _ in 0..log_exp {
370		g *= g
371	}
372	g
373}
374
375fn columns_to_numbers(columns: &[&[PackedType<U, BinaryField1b>]]) -> Vec<u128> {
376	let width = PackedType::<U, BinaryField1b>::WIDTH;
377	let mut numbers: Vec<u128> = vec![0; columns.first().map(|c| c.len() * width).unwrap_or(0)];
378
379	for (bit, column) in columns.iter().enumerate() {
380		numbers.par_iter_mut().enumerate().for_each(|(i, number)| {
381			if get_packed_slice(column, i) == BinaryField1b::ONE {
382				*number |= 1 << bit;
383			}
384		});
385	}
386	numbers
387}
388
389fn numbers_to_columns(numbers: &[u128], columns: &mut [&mut [PackedType<U, BinaryField1b>]]) {
390	columns
391		.par_iter_mut()
392		.enumerate()
393		.for_each(|(bit, column)| {
394			for (i, number) in numbers.iter().enumerate() {
395				if (number >> bit) & 1 == 1 {
396					set_packed_slice(column, i, BinaryField1b::ONE);
397				}
398			}
399		});
400}
401
402#[cfg(test)]
403mod tests {
404	use binius_core::{
405		constraint_system::{self},
406		fiat_shamir::HasherChallenger,
407		tower::CanonicalTowerFamily,
408	};
409	use binius_field::{BinaryField1b, BinaryField8b};
410	use binius_hal::make_portable_backend;
411	use binius_hash::groestl::{Groestl256, Groestl256ByteCompression};
412
413	use super::mul;
414	use crate::{
415		builder::{types::U, ConstraintSystemBuilder},
416		unconstrained::unconstrained,
417	};
418
419	#[test]
420	fn test_mul() {
421		let allocator = bumpalo::Bump::new();
422		let mut builder = ConstraintSystemBuilder::new_with_witness(&allocator);
423
424		let log_n_muls = 9;
425
426		let in_a = (0..2)
427			.map(|i| {
428				unconstrained::<BinaryField1b>(&mut builder, format!("in_a_{}", i), log_n_muls)
429					.unwrap()
430			})
431			.collect::<Vec<_>>();
432		let in_b = (0..2)
433			.map(|i| {
434				unconstrained::<BinaryField1b>(&mut builder, format!("in_b_{}", i), log_n_muls)
435					.unwrap()
436			})
437			.collect::<Vec<_>>();
438
439		mul::<BinaryField8b>(&mut builder, "test", in_a, in_b).unwrap();
440
441		let witness = builder
442			.take_witness()
443			.expect("builder created with witness");
444
445		let constraint_system = builder.build().unwrap();
446
447		let backend = make_portable_backend();
448
449		let proof = constraint_system::prove::<
450			U,
451			CanonicalTowerFamily,
452			Groestl256,
453			Groestl256ByteCompression,
454			HasherChallenger<Groestl256>,
455			_,
456		>(&constraint_system, 1, 10, &[], witness, &backend)
457		.unwrap();
458
459		constraint_system::verify::<
460			U,
461			CanonicalTowerFamily,
462			Groestl256,
463			Groestl256ByteCompression,
464			HasherChallenger<Groestl256>,
465		>(&constraint_system, 1, 10, &[], proof)
466		.unwrap();
467	}
468}