binius_m3/gadgets/
mul.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{array, marker::PhantomData};
4
5use anyhow::Result;
6use binius_field::{
7	BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedSubfield, TowerField,
8	packed::set_packed_slice,
9};
10use itertools::izip;
11
12use crate::{
13	builder::{B1, B32, B64, B128, Col, Expr, TableBuilder, TableWitnessSegment},
14	gadgets::{
15		add::{Incr, UnsignedAddPrimitives},
16		sub::{U32SubFlags, WideU32Sub},
17		util::pack_fp,
18	},
19};
20
21/// Helper trait to create Multiplication gadgets for unsigned integers of different bit lengths.
22pub trait UnsignedMulPrimitives {
23	type FP: TowerField;
24	type FExpBase: TowerField + ExtensionField<Self::FP>;
25
26	const BIT_LENGTH: usize;
27
28	/// Computes the unsigned primitive mul of `x*y = z` and returns the tuple (z_high, z_low)
29	/// representing the high and low bits respectively.
30	fn mul(x: Self::FP, y: Self::FP) -> (Self::FP, Self::FP);
31
32	fn is_bit_set_at(a: Self::FP, index: usize) -> bool {
33		<Self::FP as ExtensionField<B1>>::get_base(&a, index) == B1::ONE
34	}
35
36	fn generator() -> Self::FExpBase;
37
38	/// Returns the generator shifted by the bit length of `Self::FP`.
39	fn shifted_generator() -> Self::FExpBase;
40}
41
42impl UnsignedMulPrimitives for u32 {
43	type FP = B32;
44	type FExpBase = B64;
45
46	const BIT_LENGTH: usize = 32;
47
48	fn mul(x: B32, y: B32) -> (B32, B32) {
49		let res = x.val() as u64 * y.val() as u64;
50		let low = B32::new(res as u32);
51		let high = B32::new((res >> 32) as u32);
52		(high, low)
53	}
54
55	fn generator() -> B64 {
56		B64::MULTIPLICATIVE_GENERATOR
57	}
58
59	fn shifted_generator() -> B64 {
60		let mut g = B64::MULTIPLICATIVE_GENERATOR;
61		for _ in 0..32 {
62			g = g.square();
63		}
64		g
65	}
66}
67
68impl UnsignedMulPrimitives for u64 {
69	type FP = B64;
70	type FExpBase = B128;
71
72	const BIT_LENGTH: usize = 64;
73
74	fn mul(x: B64, y: B64) -> (B64, B64) {
75		let res = x.val() as u128 * y.val() as u128;
76		let low = B64::new(res as u64);
77		let high = B64::new((res >> 64) as u64);
78		(high, low)
79	}
80
81	fn generator() -> Self::FExpBase {
82		B128::MULTIPLICATIVE_GENERATOR
83	}
84
85	fn shifted_generator() -> B128 {
86		let mut g = Self::generator();
87		for _ in 0..64 {
88			g = g.square();
89		}
90		g
91	}
92}
93
94// Internally used to have deduplicated implementations.
95#[derive(Debug)]
96struct Mul<UX: UnsignedMulPrimitives, const BIT_LENGTH: usize> {
97	x_in_bits: [Col<B1>; BIT_LENGTH],
98	y_in_bits: [Col<B1>; BIT_LENGTH],
99	out_high_bits: [Col<B1>; BIT_LENGTH],
100	out_low_bits: [Col<B1>; BIT_LENGTH],
101
102	pub xin: Col<UX::FP>,
103	pub yin: Col<UX::FP>,
104	pub out_high: Col<UX::FP>,
105	pub out_low: Col<UX::FP>,
106
107	_marker: PhantomData<UX>,
108}
109
110impl<
111	FExpBase: TowerField,
112	FP: TowerField,
113	UX: UnsignedMulPrimitives<FP = FP, FExpBase = FExpBase>,
114	const BIT_LENGTH: usize,
115> Mul<UX, BIT_LENGTH>
116where
117	FExpBase: ExtensionField<FP> + ExtensionField<B1>,
118	B128: ExtensionField<FExpBase> + ExtensionField<FP> + ExtensionField<B1>,
119{
120	pub fn new(table: &mut TableBuilder) -> Self {
121		let x_in_bits = table.add_committed_multiple("x_in_bits");
122		let y_in_bits = table.add_committed_multiple("y_in_bits");
123
124		Self::with_inputs(table, x_in_bits, y_in_bits)
125	}
126
127	pub fn with_inputs(
128		table: &mut TableBuilder,
129		xin_bits: [Col<B1>; BIT_LENGTH],
130		yin_bits: [Col<B1>; BIT_LENGTH],
131	) -> Self {
132		assert_eq!(FExpBase::TOWER_LEVEL, FP::TOWER_LEVEL + 1);
133		assert_eq!(BIT_LENGTH, 1 << FP::TOWER_LEVEL);
134		assert_eq!(BIT_LENGTH, UX::BIT_LENGTH);
135		// These are currently the only bit lengths I've tested
136		assert!(BIT_LENGTH == 32 || BIT_LENGTH == 64);
137
138		let x_in = table.add_computed("x_in", pack_fp(xin_bits));
139		let y_in = table.add_computed("y_in", pack_fp(yin_bits));
140
141		let generator = UX::generator();
142		let generator_pow_bit_len = UX::shifted_generator();
143
144		let g_pow_x = table.add_static_exp::<FExpBase>("g^x", &xin_bits, generator);
145		let g_pow_xy = table.add_dynamic_exp::<FExpBase>("(g^x)^y", &yin_bits, g_pow_x);
146
147		let out_high_bits = table.add_committed_multiple("out_high");
148		let out_low_bits = table.add_committed_multiple("out_low");
149
150		let out_high = table.add_computed("out_high", pack_fp(out_high_bits));
151		let out_low = table.add_computed("out_low", pack_fp(out_low_bits));
152
153		let g_pow_out_low: Col<FExpBase> =
154			table.add_static_exp("g^(out_low)", &out_low_bits, generator);
155		let g_pow_out_high: Col<FExpBase> = table.add_static_exp(
156			"(g^(2^BIT_LENGTH))^(out_high)",
157			&out_high_bits,
158			generator_pow_bit_len,
159		);
160
161		table.assert_zero("order_non_wrapping", xin_bits[0] * yin_bits[0] - out_low_bits[0]);
162		table.assert_zero("exponentiation_equality", g_pow_xy - g_pow_out_low * g_pow_out_high);
163
164		Self {
165			x_in_bits: xin_bits,
166			y_in_bits: yin_bits,
167			out_high_bits,
168			out_low_bits,
169			xin: x_in,
170			yin: y_in,
171			out_high,
172			out_low,
173			_marker: PhantomData,
174		}
175	}
176
177	#[inline]
178	fn populate_internal<P>(
179		&self,
180		index: &mut TableWitnessSegment<P>,
181		x_vals: impl IntoIterator<Item = FP>,
182		y_vals: impl IntoIterator<Item = FP>,
183		fill_input_bits: bool,
184	) -> Result<()>
185	where
186		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<FP>,
187	{
188		let mut x_in_bits = array_util::try_map(self.x_in_bits, |bit| index.get_mut(bit))?;
189		let mut y_in_bits = array_util::try_map(self.y_in_bits, |bit| index.get_mut(bit))?;
190		let mut out_low_bits = array_util::try_map(self.out_low_bits, |bit| index.get_mut(bit))?;
191		let mut out_high_bits = array_util::try_map(self.out_high_bits, |bit| index.get_mut(bit))?;
192
193		let mut x_in = index.get_mut(self.xin)?;
194		let mut y_in = index.get_mut(self.yin)?;
195		let mut out_low = index.get_mut(self.out_low)?;
196		let mut out_high = index.get_mut(self.out_high)?;
197
198		for (i, (x, y)) in x_vals.into_iter().zip(y_vals.into_iter()).enumerate() {
199			let (res_high, res_low) = UX::mul(x, y);
200			set_packed_slice(&mut x_in, i, x);
201			set_packed_slice(&mut y_in, i, y);
202			set_packed_slice(&mut out_low, i, res_low);
203			set_packed_slice(&mut out_high, i, res_high);
204
205			for bit_idx in 0..BIT_LENGTH {
206				if fill_input_bits {
207					set_packed_slice(
208						&mut x_in_bits[bit_idx],
209						i,
210						B1::from(UX::is_bit_set_at(x, bit_idx)),
211					);
212					set_packed_slice(
213						&mut y_in_bits[bit_idx],
214						i,
215						B1::from(UX::is_bit_set_at(y, bit_idx)),
216					);
217				}
218				set_packed_slice(
219					&mut out_low_bits[bit_idx],
220					i,
221					B1::from(UX::is_bit_set_at(res_low, bit_idx)),
222				);
223				set_packed_slice(
224					&mut out_high_bits[bit_idx],
225					i,
226					B1::from(UX::is_bit_set_at(res_high, bit_idx)),
227				);
228			}
229		}
230
231		// NB: Exponentiation result columns are filled by the core constraint system prover.
232
233		Ok(())
234	}
235
236	pub fn populate_with_inputs<P>(
237		&self,
238		index: &mut TableWitnessSegment<P>,
239		x_vals: impl IntoIterator<Item = FP>,
240		y_vals: impl IntoIterator<Item = FP>,
241	) -> Result<()>
242	where
243		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<FP>,
244	{
245		self.populate_internal(index, x_vals, y_vals, true)
246	}
247
248	pub fn populate<P>(
249		&self,
250		index: &mut TableWitnessSegment<P>,
251		x_vals: impl IntoIterator<Item = FP>,
252		y_vals: impl IntoIterator<Item = FP>,
253	) -> Result<()>
254	where
255		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<FP>,
256	{
257		self.populate_internal(index, x_vals, y_vals, false)
258	}
259}
260
261#[derive(Debug)]
262pub struct MulUU32 {
263	inner: Mul<u32, 32>,
264
265	pub xin: Col<B32>,
266	pub yin: Col<B32>,
267	pub out_high: Col<B32>,
268	pub out_low: Col<B32>,
269	pub out_high_bits: [Col<B1>; 32],
270	pub out_low_bits: [Col<B1>; 32],
271}
272
273impl MulUU32 {
274	/// Constructor for `u32` multiplication gadget that creates the columns for inputs.
275	/// You must call `MulUU32::populate` to fill the witness data.
276	pub fn new(table: &mut TableBuilder) -> Self {
277		let inner = Mul::new(table);
278
279		Self {
280			xin: inner.xin,
281			yin: inner.yin,
282			out_high: inner.out_high,
283			out_low: inner.out_low,
284			out_high_bits: inner.out_high_bits,
285			out_low_bits: inner.out_low_bits,
286			inner,
287		}
288	}
289
290	/// Constructor for `u32` multiplication gadget that uses the provided columns for inputs.
291	/// You must call `MulUU32::populate_with_inputs` to fill the witness data.
292	pub fn with_inputs(
293		table: &mut TableBuilder,
294		xin_bits: [Col<B1>; 32],
295		yin_bits: [Col<B1>; 32],
296	) -> Self {
297		let inner = Mul::with_inputs(table, xin_bits, yin_bits);
298
299		Self {
300			xin: inner.xin,
301			yin: inner.yin,
302			out_high: inner.out_high,
303			out_low: inner.out_low,
304			out_high_bits: inner.out_high_bits,
305			out_low_bits: inner.out_low_bits,
306			inner,
307		}
308	}
309
310	pub fn populate_with_inputs<P>(
311		&self,
312		index: &mut TableWitnessSegment<P>,
313		x_vals: impl IntoIterator<Item = B32>,
314		y_vals: impl IntoIterator<Item = B32>,
315	) -> Result<()>
316	where
317		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B32>,
318	{
319		self.inner.populate_with_inputs(index, x_vals, y_vals)
320	}
321
322	pub fn populate<P>(
323		&self,
324		index: &mut TableWitnessSegment<P>,
325		x_vals: impl IntoIterator<Item = B32>,
326		y_vals: impl IntoIterator<Item = B32>,
327	) -> Result<()>
328	where
329		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B32>,
330	{
331		self.inner.populate(index, x_vals, y_vals)
332	}
333}
334
335#[derive(Debug)]
336pub struct MulUU64 {
337	inner: Mul<u64, 64>,
338
339	pub xin: Col<B64>,
340	pub yin: Col<B64>,
341	pub out_high: Col<B64>,
342	pub out_low: Col<B64>,
343	pub out_high_bits: [Col<B1>; 64],
344	pub out_low_bits: [Col<B1>; 64],
345}
346
347impl MulUU64 {
348	/// Constructor for `u64` multiplication gadget that creates the columns for inputs.
349	/// You must call `MulUU64::populate` to fill the witness data.
350	pub fn new(table: &mut TableBuilder) -> Self {
351		let inner = Mul::new(table);
352
353		Self {
354			xin: inner.xin,
355			yin: inner.yin,
356			out_high: inner.out_high,
357			out_low: inner.out_low,
358			out_high_bits: inner.out_high_bits,
359			out_low_bits: inner.out_low_bits,
360			inner,
361		}
362	}
363
364	/// Constructor for `u64` multiplication gadget that uses the provided columns for inputs.
365	/// You must call `MulUU64::populate_with_inputs` to fill the witness data.
366	pub fn with_inputs(
367		table: &mut TableBuilder,
368		xin_bits: [Col<B1>; 64],
369		yin_bits: [Col<B1>; 64],
370	) -> Self {
371		let inner = Mul::with_inputs(table, xin_bits, yin_bits);
372
373		Self {
374			xin: inner.xin,
375			yin: inner.yin,
376			out_high: inner.out_high,
377			out_low: inner.out_low,
378			out_high_bits: inner.out_high_bits,
379			out_low_bits: inner.out_low_bits,
380			inner,
381		}
382	}
383
384	pub fn populate_with_inputs<P>(
385		&self,
386		index: &mut TableWitnessSegment<P>,
387		x_vals: impl IntoIterator<Item = B64>,
388		y_vals: impl IntoIterator<Item = B64>,
389	) -> Result<()>
390	where
391		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B64>,
392	{
393		self.inner.populate_with_inputs(index, x_vals, y_vals)
394	}
395
396	pub fn populate<P>(
397		&self,
398		index: &mut TableWitnessSegment<P>,
399		x_vals: impl IntoIterator<Item = B64>,
400		y_vals: impl IntoIterator<Item = B64>,
401	) -> Result<()>
402	where
403		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B64>,
404	{
405		self.inner.populate(index, x_vals, y_vals)
406	}
407}
408
409#[derive(Debug)]
410pub struct MulSS32 {
411	mul_inner: MulUU32,
412	x_in_bits: [Col<B1>; 32],
413	y_in_bits: [Col<B1>; 32],
414	y_sub: WideU32Sub,
415	new_prod_high_bits: [Col<B1>; 32],
416	x_sub: WideU32Sub,
417
418	// Outputs
419	pub out_bits: [Col<B1>; 64],
420	pub xin: Col<B32>,
421	pub yin: Col<B32>,
422	pub out_high: Col<B32>,
423	pub out_low: Col<B32>,
424}
425
426impl MulSS32 {
427	/// Create the gadget by automatically committing the required input columns.
428	pub fn new(table: &mut TableBuilder) -> Self {
429		let x_in_bits = table.add_committed_multiple("x_in_bits");
430		let y_in_bits = table.add_committed_multiple("y_in_bits");
431
432		Self::with_input(table, x_in_bits, y_in_bits)
433	}
434
435	/// Create the gadget with the supplied `xin_bits` and `yin_bits` columns.
436	pub fn with_input(
437		table: &mut TableBuilder,
438		xin_bits: [Col<B1>; 32],
439		yin_bits: [Col<B1>; 32],
440	) -> Self {
441		let xin = table.add_computed("x_in", pack_fp(xin_bits));
442		let yin = table.add_computed("y_in", pack_fp(yin_bits));
443
444		let x_is_negative = xin_bits[31]; // Will be 1 if negative
445		let y_is_negative = yin_bits[31]; // Will be 1 if negative
446
447		let mut inner_mul_table = table.with_namespace("MulUU32");
448		let mul_inner = MulUU32::with_inputs(&mut inner_mul_table, xin_bits, yin_bits);
449
450		let out_low_bits: [_; 32] = array::from_fn(|i| mul_inner.out_low_bits[i]);
451
452		let prod_high_bits: [_; 32] = array::from_fn(|i| mul_inner.out_high_bits[i]);
453
454		let mut inner_y_sub_table = table.with_namespace("X_less_than_zero");
455		let y_sub = WideU32Sub::new(
456			&mut inner_y_sub_table,
457			prod_high_bits,
458			yin_bits,
459			U32SubFlags {
460				commit_zout: true,
461				..Default::default()
462			},
463		);
464
465		let new_prod_high_bits = array::from_fn(|bit| {
466			table.add_computed(
467				format!("new_prod_high[{bit}]"),
468				prod_high_bits[bit] + x_is_negative * (prod_high_bits[bit] + y_sub.zout[bit]),
469			)
470		});
471
472		let mut inner_x_sub_table = table.with_namespace("Y_less_than_zero");
473		let x_sub = WideU32Sub::new(
474			&mut inner_x_sub_table,
475			new_prod_high_bits,
476			xin_bits,
477			U32SubFlags {
478				commit_zout: true,
479				..Default::default()
480			},
481		);
482
483		let out_high_bits: [_; 32] = array::from_fn(|bit| {
484			table.add_computed(
485				format!("out_high[{bit}]"),
486				new_prod_high_bits[bit]
487					+ y_is_negative * (new_prod_high_bits[bit] + x_sub.zout[bit]),
488			)
489		});
490
491		let out_high = table.add_computed("out_high", pack_fp(out_high_bits));
492		let out_low = table.add_computed("out_low", pack_fp(out_low_bits));
493		let out_bits: [_; 64] = array::from_fn(|i| {
494			if i < 32 {
495				out_low_bits[i]
496			} else {
497				out_high_bits[i - 32]
498			}
499		});
500
501		Self {
502			y_sub,
503			new_prod_high_bits,
504			x_sub,
505			x_in_bits: xin_bits,
506			y_in_bits: yin_bits,
507			out_bits,
508			mul_inner,
509			xin,
510			yin,
511			out_low,
512			out_high,
513		}
514	}
515
516	pub fn populate_with_inputs<P>(
517		&self,
518		index: &mut TableWitnessSegment<P>,
519		x_vals: impl IntoIterator<Item = B32> + Clone,
520		y_vals: impl IntoIterator<Item = B32> + Clone,
521	) -> Result<()>
522	where
523		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B32>,
524	{
525		// For interior mutability we need scoped refs.
526		{
527			let mut x_in_bits = array_util::try_map(self.x_in_bits, |bit| index.get_mut(bit))?;
528			let mut y_in_bits = array_util::try_map(self.y_in_bits, |bit| index.get_mut(bit))?;
529			let mut out_bits = array_util::try_map(self.out_bits, |bit| index.get_mut(bit))?;
530			let mut new_prod_high_bits =
531				array_util::try_map(self.new_prod_high_bits, |bit| index.get_mut(bit))?;
532
533			let mut x_in = index.get_mut(self.xin)?;
534			let mut y_in = index.get_mut(self.yin)?;
535			let mut out_low = index.get_mut(self.out_low)?;
536			let mut out_high = index.get_mut(self.out_high)?;
537
538			for (i, (x, y)) in x_vals
539				.clone()
540				.into_iter()
541				.zip(y_vals.clone().into_iter())
542				.enumerate()
543			{
544				let res = x.val() as u64 * y.val() as u64;
545				let prod_hi = B32::new((res >> 32) as u32);
546				let prod_lo = B32::new(res as u32);
547				set_packed_slice(&mut x_in, i, x);
548				set_packed_slice(&mut y_in, i, y);
549				set_packed_slice(&mut out_low, i, prod_lo);
550				let new_prod_hi = if (x.val() as i32) < 0 {
551					B32::new(prod_hi.val().wrapping_sub(y.val()))
552				} else {
553					prod_hi
554				};
555				let out_hi = if (y.val() as i32) < 0 {
556					B32::new(new_prod_hi.val().wrapping_sub(x.val()))
557				} else {
558					new_prod_hi
559				};
560				set_packed_slice(&mut out_high, i, out_hi);
561
562				for bit_idx in 0..32 {
563					set_packed_slice(
564						&mut x_in_bits[bit_idx],
565						i,
566						B1::from(u32::is_bit_set_at(x, bit_idx)),
567					);
568					set_packed_slice(
569						&mut y_in_bits[bit_idx],
570						i,
571						B1::from(u32::is_bit_set_at(y, bit_idx)),
572					);
573					set_packed_slice(
574						&mut out_bits[bit_idx + 32],
575						i,
576						B1::from(u32::is_bit_set_at(out_hi, bit_idx)),
577					);
578					set_packed_slice(
579						&mut new_prod_high_bits[bit_idx],
580						i,
581						B1::from(u32::is_bit_set_at(new_prod_hi, bit_idx)),
582					);
583				}
584			}
585		}
586
587		self.mul_inner.populate(index, x_vals, y_vals)?;
588		self.y_sub.populate(index)?;
589		self.x_sub.populate(index)?;
590
591		Ok(())
592	}
593}
594
595/// A gadget that computes Signed x Unsigned multiplication with the full 64-bit signed result
596#[derive(Debug)]
597pub struct MulSU32 {
598	mul_inner: MulUU32,
599	x_in_bits: [Col<B1>; 32],
600	y_in_bits: [Col<B1>; 32],
601	out_high_bits: [Col<B1>; 32],
602	y_sub: WideU32Sub,
603
604	/// Output columns
605	pub xin: Col<B32>,
606	pub yin: Col<B32>,
607	pub out_high: Col<B32>,
608	pub out_low: Col<B32>,
609}
610
611impl MulSU32 {
612	pub fn new(table: &mut TableBuilder) -> Self {
613		let x_in_bits = table.add_committed_multiple("x_in_bits");
614		let y_in_bits = table.add_committed_multiple("y_in_bits");
615
616		let x_in = table.add_computed("x_in", pack_fp(x_in_bits));
617		let y_in = table.add_computed("y_in", pack_fp(y_in_bits));
618
619		let x_is_negative = x_in_bits[31];
620		let mul_inner = MulUU32::with_inputs(table, x_in_bits, y_in_bits);
621		let prod_high_bits: [_; 32] = array::from_fn(|i| mul_inner.out_high_bits[i]);
622
623		let mut inner_y_sub_table = table.with_namespace("X_less_than_zero");
624		let y_sub = WideU32Sub::new(
625			&mut inner_y_sub_table,
626			prod_high_bits,
627			y_in_bits,
628			U32SubFlags {
629				commit_zout: true,
630				..Default::default()
631			},
632		);
633		let out_high_bits = array::from_fn(|bit| {
634			table.add_computed(
635				format!("out_high[{bit}]"),
636				prod_high_bits[bit] + x_is_negative * (prod_high_bits[bit] + y_sub.zout[bit]),
637			)
638		});
639
640		let out_low_bits: [_; 32] = array::from_fn(|i| mul_inner.out_low_bits[i]);
641
642		let out_high = table.add_computed("out_high", pack_fp(out_high_bits));
643		let out_low = table.add_computed("out_low", pack_fp(out_low_bits));
644
645		Self {
646			mul_inner,
647			x_in_bits,
648			y_in_bits,
649			y_sub,
650			out_high_bits,
651			xin: x_in,
652			yin: y_in,
653			out_low,
654			out_high,
655		}
656	}
657
658	pub fn populate_with_inputs<P>(
659		&self,
660		index: &mut TableWitnessSegment<P>,
661		x_vals: impl IntoIterator<Item = B32> + Clone,
662		y_vals: impl IntoIterator<Item = B32> + Clone,
663	) -> Result<()>
664	where
665		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B32>,
666	{
667		{
668			let mut x_in_bits = array_util::try_map(self.x_in_bits, |bit| index.get_mut(bit))?;
669			let mut y_in_bits = array_util::try_map(self.y_in_bits, |bit| index.get_mut(bit))?;
670			let mut out_high_bits =
671				array_util::try_map(self.out_high_bits, |bit| index.get_mut(bit))?;
672
673			let mut x_in = index.get_mut(self.xin)?;
674			let mut y_in = index.get_mut(self.yin)?;
675			let mut out_low = index.get_mut(self.out_low)?;
676			let mut out_high = index.get_mut(self.out_high)?;
677
678			for (i, (x, y)) in x_vals
679				.clone()
680				.into_iter()
681				.zip(y_vals.clone().into_iter())
682				.enumerate()
683			{
684				let res = x.val() as u64 * y.val() as u64;
685				let prod_hi = B32::new((res >> 32) as u32);
686				let prod_lo = B32::new(res as u32);
687				set_packed_slice(&mut x_in, i, x);
688				set_packed_slice(&mut y_in, i, y);
689				set_packed_slice(&mut out_low, i, prod_lo);
690				let out_hi = if (x.val() as i32) < 0 {
691					B32::new(prod_hi.val().wrapping_sub(y.val()))
692				} else {
693					prod_hi
694				};
695				set_packed_slice(&mut out_high, i, out_hi);
696
697				for bit_idx in 0..32 {
698					set_packed_slice(
699						&mut x_in_bits[bit_idx],
700						i,
701						B1::from(u32::is_bit_set_at(x, bit_idx)),
702					);
703					set_packed_slice(
704						&mut y_in_bits[bit_idx],
705						i,
706						B1::from(u32::is_bit_set_at(y, bit_idx)),
707					);
708					set_packed_slice(
709						&mut out_high_bits[bit_idx],
710						i,
711						B1::from(u32::is_bit_set_at(out_hi, bit_idx)),
712					);
713				}
714			}
715		}
716
717		self.mul_inner.populate(index, x_vals, y_vals)?;
718		self.y_sub.populate(index)?;
719
720		Ok(())
721	}
722}
723
724/// Simple struct to convert to and from Two's complement representation based on bits. See
725/// [`SignConverter::new`]
726///
727/// NOTE: *We do not handle witness generation for the `converted_bits` and should be handled by
728/// caller*
729#[derive(Debug)]
730pub struct SignConverter<UPrimitive: UnsignedAddPrimitives, const BIT_LENGTH: usize> {
731	twos_complement: TwosComplement<UPrimitive, BIT_LENGTH>,
732
733	// Output columns
734	pub converted_bits: [Col<B1>; BIT_LENGTH],
735}
736
737impl<UPrimitive: UnsignedAddPrimitives, const BIT_LENGTH: usize>
738	SignConverter<UPrimitive, BIT_LENGTH>
739{
740	/// Used to conditionally select bit representation based on the MSB (sign bit)
741	///
742	/// ## Parameters
743	/// * `in_bits`: The input bits from MSB to LSB
744	/// * `conditional`: The conditional bit to choose input bits, or it's two's complement
745	///
746	/// ## Example
747	/// - If the conditional is zero, the output will be the input bits.
748	/// - If the conditional is one, the output will be the two's complement of input bits.
749	pub fn new(
750		table: &mut TableBuilder,
751		xin: [Col<B1>; BIT_LENGTH],
752		conditional: Expr<B1, 1>,
753	) -> Self {
754		let twos_complement = TwosComplement::new(table, xin);
755		let converted_bits = array::from_fn(|bit| {
756			table.add_computed(
757				format!("converted_bits[{bit}]"),
758				twos_complement.result_bits[bit] * conditional.clone()
759					+ (conditional.clone() + B1::ONE) * xin[bit],
760			)
761		});
762		Self {
763			twos_complement,
764			converted_bits,
765		}
766	}
767
768	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<()>
769	where
770		P: PackedField<Scalar = B128> + PackedExtension<B1>,
771	{
772		self.twos_complement.populate(index)
773	}
774}
775
776/// Simple gadget that's used to convert to and from two's complement binary representations
777#[derive(Debug)]
778pub struct TwosComplement<UPrimitive: UnsignedAddPrimitives, const BIT_LENGTH: usize> {
779	inverted: [Col<B1>; BIT_LENGTH],
780	inner_incr: Incr<UPrimitive, BIT_LENGTH>,
781
782	// Input columns
783	pub xin: [Col<B1>; BIT_LENGTH],
784	// Output columns
785	pub result_bits: [Col<B1>; BIT_LENGTH],
786}
787
788impl<UPrimitive: UnsignedAddPrimitives, const BIT_LENGTH: usize>
789	TwosComplement<UPrimitive, BIT_LENGTH>
790{
791	pub fn new(table: &mut TableBuilder, xin: [Col<B1>; BIT_LENGTH]) -> Self {
792		let inverted =
793			array::from_fn(|i| table.add_computed(format!("inverted[{i}]"), xin[i] + B1::ONE));
794		let mut inner_table = table.with_namespace("Increment");
795		let inner_incr = Incr::new(&mut inner_table, inverted);
796
797		Self {
798			inverted,
799			result_bits: inner_incr.zout,
800			inner_incr,
801			xin,
802		}
803	}
804
805	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
806	where
807		P: PackedField<Scalar = B128> + PackedExtension<B1>,
808	{
809		let one = PackedSubfield::<P, B1>::broadcast(B1::ONE);
810		for (inverted, xin) in izip!(self.inverted.iter(), self.xin.iter()) {
811			let inp = index.get(*xin)?;
812			let mut inverted = index.get_mut(*inverted)?;
813			for (inp, value) in izip!(inp.iter(), inverted.iter_mut()) {
814				*value = *inp + one;
815			}
816		}
817
818		self.inner_incr.populate(index)?;
819
820		Ok(())
821	}
822}