binius_m3/gadgets/
u32.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_core::oracle::ShiftVariant;
4use binius_field::{packed::set_packed_slice, Field, PackedExtension, PackedFieldIndexable};
5
6use crate::builder::{column::Col, types::B1, witness::TableWitnessSegment, TableBuilder, B128};
7
8/// A gadget for performing 32-bit integer addition on vertically-packed bit columns.
9///
10/// This gadget has input columns `xin` and `yin` for the two 32-bit integers to be added, and an
11/// output column `zout`, and it constrains that `xin + yin = zout` as integers.
12#[derive(Debug)]
13pub struct U32Add {
14	// Inputs
15	pub xin: Col<B1, 32>,
16	pub yin: Col<B1, 32>,
17
18	// Private
19	cin: Col<B1, 32>,
20	cout: Col<B1, 32>,
21	cout_shl: Col<B1, 32>,
22
23	// Outputs
24	/// The output column, either committed if `flags.commit_zout` is set, otherwise a linear
25	/// combination derived column.
26	pub zout: Col<B1, 32>,
27	/// This is `Some` if `flags.expose_final_carry` is set, otherwise it is `None`.
28	pub final_carry: Option<Col<B1>>,
29	/// Flags modifying the gadget's behavior.
30	pub flags: U32AddFlags,
31}
32
33/// Flags modifying the behavior of the [`U32Add`] gadget.
34#[derive(Debug, Default, Clone)]
35pub struct U32AddFlags {
36	// Optionally a column for a dynamic carry in bit. This *must* be zero in all bits except the
37	// 0th.
38	pub carry_in_bit: Option<Col<B1, 32>>,
39	pub commit_zout: bool,
40	pub expose_final_carry: bool,
41}
42
43impl U32Add {
44	pub fn new(
45		table: &mut TableBuilder,
46		xin: Col<B1, 32>,
47		yin: Col<B1, 32>,
48		flags: U32AddFlags,
49	) -> Self {
50		let cout = table.add_committed::<B1, 32>("cout");
51		let cout_shl = table.add_shifted("cout_shl", cout, 5, 1, ShiftVariant::LogicalLeft);
52
53		let cin = if let Some(carry_in_bit) = flags.carry_in_bit {
54			table.add_computed("cin", cout_shl + carry_in_bit)
55		} else {
56			cout_shl
57		};
58
59		let final_carry = flags
60			.expose_final_carry
61			.then(|| table.add_selected("final_carry", cout, 31));
62
63		table.assert_zero("carry_out", (xin + cin) * (yin + cin) + cin - cout);
64
65		let zout = if flags.commit_zout {
66			let zout = table.add_committed::<B1, 32>("zout");
67			table.assert_zero("zout", xin + yin + cin - zout);
68			zout
69		} else {
70			table.add_computed("zout", xin + yin + cin)
71		};
72
73		Self {
74			xin,
75			yin,
76			cin,
77			cout,
78			cout_shl,
79			final_carry,
80			zout,
81			flags,
82		}
83	}
84
85	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
86	where
87		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
88	{
89		let xin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.xin)?;
90		let yin = index.get_mut_as(self.yin)?;
91		let mut cout = index.get_mut_as(self.cout)?;
92		let mut zout = index.get_mut_as(self.zout)?;
93		let mut final_carry = if let Some(final_carry) = self.final_carry {
94			let final_carry = index.get_mut(final_carry)?;
95			Some(final_carry)
96		} else {
97			None
98		};
99
100		if let Some(carry_in_bit_col) = self.flags.carry_in_bit {
101			// This is u32 assumed to be either 0 or 1.
102			let carry_in_bit = index.get_mut_as(carry_in_bit_col)?;
103
104			let mut cin = index.get_mut_as(self.cin)?;
105			let mut cout_shl = index.get_mut_as(self.cout_shl)?;
106			for i in 0..index.size() {
107				let (x_plus_y, carry0) = xin[i].overflowing_add(yin[i]);
108				let carry1;
109				(zout[i], carry1) = x_plus_y.overflowing_add(carry_in_bit[i]);
110				let carry = carry0 | carry1;
111
112				cin[i] = xin[i] ^ yin[i] ^ zout[i];
113				cout[i] = (carry as u32) << 31 | cin[i] >> 1;
114				cout_shl[i] = cout[i] << 1;
115
116				if let Some(ref mut final_carry) = final_carry {
117					set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
118				}
119			}
120		} else {
121			// When the carry in bit is fixed to zero, we can simplify the logic.
122			let mut cin = index.get_mut_as(self.cin)?;
123			for i in 0..index.size() {
124				let carry;
125				(zout[i], carry) = xin[i].overflowing_add(yin[i]);
126				cin[i] = xin[i] ^ yin[i] ^ zout[i];
127				cout[i] = (carry as u32) << 31 | cin[i] >> 1;
128				if let Some(ref mut final_carry) = final_carry {
129					set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
130				}
131			}
132		};
133		Ok(())
134	}
135}