binius_m3/gadgets/u32/
add.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_core::oracle::ShiftVariant;
4use binius_field::{packed::set_packed_slice, Field, PackedExtension, PackedFieldIndexable};
5
6use crate::builder::{column::Col, types::B1, witness::TableWitnessSegment, TableBuilder, B128};
7
8/// A gadget for performing 32-bit integer addition on vertically-packed bit columns.
9///
10/// This gadget has input columns `xin` and `yin` for the two 32-bit integers to be added, and an
11/// output column `zout`, and it constrains that `xin + yin = zout` as integers.
12#[derive(Debug)]
13pub struct U32Add {
14	// Inputs
15	pub xin: Col<B1, 32>,
16	pub yin: Col<B1, 32>,
17
18	// Private
19	cin: Col<B1, 32>,
20	cout: Col<B1, 32>,
21	cout_shl: Col<B1, 32>,
22
23	// Outputs
24	/// The output column, either committed if `flags.commit_zout` is set, otherwise a linear
25	/// combination derived column.
26	pub zout: Col<B1, 32>,
27	/// This is `Some` if `flags.expose_final_carry` is set, otherwise it is `None`.
28	pub final_carry: Option<Col<B1>>,
29	/// Flags modifying the gadget's behavior.
30	pub flags: U32AddFlags,
31}
32
33/// Flags modifying the behavior of the [`U32Add`] gadget.
34#[derive(Debug, Default, Clone)]
35pub struct U32AddFlags {
36	// Optionally a column for a dynamic carry in bit. This *must* be zero in all bits except the
37	// 0th.
38	pub carry_in_bit: Option<Col<B1, 32>>,
39	pub commit_zout: bool,
40	pub expose_final_carry: bool,
41}
42
43impl U32Add {
44	pub fn new(
45		table: &mut TableBuilder,
46		xin: Col<B1, 32>,
47		yin: Col<B1, 32>,
48		flags: U32AddFlags,
49	) -> Self {
50		let cout = table.add_committed::<B1, 32>("cout");
51		let cout_shl = table.add_shifted("cout_shl", cout, 5, 1, ShiftVariant::LogicalLeft);
52
53		let cin = if let Some(carry_in_bit) = flags.carry_in_bit {
54			table.add_computed("cin", cout_shl + carry_in_bit)
55		} else {
56			cout_shl
57		};
58
59		let final_carry = flags
60			.expose_final_carry
61			.then(|| table.add_selected("final_carry", cout, 31));
62
63		table.assert_zero("carry_out", (xin + cin) * (yin + cin) + cin - cout);
64
65		let zout = if flags.commit_zout {
66			let zout = table.add_committed::<B1, 32>("zout");
67			table.assert_zero("zout", xin + yin + cin - zout);
68			zout
69		} else {
70			table.add_computed("zout", xin + yin + cin)
71		};
72
73		Self {
74			xin,
75			yin,
76			cin,
77			cout,
78			cout_shl,
79			final_carry,
80			zout,
81			flags,
82		}
83	}
84
85	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
86	where
87		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
88	{
89		let xin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.xin)?;
90		let yin = index.get_mut_as(self.yin)?;
91		let mut cout = index.get_mut_as(self.cout)?;
92		let mut zout = index.get_mut_as(self.zout)?;
93		let mut final_carry = if let Some(final_carry) = self.final_carry {
94			let final_carry = index.get_mut(final_carry)?;
95			Some(final_carry)
96		} else {
97			None
98		};
99
100		if let Some(carry_in_bit_col) = self.flags.carry_in_bit {
101			// This is u32 assumed to be either 0 or 1.
102			let carry_in_bit = index.get_mut_as(carry_in_bit_col)?;
103
104			let mut cin = index.get_mut_as(self.cin)?;
105			let mut cout_shl = index.get_mut_as(self.cout_shl)?;
106			for i in 0..index.size() {
107				let (x_plus_y, carry0) = xin[i].overflowing_add(yin[i]);
108				let carry1;
109				(zout[i], carry1) = x_plus_y.overflowing_add(carry_in_bit[i]);
110				let carry = carry0 | carry1;
111
112				cin[i] = xin[i] ^ yin[i] ^ zout[i];
113				cout[i] = (carry as u32) << 31 | cin[i] >> 1;
114				cout_shl[i] = cout[i] << 1;
115
116				if let Some(ref mut final_carry) = final_carry {
117					set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
118				}
119			}
120		} else {
121			// When the carry in bit is fixed to zero, we can simplify the logic.
122			let mut cin = index.get_mut_as(self.cin)?;
123			for i in 0..index.size() {
124				let carry;
125				(zout[i], carry) = xin[i].overflowing_add(yin[i]);
126				cin[i] = xin[i] ^ yin[i] ^ zout[i];
127				cout[i] = (carry as u32) << 31 | cin[i] >> 1;
128				if let Some(ref mut final_carry) = final_carry {
129					set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
130				}
131			}
132		};
133		Ok(())
134	}
135}
136
137#[cfg(test)]
138mod tests {
139	use binius_field::{
140		arch::OptimalUnderlier128b, as_packed_field::PackedType, packed::get_packed_slice,
141	};
142	use bumpalo::Bump;
143	use rand::{prelude::StdRng, Rng as _, SeedableRng};
144
145	use super::*;
146	use crate::builder::{ConstraintSystem, Statement, WitnessIndex};
147
148	#[test]
149	fn test_basic() {
150		const TABLE_SZ: usize = 1 << 14;
151
152		let mut cs = ConstraintSystem::new();
153		let mut table = cs.add_table("u32_add test");
154
155		let xin = table.add_committed::<B1, 32>("xin");
156		let yin = table.add_committed::<B1, 32>("yin");
157
158		let adder = U32Add::new(&mut table, xin, yin, U32AddFlags::default());
159		let table_id = table.id();
160		let statement = Statement {
161			boundaries: vec![],
162			table_sizes: vec![TABLE_SZ],
163		};
164		let allocator = Bump::new();
165		let mut witness =
166			WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
167
168		let table_witness = witness.init_table(table_id, TABLE_SZ).unwrap();
169		let mut segment = table_witness.full_segment();
170
171		let mut rng = StdRng::seed_from_u64(0);
172
173		// Generate random u32 operands and expected results.
174		//
175		// (x, y, z)
176		let test_vector: Vec<(u32, u32, u32)> = (0..segment.size())
177			.map(|_| {
178				let x = rng.gen::<u32>();
179				let y = rng.gen::<u32>();
180				let z = x.wrapping_add(y);
181				(x, y, z)
182			})
183			.collect();
184
185		{
186			let mut xin_bits = segment.get_mut_as::<u32, _, 32>(adder.xin).unwrap();
187			let mut yin_bits = segment.get_mut_as::<u32, _, 32>(adder.yin).unwrap();
188			for (i, (x, y, _)) in test_vector.iter().enumerate() {
189				xin_bits[i] = *x;
190				yin_bits[i] = *y;
191			}
192		}
193
194		// Populate the gadget
195		adder.populate(&mut segment).unwrap();
196
197		{
198			// Verify results
199			let zout_bits = segment.get_as::<u32, _, 32>(adder.zout).unwrap();
200			for (i, (_, _, z)) in test_vector.iter().enumerate() {
201				assert_eq!(zout_bits[i], *z);
202			}
203		}
204
205		// Validate constraint system
206		let ccs = cs.compile(&statement).unwrap();
207		let witness = witness.into_multilinear_extension_index();
208
209		binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
210	}
211
212	#[test]
213	fn test_add_with_carry() {
214		const TABLE_SZ: usize = 1 << 2;
215
216		let mut cs = ConstraintSystem::new();
217		let mut table = cs.add_table("u32_add_with_carry test");
218
219		let xin = table.add_committed::<B1, 32>("xin");
220		let yin = table.add_committed::<B1, 32>("yin");
221		let carry_in = table.add_committed::<B1, 32>("carry_in");
222
223		let flags = U32AddFlags {
224			carry_in_bit: Some(carry_in),
225			expose_final_carry: true,
226			commit_zout: false,
227		};
228		let adder = U32Add::new(&mut table, xin, yin, flags);
229		let table_id = table.id();
230		let statement = Statement {
231			boundaries: vec![],
232			table_sizes: vec![TABLE_SZ],
233		};
234		let allocator = Bump::new();
235		let mut witness =
236			WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
237
238		let table_witness = witness.init_table(table_id, TABLE_SZ).unwrap();
239		let mut segment = table_witness.full_segment();
240
241		// The test vector that contains interesting cases.
242		//
243		// (x, y, carry_in, zout, final_carry)
244		let test_vector = [
245			(0xFFFFFFFF, 0x00000001, 0x00000000, 0x00000000, true),
246			(0xFFFFFFFF, 0x00000000, 0x00000000, 0xFFFFFFFF, false),
247			(0x7FFFFFFF, 0x00000001, 0x00000000, 0x80000000, false),
248			(0xFFFF0000, 0x0000FFFF, 0x00000001, 0x00000000, true),
249		];
250		assert_eq!(test_vector.len(), segment.size());
251
252		{
253			// Populate the columns with the inputs from the test vector.
254			let mut xin_bits = segment.get_mut_as::<u32, _, 32>(adder.xin).unwrap();
255			let mut yin_bits = segment.get_mut_as::<u32, _, 32>(adder.yin).unwrap();
256			let mut carry_in_bits = segment.get_mut_as::<u32, _, 32>(carry_in).unwrap();
257			for (i, (x, y, carry, _, _)) in test_vector.iter().enumerate() {
258				xin_bits[i] = *x;
259				yin_bits[i] = *y;
260				carry_in_bits[i] = *carry;
261			}
262		}
263
264		// Populate the gadget
265		adder.populate(&mut segment).unwrap();
266
267		{
268			// Verify results
269			let zout_bits = segment.get_as::<u32, _, 32>(adder.zout).unwrap();
270			let final_carry = segment.get(adder.final_carry.unwrap()).unwrap();
271
272			for (i, (_, _, _, zout, expected_carry)) in test_vector.iter().enumerate() {
273				assert_eq!(zout_bits[i], *zout);
274
275				// Check final carry bit
276				assert_eq!(get_packed_slice(&final_carry, i), B1::from(*expected_carry));
277			}
278		}
279
280		// Validate constraint system
281		let ccs = cs.compile(&statement).unwrap();
282		let witness = witness.into_multilinear_extension_index();
283
284		binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
285	}
286}