1use std::{array, marker::PhantomData};
4
5use binius_core::oracle::ShiftVariant;
6use binius_field::{
7 Field, PackedExtension, PackedField, PackedFieldIndexable, packed::set_packed_slice,
8};
9use itertools::izip;
10
11use crate::{
12 builder::{B128, TableBuilder, column::Col, types::B1, witness::TableWitnessSegment},
13 gadgets::add::UnsignedAddPrimitives,
14};
15
16#[derive(Debug)]
21pub struct U32Sub {
22 pub xin: Col<B1, 32>,
24 pub yin: Col<B1, 32>,
25
26 bout: Col<B1, 32>,
28 bout_shl: Col<B1, 32>,
29 bin: Col<B1, 32>,
30
31 pub zout: Col<B1, 32>,
35 pub final_borrow: Option<Col<B1>>,
37 pub flags: U32SubFlags,
39}
40
41#[derive(Debug, Default, Clone)]
43pub struct U32SubFlags {
44 pub borrow_in_bit: Option<Col<B1, 32>>,
47 pub expose_final_borrow: bool,
48 pub commit_zout: bool,
49}
50
51impl U32Sub {
52 pub fn new(
53 table: &mut TableBuilder,
54 xin: Col<B1, 32>,
55 yin: Col<B1, 32>,
56 flags: U32SubFlags,
57 ) -> Self {
58 let bout = table.add_committed("bout");
59 let bout_shl = table.add_shifted("bout_shl", bout, 5, 1, ShiftVariant::LogicalLeft);
60
61 let bin = if let Some(borrow_in_bit) = flags.borrow_in_bit {
62 table.add_computed("bin", bout_shl + borrow_in_bit)
63 } else {
64 bout_shl
65 };
66
67 let final_borrow = flags
68 .expose_final_borrow
69 .then(|| table.add_selected("final_borrow", bout, 31));
70
71 table.assert_zero("borrow_out", (bin + (xin - B1::ONE)) * (bin + yin) + bin - bout);
78
79 let zout = if flags.commit_zout {
80 let zout = table.add_committed("zout");
81 table.assert_zero("zout", xin + yin + bin - zout);
82 zout
83 } else {
84 table.add_computed("zout", xin + yin + bin)
85 };
86
87 U32Sub {
88 xin,
89 yin,
90 bout,
91 bout_shl,
92 bin,
93 zout,
94 final_borrow,
95 flags,
96 }
97 }
98}
99
100impl U32Sub {
101 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
102 where
103 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
104 {
105 let xin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.xin)?;
106 let yin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.yin)?;
107 let mut bout: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.bout)?;
108 let mut zout: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.zout)?;
109 let mut bin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.bin)?;
110 let mut final_borrow = if let Some(final_borrow) = self.final_borrow {
111 let final_borrow = index.get_mut(final_borrow)?;
112 Some(final_borrow)
113 } else {
114 None
115 };
116
117 if let Some(borrow_in_bit) = self.flags.borrow_in_bit {
118 let borrow_in_bit = index.get_mut_as(borrow_in_bit)?;
120 let mut bout_shl = index.get_mut_as(self.bout_shl)?;
121
122 for i in 0..index.size() {
123 let (x_minus_y, borrow1) = xin[i].overflowing_sub(yin[i]);
124 let borrow2;
125 (zout[i], borrow2) = x_minus_y.overflowing_sub(borrow_in_bit[i]);
126 let borrow = borrow1 | borrow2;
127
128 bin[i] = xin[i] ^ yin[i] ^ zout[i];
129 bout[i] = (borrow as u32) << 31 | bin[i] >> 1;
130 bout_shl[i] = bout[i] << 1;
131
132 if let Some(ref mut final_borrow) = final_borrow {
133 set_packed_slice(
134 &mut *final_borrow,
135 i,
136 if borrow { B1::ONE } else { B1::ZERO },
137 );
138 }
139 }
140 } else {
141 for i in 0..index.size() {
143 let borrow;
144 (zout[i], borrow) = xin[i].overflowing_sub(yin[i]);
145 bin[i] = xin[i] ^ yin[i] ^ zout[i];
146 bout[i] = (borrow as u32) << 31 | bin[i] >> 1;
147
148 if let Some(ref mut final_borrow) = final_borrow {
149 set_packed_slice(
150 &mut *final_borrow,
151 i,
152 if borrow { B1::ONE } else { B1::ZERO },
153 );
154 }
155 }
156 }
157
158 Ok(())
159 }
160}
161
162#[derive(Debug)]
164pub struct WideSub<UX: UnsignedAddPrimitives, const BIT_LENGTH: usize> {
165 pub xin: [Col<B1>; BIT_LENGTH],
167 pub yin: [Col<B1>; BIT_LENGTH],
168
169 bin: Col<B1>,
170 bout: [Col<B1>; BIT_LENGTH],
171 _marker: PhantomData<UX>,
172
173 pub zout: [Col<B1>; BIT_LENGTH],
175 pub final_borrow: Option<Col<B1>>,
176 pub flags: U32SubFlags,
177}
178
179impl<UX: UnsignedAddPrimitives, const BIT_LENGTH: usize> WideSub<UX, BIT_LENGTH> {
180 pub fn new(
181 table: &mut TableBuilder,
182 xin: [Col<B1>; BIT_LENGTH],
183 yin: [Col<B1>; BIT_LENGTH],
184 flags: U32SubFlags,
185 ) -> Self {
186 let bout = table.add_committed_multiple("bout");
187
188 let bin: [_; BIT_LENGTH] = array::from_fn(|i| {
189 if i == 0 {
190 if let Some(borrow_in_bit) = flags.borrow_in_bit {
191 table.add_selected("bin[0]", borrow_in_bit, 0)
192 } else {
193 table.add_constant("bin[0]", [B1::ZERO])
194 }
195 } else {
196 bout[i - 1]
197 }
198 });
199
200 let final_borrow = flags.expose_final_borrow.then(|| bout[BIT_LENGTH - 1]);
201
202 let zout = if flags.commit_zout {
203 let zout = table.add_committed_multiple("zout");
204 for (bit, zout_bit) in zout.iter().enumerate() {
205 table
206 .assert_zero(format!("zout[{bit}]"), xin[bit] + yin[bit] + bin[bit] - *zout_bit)
207 }
208 zout
209 } else {
210 array::from_fn(|bit| {
211 table.add_computed(format!("zout[{bit}]"), xin[bit] + yin[bit] + bin[bit])
212 })
213 };
214
215 Self {
216 bin: bin[0],
217 xin,
218 yin,
219 bout,
220 zout,
221 final_borrow,
222 flags,
223 _marker: PhantomData,
224 }
225 }
226
227 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
228 where
229 P: PackedField<Scalar = B128> + PackedExtension<B1>,
230 {
231 let xin = array_util::try_map(self.xin, |bit_col| index.get(bit_col))?;
232 let yin = array_util::try_map(self.yin, |bit_col| index.get(bit_col))?;
233 let bout = array_util::try_map(self.bout, |bit_col| index.get_mut(bit_col))?;
234 let zout = array_util::try_map(self.zout, |bit_col| index.get_mut(bit_col))?;
235
236 type PB1<P> = <P as PackedExtension<B1>>::PackedSubfield;
237
238 let one = PB1::<P>::one();
239 let mut b_in = self
240 .flags
241 .borrow_in_bit
242 .map(|_| {
243 index
244 .get(self.bin)
245 .expect("witness index for borrow_in must be set before")
246 .to_vec()
247 })
248 .unwrap_or_else(|| {
249 let mut b_in = index
250 .get_mut(self.bin)
251 .expect("witness index for constant not set");
252 b_in.fill(PB1::<P>::zero());
253 vec![PB1::<P>::zero(); xin[0].len()]
254 });
255 for (x_bit, y_bit, mut b_out_bit, mut zout_bit) in
256 izip!(xin.into_iter(), yin.into_iter(), bout.into_iter(), zout.into_iter())
257 {
258 for (x, y, b_out, b_in, z_out) in izip!(
259 x_bit.iter().copied(),
260 y_bit.iter().copied(),
261 b_out_bit.iter_mut(),
262 b_in.iter_mut(),
263 zout_bit.iter_mut()
264 ) {
265 let x_bit_inv = one + x;
266 let new_borrow = y * (*b_in) + x_bit_inv * (*b_in + y);
267 *z_out = x + y + (*b_in);
268 *b_out = new_borrow;
269 *b_in = new_borrow;
270 }
271 }
272
273 Ok(())
274 }
275}
276
277pub type WideU32Sub = WideSub<u32, 32>;
278
279#[cfg(test)]
280mod tests {
281 use binius_compute::cpu::alloc::CpuComputeAllocator;
282 use binius_field::{
283 arch::OptimalUnderlier128b, as_packed_field::PackedType, packed::get_packed_slice,
284 };
285 use rand::{Rng as _, SeedableRng, prelude::StdRng};
286
287 use super::*;
288 use crate::builder::{ConstraintSystem, WitnessIndex};
289
290 #[test]
291 fn prop_test_no_borrow() {
292 const N_ITER: usize = 1 << 14;
293
294 let mut rng = StdRng::seed_from_u64(0);
295 let test_vector: Vec<(u32, u32, u32, u32, bool)> = (0..N_ITER)
296 .map(|_| {
297 let x: u32 = rng.random();
298 let y: u32 = rng.random();
299 let z: u32 = x.wrapping_sub(y);
300 (x, y, 0x00000000, z, false)
302 })
303 .collect();
304
305 TestPlan {
306 dyn_borrow_in: false,
307 expose_final_borrow: false,
308 commit_zout: true,
309 test_vector,
310 }
311 .execute();
312 }
313
314 #[test]
315 fn prop_test_with_borrow() {
316 const N_ITER: usize = 1 << 14;
317
318 let mut rng = StdRng::seed_from_u64(0);
319 let test_vector: Vec<(u32, u32, u32, u32, bool)> = (0..N_ITER)
320 .map(|_| {
321 let x: u32 = rng.random();
322 let y: u32 = rng.random();
323 let borrow_in = rng.random::<bool>() as u32;
324 let (x_minus_y, borrow1) = x.overflowing_sub(y);
325 let (z, borrow2) = x_minus_y.overflowing_sub(borrow_in);
326 let final_borrow = borrow1 | borrow2;
327 (x, y, borrow_in, z, final_borrow)
328 })
329 .collect();
330
331 TestPlan {
332 dyn_borrow_in: true,
333 expose_final_borrow: true,
334 commit_zout: true,
335 test_vector,
336 }
337 .execute();
338 }
339
340 #[test]
341 fn test_borrow() {
342 let test_vector = [
344 (0x00000000, 0x00000001, 0x00000000, 0xFFFFFFFF, true), (0xFFFFFFFF, 0x00000001, 0x00000000, 0xFFFFFFFE, false), (0x80000000, 0x00000001, 0x00000000, 0x7FFFFFFF, false), (0x00000005, 0x00000005, 0x00000001, 0xFFFFFFFF, true), ];
350 TestPlan {
351 dyn_borrow_in: true,
352 expose_final_borrow: true,
353 commit_zout: true,
354 test_vector: test_vector.to_vec(),
355 }
356 .execute();
357 }
358
359 struct TestPlan {
360 dyn_borrow_in: bool,
361 expose_final_borrow: bool,
362 commit_zout: bool,
363 test_vector: Vec<(u32, u32, u32, u32, bool)>,
365 }
366
367 impl TestPlan {
368 fn execute(self) {
369 let mut cs = ConstraintSystem::new();
370 let mut table = cs.add_table("u32_sub");
371
372 let xin = table.add_committed::<B1, 32>("xin");
373 let yin = table.add_committed::<B1, 32>("yin");
374
375 let borrow_in = self
376 .dyn_borrow_in
377 .then_some(table.add_committed::<B1, 32>("borrow_in"));
378
379 let flags = U32SubFlags {
380 borrow_in_bit: borrow_in,
381 expose_final_borrow: self.expose_final_borrow,
382 commit_zout: self.commit_zout,
383 };
384 let subber = U32Sub::new(&mut table, xin, yin, flags);
385 assert!(subber.final_borrow.is_some() == self.expose_final_borrow);
386
387 let table_id = table.id();
388 let mut allocator = CpuComputeAllocator::new(1 << 16);
389 let allocator = allocator.into_bump_allocator();
390 let mut witness =
391 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
392
393 let table_witness = witness
394 .init_table(table_id, self.test_vector.len())
395 .unwrap();
396 let mut segment = table_witness.full_segment();
397
398 {
399 let mut xin_bits = segment.get_mut_as::<u32, _, 32>(subber.xin).unwrap();
400 let mut yin_bits = segment.get_mut_as::<u32, _, 32>(subber.yin).unwrap();
401 let mut borrow_in_bits =
402 borrow_in.map(|borrow_in| segment.get_mut_as::<u32, _, 32>(borrow_in).unwrap());
403 for (i, (x, y, borrow_in, _, _)) in self.test_vector.iter().enumerate() {
404 xin_bits[i] = *x;
405 yin_bits[i] = *y;
406 if let Some(ref mut borrow_in_bits) = borrow_in_bits {
407 borrow_in_bits[i] = *borrow_in;
408 }
409 }
410 }
411
412 subber.populate(&mut segment).unwrap();
414
415 {
416 let zout_bits = segment.get_as::<u32, _, 32>(subber.zout).unwrap();
418 let final_borrow = subber
419 .final_borrow
420 .map(|final_borrow| segment.get(final_borrow).unwrap());
421 for (i, (_, _, _, zout, expected_borrow)) in self.test_vector.iter().enumerate() {
422 assert_eq!(zout_bits[i], *zout);
423
424 if let Some(ref final_borrow) = final_borrow {
425 assert_eq!(get_packed_slice(final_borrow, i), B1::from(*expected_borrow));
426 }
427 }
428 }
429
430 let ccs = cs.compile().unwrap();
432 let table_sizes = witness.table_sizes();
433 let witness = witness.into_multilinear_extension_index();
434
435 binius_core::constraint_system::validate::validate_witness(
436 &ccs,
437 &[],
438 &table_sizes,
439 &witness,
440 )
441 .unwrap();
442 }
443 }
444}