binius_m3/gadgets/
div.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::array;
4
5use binius_field::{Field, PackedExtension, PackedField, packed::set_packed_slice};
6use itertools::izip;
7
8use crate::{
9	builder::{B1, B32, B64, B128, Col, TableBuilder, TableWitnessSegment},
10	gadgets::{
11		add::{U32AddFlags, WideAdd},
12		mul::{MulSS32, MulUU32, SignConverter, UnsignedMulPrimitives},
13		sub::{U32SubFlags, WideSub},
14		util::pack_fp,
15	},
16};
17
18/// Gadget for unsigned division of two u32s.
19///
20/// `p = q*a + r`
21#[derive(Debug)]
22pub struct DivUU32 {
23	mul_inner: MulUU32,
24	sum: WideAdd<u64, 64>,
25	sub: WideSub<u64, 64>,
26
27	p_in_bits: [Col<B1>; 32],
28	q_in_bits: [Col<B1>; 32],
29	out_div_bits: [Col<B1>; 32],
30	out_rem_bits: [Col<B1>; 32],
31
32	pub p_in: Col<B32>,
33	pub q_in: Col<B32>,
34	pub out_div: Col<B32>,
35	pub out_rem: Col<B32>,
36}
37
38impl DivUU32 {
39	pub fn new(table: &mut TableBuilder) -> Self {
40		let zero = table.add_constant("zero", [B1::ZERO]);
41		let p_in_bits = table.add_committed_multiple("p_in_bits");
42		let q_in_bits = table.add_committed_multiple("q_in_bits");
43
44		let p_in = table.add_computed("p_in", pack_fp(p_in_bits));
45		let q_in = table.add_computed("q_in", pack_fp(q_in_bits));
46
47		let zero_extend_p: [_; 64] = array::from_fn(|i| if i < 32 { p_in_bits[i] } else { zero });
48
49		let out_div_bits = table.add_committed_multiple("out_div_bits");
50		let out_rem_bits = table.add_committed_multiple("out_rem_bits");
51
52		let zero_extend_q: [_; 64] = array::from_fn(|i| if i < 32 { q_in_bits[i] } else { zero });
53		let zero_extend_rem = array::from_fn(|i| if i < 32 { out_rem_bits[i] } else { zero });
54
55		let out_div = table.add_computed("out_div", pack_fp(out_div_bits));
56		let out_rem = table.add_computed("out_rem", pack_fp(out_rem_bits));
57
58		let mul_inner = MulUU32::with_inputs(table, q_in_bits, out_div_bits);
59
60		// Check q is non-zero
61		table.assert_nonzero(q_in);
62
63		let product_cols = array::from_fn(|i| {
64			if i < 32 {
65				mul_inner.out_low_bits[i]
66			} else {
67				mul_inner.out_high_bits[i - 32]
68			}
69		});
70
71		// Check p = q * a + r in 64 bits
72		let sum = WideAdd::<u64, 64>::new(
73			table,
74			product_cols,
75			zero_extend_rem,
76			U32AddFlags {
77				commit_zout: true,
78				expose_final_carry: true,
79				..Default::default()
80			},
81		);
82
83		#[allow(clippy::needless_range_loop)]
84		for bit in 0..64 {
85			table.assert_zero(
86				format!("division_satisfied[{bit}]"),
87				zero_extend_p[bit] - sum.z_out[bit],
88			);
89		}
90
91		// Add constraint to make sure that r < q by computing s = r - q in a larger bit length.
92		// There maybe a better way to do it with channels and simpler comparator logic.
93		let mut inner_comparator = table.with_namespace("sign_comparator");
94		let sub = WideSub::<u64, 64>::new(
95			&mut inner_comparator,
96			zero_extend_rem,
97			zero_extend_q,
98			U32SubFlags {
99				commit_zout: true,
100				..Default::default()
101			},
102		);
103		// Check that the sign bit is set
104		table.assert_zero("less_than", sub.zout[63] + B1::ONE);
105
106		Self {
107			mul_inner,
108			sum,
109			sub,
110
111			p_in_bits,
112			q_in_bits,
113			out_div_bits,
114			out_rem_bits,
115
116			p_in,
117			q_in,
118			out_div,
119			out_rem,
120		}
121	}
122
123	pub fn populate_with_inputs<P>(
124		&self,
125		index: &mut TableWitnessSegment<P>,
126		p_vals: impl IntoIterator<Item = B32>,
127		q_vals: impl IntoIterator<Item = B32> + Clone,
128	) -> anyhow::Result<()>
129	where
130		P: PackedField<Scalar = B128>
131			+ PackedExtension<B1>
132			+ PackedExtension<B32>
133			+ PackedExtension<B64>,
134	{
135		let mut inner_div = Vec::new();
136		{
137			let mut p_in_bits = array_util::try_map(self.p_in_bits, |bit| index.get_mut(bit))?;
138			let mut q_in_bits = array_util::try_map(self.q_in_bits, |bit| index.get_mut(bit))?;
139			let mut out_div_bits =
140				array_util::try_map(self.out_div_bits, |bit| index.get_mut(bit))?;
141			let mut out_rem_bits =
142				array_util::try_map(self.out_rem_bits, |bit| index.get_mut(bit))?;
143
144			let mut p_in = index.get_mut(self.p_in)?;
145			let mut q_in = index.get_mut(self.q_in)?;
146			let mut out_div = index.get_mut(self.out_div)?;
147			let mut out_rem = index.get_mut(self.out_rem)?;
148
149			for (i, (p, q)) in izip!(p_vals, q_vals.clone()).enumerate() {
150				let div = p.val() / q.val();
151				let rem = p.val() % q.val();
152				set_packed_slice(&mut p_in, i, p);
153				set_packed_slice(&mut q_in, i, q);
154				let div = B32::new(div);
155				let rem = B32::new(rem);
156				set_packed_slice(&mut out_div, i, div);
157				set_packed_slice(&mut out_rem, i, rem);
158
159				inner_div.push(div);
160
161				for bit_idx in 0..32 {
162					set_packed_slice(
163						&mut p_in_bits[bit_idx],
164						i,
165						B1::from(u32::is_bit_set_at(p, bit_idx)),
166					);
167					set_packed_slice(
168						&mut q_in_bits[bit_idx],
169						i,
170						B1::from(u32::is_bit_set_at(q, bit_idx)),
171					);
172					set_packed_slice(
173						&mut out_div_bits[bit_idx],
174						i,
175						B1::from(u32::is_bit_set_at(div, bit_idx)),
176					);
177					set_packed_slice(
178						&mut out_rem_bits[bit_idx],
179						i,
180						B1::from(u32::is_bit_set_at(rem, bit_idx)),
181					);
182				}
183			}
184		}
185
186		self.mul_inner.populate(index, q_vals, inner_div)?;
187		self.sum.populate(index)?;
188		self.sub.populate(index)?;
189
190		Ok(())
191	}
192}
193
194/// Gadget for signed division of two i32s.
195///
196/// `p = q*a + r` where p and r have the same sign bit.
197#[derive(Debug)]
198pub struct DivSS32 {
199	mul_inner: MulSS32,
200	sum: WideAdd<u64, 64>,
201	sub: WideAdd<u64, 64>,
202	abs_r_value: SignConverter<u64, 64>,
203	neg_abs_q_value: SignConverter<u64, 64>,
204	abs_r_bits: [Col<B1>; 64],
205	neg_abs_q_bits: [Col<B1>; 64],
206
207	pub p_in_bits: [Col<B1>; 32],
208	pub q_in_bits: [Col<B1>; 32],
209	pub out_div_bits: [Col<B1>; 32],
210	pub out_rem_bits: [Col<B1>; 32],
211
212	pub p_in: Col<B32>,
213	pub q_in: Col<B32>,
214	pub out_div: Col<B32>,
215	pub out_rem: Col<B32>,
216}
217
218impl DivSS32 {
219	pub fn new(table: &mut TableBuilder) -> Self {
220		let p_in_bits = table.add_committed_multiple("p_in_bits");
221		let q_in_bits = table.add_committed_multiple("q_in_bits");
222
223		let p_in = table.add_computed("p_in", pack_fp(p_in_bits));
224		let q_in = table.add_computed("q_in", pack_fp(q_in_bits));
225
226		let sign_extend_p: [_; 64] =
227			array::from_fn(|i| if i < 32 { p_in_bits[i] } else { p_in_bits[31] });
228
229		let out_div_bits = table.add_committed_multiple("out_div_bits");
230		let out_rem_bits = table.add_committed_multiple("out_rem_bits");
231
232		// Check sign(p) == sign(r)
233		table.assert_zero("sign_dividend_eq_sign_rem", p_in_bits[31] - out_rem_bits[31]);
234
235		let sign_extend_q: [_; 64] =
236			array::from_fn(|i| if i < 32 { q_in_bits[i] } else { q_in_bits[31] });
237		let sign_extend_rem = array::from_fn(|i| {
238			if i < 32 {
239				out_rem_bits[i]
240			} else {
241				out_rem_bits[31]
242			}
243		});
244
245		let out_div = table.add_computed("out_div", pack_fp(out_div_bits));
246		let out_rem = table.add_computed("out_rem", pack_fp(out_rem_bits));
247
248		let mul_inner = MulSS32::with_input(table, q_in_bits, out_div_bits);
249
250		// Check q is non-zero
251		table.assert_nonzero(q_in);
252
253		// Check p = q * a + r in 64 bits
254		let sum = WideAdd::<u64, 64>::new(
255			table,
256			mul_inner.out_bits,
257			sign_extend_rem,
258			U32AddFlags {
259				commit_zout: true,
260				expose_final_carry: true,
261				..Default::default()
262			},
263		);
264
265		#[allow(clippy::needless_range_loop)]
266		for bit in 0..64 {
267			table.assert_zero(
268				format!("division_satisfied_{bit}"),
269				sign_extend_p[bit] - sum.z_out[bit],
270			);
271		}
272
273		// Add constraint to make sure that |r| < |q| by computing s = |r| - |q| in a larger bit
274		// length. There maybe a better way to do it with channels and simpler comparator logic.
275		let r_is_negative = out_rem_bits[31];
276		let mut inner_abs_rem_table = table.with_namespace("rem_abs_value");
277		let abs_r_value =
278			SignConverter::new(&mut inner_abs_rem_table, sign_extend_rem, r_is_negative.into());
279		let abs_r_bits = abs_r_value.converted_bits;
280
281		let q_is_positive = q_in_bits[31] + B1::ONE;
282		let mut inner_neg_abs_q_table = table.with_namespace("neg_abs_q");
283		let neg_abs_q_value =
284			SignConverter::new(&mut inner_neg_abs_q_table, sign_extend_q, q_is_positive);
285		let neg_abs_q_bits = neg_abs_q_value.converted_bits;
286
287		let sub = WideAdd::<u64, 64>::new(
288			table,
289			abs_r_bits,
290			neg_abs_q_bits,
291			U32AddFlags {
292				commit_zout: true,
293				..Default::default()
294			},
295		);
296		// Check that the sign bit is set
297		table.assert_zero("less_than", sub.z_out[63] + B1::ONE);
298
299		Self {
300			mul_inner,
301			sum,
302			sub,
303			abs_r_value,
304			neg_abs_q_value,
305
306			abs_r_bits,
307			neg_abs_q_bits,
308			p_in_bits,
309			q_in_bits,
310			out_div_bits,
311			out_rem_bits,
312
313			p_in,
314			q_in,
315			out_div,
316			out_rem,
317		}
318	}
319
320	pub fn populate_with_inputs<P>(
321		&self,
322		index: &mut TableWitnessSegment<P>,
323		p_vals: impl IntoIterator<Item = B32>,
324		q_vals: impl IntoIterator<Item = B32> + Clone,
325	) -> anyhow::Result<()>
326	where
327		P: PackedField<Scalar = B128> + PackedExtension<B1> + PackedExtension<B32>,
328	{
329		// This vector holds the witness data for the inner multiplication gadget
330		let mut inner_div = Vec::new();
331
332		{
333			let mut p_in_bits = array_util::try_map(self.p_in_bits, |bit| index.get_mut(bit))?;
334			let mut q_in_bits = array_util::try_map(self.q_in_bits, |bit| index.get_mut(bit))?;
335			let mut out_div_bits =
336				array_util::try_map(self.out_div_bits, |bit| index.get_mut(bit))?;
337			let mut out_rem_bits =
338				array_util::try_map(self.out_rem_bits, |bit| index.get_mut(bit))?;
339			let mut abs_r_bits = array_util::try_map(self.abs_r_bits, |bit| index.get_mut(bit))?;
340			let mut neg_abs_q_bits =
341				array_util::try_map(self.neg_abs_q_bits, |bit| index.get_mut(bit))?;
342
343			let mut p_in = index.get_mut(self.p_in)?;
344			let mut q_in = index.get_mut(self.q_in)?;
345			let mut out_div = index.get_mut(self.out_div)?;
346			let mut out_rem = index.get_mut(self.out_rem)?;
347
348			for (i, (p, q)) in izip!(p_vals, q_vals.clone()).enumerate() {
349				let p_i32 = p.val() as i32;
350				let q_i32 = q.val() as i32;
351				let div = p_i32 / q_i32;
352				let rem = p_i32 % q_i32;
353				let abs_rem_b64 = if rem < 0 {
354					B64::new((-rem) as u64)
355				} else {
356					B64::new(rem as u64)
357				};
358				let neg_abs_q_b64 = if q_i32 < 0 {
359					B64::new(q_i32 as i64 as u64)
360				} else {
361					B64::new((-q_i32) as i64 as u64)
362				};
363				set_packed_slice(&mut p_in, i, p);
364				set_packed_slice(&mut q_in, i, q);
365				let div = B32::new(div as u32);
366				let rem = B32::new(rem as u32);
367				set_packed_slice(&mut out_div, i, div);
368				set_packed_slice(&mut out_rem, i, rem);
369				inner_div.push(div);
370
371				for bit_idx in 0..32 {
372					set_packed_slice(
373						&mut p_in_bits[bit_idx],
374						i,
375						B1::from(u32::is_bit_set_at(p, bit_idx)),
376					);
377					set_packed_slice(
378						&mut q_in_bits[bit_idx],
379						i,
380						B1::from(u32::is_bit_set_at(q, bit_idx)),
381					);
382					set_packed_slice(
383						&mut out_div_bits[bit_idx],
384						i,
385						B1::from(u32::is_bit_set_at(div, bit_idx)),
386					);
387					set_packed_slice(
388						&mut out_rem_bits[bit_idx],
389						i,
390						B1::from(u32::is_bit_set_at(rem, bit_idx)),
391					);
392				}
393				for bit_idx in 0..64 {
394					set_packed_slice(
395						&mut abs_r_bits[bit_idx],
396						i,
397						B1::from(u64::is_bit_set_at(abs_rem_b64, bit_idx)),
398					);
399					set_packed_slice(
400						&mut neg_abs_q_bits[bit_idx],
401						i,
402						B1::from(u64::is_bit_set_at(neg_abs_q_b64, bit_idx)),
403					);
404				}
405			}
406		}
407
408		self.mul_inner
409			.populate_with_inputs(index, q_vals, inner_div)?;
410		self.sum.populate(index)?;
411		self.abs_r_value.populate(index)?;
412		self.neg_abs_q_value.populate(index)?;
413		self.sub.populate(index)?;
414
415		Ok(())
416	}
417}