binius_m3/gadgets/
sub.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{array, marker::PhantomData};
4
5use binius_core::oracle::ShiftVariant;
6use binius_field::{
7	Field, PackedExtension, PackedField, PackedFieldIndexable, packed::set_packed_slice,
8};
9use itertools::izip;
10
11use crate::{
12	builder::{B128, TableBuilder, column::Col, types::B1, witness::TableWitnessSegment},
13	gadgets::add::UnsignedAddPrimitives,
14};
15
16/// A gadget for performing 32-bit integer subtraction on vertically-packed bit columns.
17///
18/// This gadget has input columns `xin` and `yin` for the two 32-bit integers to be subtracted, and
19/// an output column `zout`, and it constrains that `xin - yin = zout` as integers.
20#[derive(Debug)]
21pub struct U32Sub {
22	// Inputs
23	pub xin: Col<B1, 32>,
24	pub yin: Col<B1, 32>,
25
26	// Private
27	bout: Col<B1, 32>,
28	bout_shl: Col<B1, 32>,
29	bin: Col<B1, 32>,
30
31	// Outputs
32	/// The output column, either committed if `flags.commit_zout` is set, otherwise a linear
33	/// combination derived column.
34	pub zout: Col<B1, 32>,
35	/// This is `Some` if `flags.expose_final_borrow` is set, otherwise it is `None`.
36	pub final_borrow: Option<Col<B1>>,
37	/// Flags modifying the gadget's behavior.
38	pub flags: U32SubFlags,
39}
40
41/// Flags modifying the behavior of the [`U32Sub`] gadget.
42#[derive(Debug, Default, Clone)]
43pub struct U32SubFlags {
44	// Optionally a column for a dynamic borrow in bit. This *must* be zero in all bits except the
45	// 0th.
46	pub borrow_in_bit: Option<Col<B1, 32>>,
47	pub expose_final_borrow: bool,
48	pub commit_zout: bool,
49}
50
51impl U32Sub {
52	pub fn new(
53		table: &mut TableBuilder,
54		xin: Col<B1, 32>,
55		yin: Col<B1, 32>,
56		flags: U32SubFlags,
57	) -> Self {
58		let bout = table.add_committed("bout");
59		let bout_shl = table.add_shifted("bout_shl", bout, 5, 1, ShiftVariant::LogicalLeft);
60
61		let bin = if let Some(borrow_in_bit) = flags.borrow_in_bit {
62			table.add_computed("bin", bout_shl + borrow_in_bit)
63		} else {
64			bout_shl
65		};
66
67		let final_borrow = flags
68			.expose_final_borrow
69			.then(|| table.add_selected("final_borrow", bout, 31));
70
71		// Check that the equation holds:
72		//
73		//     (bin + (1 - xin)) * (bin + yin) + bin = bout
74		//
75		// Note that we can't use the actual expression does `xin - B1::ONE` because of the expr
76		// builder, but in tower fields the order does not matter.
77		table.assert_zero("borrow_out", (bin + (xin - B1::ONE)) * (bin + yin) + bin - bout);
78
79		let zout = if flags.commit_zout {
80			let zout = table.add_committed("zout");
81			table.assert_zero("zout", xin + yin + bin - zout);
82			zout
83		} else {
84			table.add_computed("zout", xin + yin + bin)
85		};
86
87		U32Sub {
88			xin,
89			yin,
90			bout,
91			bout_shl,
92			bin,
93			zout,
94			final_borrow,
95			flags,
96		}
97	}
98}
99
100impl U32Sub {
101	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
102	where
103		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
104	{
105		let xin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.xin)?;
106		let yin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.yin)?;
107		let mut bout: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.bout)?;
108		let mut zout: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.zout)?;
109		let mut bin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.bin)?;
110		let mut final_borrow = if let Some(final_borrow) = self.final_borrow {
111			let final_borrow = index.get_mut(final_borrow)?;
112			Some(final_borrow)
113		} else {
114			None
115		};
116
117		if let Some(borrow_in_bit) = self.flags.borrow_in_bit {
118			// This is u32 assumed to be either 0 or 1.
119			let borrow_in_bit = index.get_mut_as(borrow_in_bit)?;
120			let mut bout_shl = index.get_mut_as(self.bout_shl)?;
121
122			for i in 0..index.size() {
123				let (x_minus_y, borrow1) = xin[i].overflowing_sub(yin[i]);
124				let borrow2;
125				(zout[i], borrow2) = x_minus_y.overflowing_sub(borrow_in_bit[i]);
126				let borrow = borrow1 | borrow2;
127
128				bin[i] = xin[i] ^ yin[i] ^ zout[i];
129				bout[i] = (borrow as u32) << 31 | bin[i] >> 1;
130				bout_shl[i] = bout[i] << 1;
131
132				if let Some(ref mut final_borrow) = final_borrow {
133					set_packed_slice(
134						&mut *final_borrow,
135						i,
136						if borrow { B1::ONE } else { B1::ZERO },
137					);
138				}
139			}
140		} else {
141			// When the borrow in bit is fixed to zero, we can simplify the logic.
142			for i in 0..index.size() {
143				let borrow;
144				(zout[i], borrow) = xin[i].overflowing_sub(yin[i]);
145				bin[i] = xin[i] ^ yin[i] ^ zout[i];
146				bout[i] = (borrow as u32) << 31 | bin[i] >> 1;
147
148				if let Some(ref mut final_borrow) = final_borrow {
149					set_packed_slice(
150						&mut *final_borrow,
151						i,
152						if borrow { B1::ONE } else { B1::ZERO },
153					);
154				}
155			}
156		}
157
158		Ok(())
159	}
160}
161
162/// Gadget for unsigned subtraction using non-packed one-bit columns generic over `u32` and `u64`
163#[derive(Debug)]
164pub struct WideSub<UX: UnsignedAddPrimitives, const BIT_LENGTH: usize> {
165	/// Inputs
166	pub xin: [Col<B1>; BIT_LENGTH],
167	pub yin: [Col<B1>; BIT_LENGTH],
168
169	bin: Col<B1>,
170	bout: [Col<B1>; BIT_LENGTH],
171	_marker: PhantomData<UX>,
172
173	/// Outputs
174	pub zout: [Col<B1>; BIT_LENGTH],
175	pub final_borrow: Option<Col<B1>>,
176	pub flags: U32SubFlags,
177}
178
179impl<UX: UnsignedAddPrimitives, const BIT_LENGTH: usize> WideSub<UX, BIT_LENGTH> {
180	pub fn new(
181		table: &mut TableBuilder,
182		xin: [Col<B1>; BIT_LENGTH],
183		yin: [Col<B1>; BIT_LENGTH],
184		flags: U32SubFlags,
185	) -> Self {
186		let bout = table.add_committed_multiple("bout");
187
188		let bin: [_; BIT_LENGTH] = array::from_fn(|i| {
189			if i == 0 {
190				if let Some(borrow_in_bit) = flags.borrow_in_bit {
191					table.add_selected("bin[0]", borrow_in_bit, 0)
192				} else {
193					table.add_constant("bin[0]", [B1::ZERO])
194				}
195			} else {
196				bout[i - 1]
197			}
198		});
199
200		let final_borrow = flags.expose_final_borrow.then(|| bout[BIT_LENGTH - 1]);
201
202		let zout = if flags.commit_zout {
203			let zout = table.add_committed_multiple("zout");
204			for (bit, zout_bit) in zout.iter().enumerate() {
205				table
206					.assert_zero(format!("zout[{bit}]"), xin[bit] + yin[bit] + bin[bit] - *zout_bit)
207			}
208			zout
209		} else {
210			array::from_fn(|bit| {
211				table.add_computed(format!("zout[{bit}]"), xin[bit] + yin[bit] + bin[bit])
212			})
213		};
214
215		Self {
216			bin: bin[0],
217			xin,
218			yin,
219			bout,
220			zout,
221			final_borrow,
222			flags,
223			_marker: PhantomData,
224		}
225	}
226
227	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
228	where
229		P: PackedField<Scalar = B128> + PackedExtension<B1>,
230	{
231		let xin = array_util::try_map(self.xin, |bit_col| index.get(bit_col))?;
232		let yin = array_util::try_map(self.yin, |bit_col| index.get(bit_col))?;
233		let bout = array_util::try_map(self.bout, |bit_col| index.get_mut(bit_col))?;
234		let zout = array_util::try_map(self.zout, |bit_col| index.get_mut(bit_col))?;
235
236		type PB1<P> = <P as PackedExtension<B1>>::PackedSubfield;
237
238		let one = PB1::<P>::one();
239		let mut b_in = self
240			.flags
241			.borrow_in_bit
242			.map(|_| {
243				index
244					.get(self.bin)
245					.expect("witness index for borrow_in must be set before")
246					.to_vec()
247			})
248			.unwrap_or_else(|| {
249				let mut b_in = index
250					.get_mut(self.bin)
251					.expect("witness index for constant not set");
252				b_in.fill(PB1::<P>::zero());
253				vec![PB1::<P>::zero(); xin[0].len()]
254			});
255		for (x_bit, y_bit, mut b_out_bit, mut zout_bit) in
256			izip!(xin.into_iter(), yin.into_iter(), bout.into_iter(), zout.into_iter())
257		{
258			for (x, y, b_out, b_in, z_out) in izip!(
259				x_bit.iter().copied(),
260				y_bit.iter().copied(),
261				b_out_bit.iter_mut(),
262				b_in.iter_mut(),
263				zout_bit.iter_mut()
264			) {
265				let x_bit_inv = one + x;
266				let new_borrow = y * (*b_in) + x_bit_inv * (*b_in + y);
267				*z_out = x + y + (*b_in);
268				*b_out = new_borrow;
269				*b_in = new_borrow;
270			}
271		}
272
273		Ok(())
274	}
275}
276
277pub type WideU32Sub = WideSub<u32, 32>;
278
279#[cfg(test)]
280mod tests {
281	use binius_compute::cpu::alloc::CpuComputeAllocator;
282	use binius_field::{
283		arch::OptimalUnderlier128b, as_packed_field::PackedType, packed::get_packed_slice,
284	};
285	use rand::{Rng as _, SeedableRng, prelude::StdRng};
286
287	use super::*;
288	use crate::builder::{ConstraintSystem, WitnessIndex};
289
290	#[test]
291	fn prop_test_no_borrow() {
292		const N_ITER: usize = 1 << 14;
293
294		let mut rng = StdRng::seed_from_u64(0);
295		let test_vector: Vec<(u32, u32, u32, u32, bool)> = (0..N_ITER)
296			.map(|_| {
297				let x: u32 = rng.random();
298				let y: u32 = rng.random();
299				let z: u32 = x.wrapping_sub(y);
300				// (x, y, borrow_in, zout, final_borrow)
301				(x, y, 0x00000000, z, false)
302			})
303			.collect();
304
305		TestPlan {
306			dyn_borrow_in: false,
307			expose_final_borrow: false,
308			commit_zout: true,
309			test_vector,
310		}
311		.execute();
312	}
313
314	#[test]
315	fn prop_test_with_borrow() {
316		const N_ITER: usize = 1 << 14;
317
318		let mut rng = StdRng::seed_from_u64(0);
319		let test_vector: Vec<(u32, u32, u32, u32, bool)> = (0..N_ITER)
320			.map(|_| {
321				let x: u32 = rng.random();
322				let y: u32 = rng.random();
323				let borrow_in = rng.random::<bool>() as u32;
324				let (x_minus_y, borrow1) = x.overflowing_sub(y);
325				let (z, borrow2) = x_minus_y.overflowing_sub(borrow_in);
326				let final_borrow = borrow1 | borrow2;
327				(x, y, borrow_in, z, final_borrow)
328			})
329			.collect();
330
331		TestPlan {
332			dyn_borrow_in: true,
333			expose_final_borrow: true,
334			commit_zout: true,
335			test_vector,
336		}
337		.execute();
338	}
339
340	#[test]
341	fn test_borrow() {
342		// (x, y, borrow_in, zout, final_borrow)
343		let test_vector = [
344			(0x00000000, 0x00000001, 0x00000000, 0xFFFFFFFF, true), // 0 - 1 = max_u32 (underflow)
345			(0xFFFFFFFF, 0x00000001, 0x00000000, 0xFFFFFFFE, false), // max - 1 = max - 1
346			(0x80000000, 0x00000001, 0x00000000, 0x7FFFFFFF, false), // Sign bit transition
347			(0x00000005, 0x00000005, 0x00000001, 0xFFFFFFFF, true), /* 5 - 5 - 1 = -1 (borrow_in
348			                                                         * causes underflow) */
349		];
350		TestPlan {
351			dyn_borrow_in: true,
352			expose_final_borrow: true,
353			commit_zout: true,
354			test_vector: test_vector.to_vec(),
355		}
356		.execute();
357	}
358
359	struct TestPlan {
360		dyn_borrow_in: bool,
361		expose_final_borrow: bool,
362		commit_zout: bool,
363		/// (x, y, borrow_in, zout, final_borrow)
364		test_vector: Vec<(u32, u32, u32, u32, bool)>,
365	}
366
367	impl TestPlan {
368		fn execute(self) {
369			let mut cs = ConstraintSystem::new();
370			let mut table = cs.add_table("u32_sub");
371
372			let xin = table.add_committed::<B1, 32>("xin");
373			let yin = table.add_committed::<B1, 32>("yin");
374
375			let borrow_in = self
376				.dyn_borrow_in
377				.then_some(table.add_committed::<B1, 32>("borrow_in"));
378
379			let flags = U32SubFlags {
380				borrow_in_bit: borrow_in,
381				expose_final_borrow: self.expose_final_borrow,
382				commit_zout: self.commit_zout,
383			};
384			let subber = U32Sub::new(&mut table, xin, yin, flags);
385			assert!(subber.final_borrow.is_some() == self.expose_final_borrow);
386
387			let table_id = table.id();
388			let mut allocator = CpuComputeAllocator::new(1 << 16);
389			let allocator = allocator.into_bump_allocator();
390			let mut witness =
391				WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
392
393			let table_witness = witness
394				.init_table(table_id, self.test_vector.len())
395				.unwrap();
396			let mut segment = table_witness.full_segment();
397
398			{
399				let mut xin_bits = segment.get_mut_as::<u32, _, 32>(subber.xin).unwrap();
400				let mut yin_bits = segment.get_mut_as::<u32, _, 32>(subber.yin).unwrap();
401				let mut borrow_in_bits =
402					borrow_in.map(|borrow_in| segment.get_mut_as::<u32, _, 32>(borrow_in).unwrap());
403				for (i, (x, y, borrow_in, _, _)) in self.test_vector.iter().enumerate() {
404					xin_bits[i] = *x;
405					yin_bits[i] = *y;
406					if let Some(ref mut borrow_in_bits) = borrow_in_bits {
407						borrow_in_bits[i] = *borrow_in;
408					}
409				}
410			}
411
412			// Populate the gadget
413			subber.populate(&mut segment).unwrap();
414
415			{
416				// Verify results
417				let zout_bits = segment.get_as::<u32, _, 32>(subber.zout).unwrap();
418				let final_borrow = subber
419					.final_borrow
420					.map(|final_borrow| segment.get(final_borrow).unwrap());
421				for (i, (_, _, _, zout, expected_borrow)) in self.test_vector.iter().enumerate() {
422					assert_eq!(zout_bits[i], *zout);
423
424					if let Some(ref final_borrow) = final_borrow {
425						assert_eq!(get_packed_slice(final_borrow, i), B1::from(*expected_borrow));
426					}
427				}
428			}
429
430			// Validate constraint system
431			let ccs = cs.compile().unwrap();
432			let table_sizes = witness.table_sizes();
433			let witness = witness.into_multilinear_extension_index();
434
435			binius_core::constraint_system::validate::validate_witness(
436				&ccs,
437				&[],
438				&table_sizes,
439				&witness,
440			)
441			.unwrap();
442		}
443	}
444}