1use binius_core::oracle::ShiftVariant;
4use binius_field::{packed::set_packed_slice, Field, PackedExtension, PackedFieldIndexable};
5
6use crate::builder::{column::Col, types::B1, witness::TableWitnessSegment, TableBuilder, B128};
7
8#[derive(Debug)]
13pub struct U32Add {
14 pub xin: Col<B1, 32>,
16 pub yin: Col<B1, 32>,
17
18 cin: Col<B1, 32>,
20 cout: Col<B1, 32>,
21 cout_shl: Col<B1, 32>,
22
23 pub zout: Col<B1, 32>,
27 pub final_carry: Option<Col<B1>>,
29 pub flags: U32AddFlags,
31}
32
33#[derive(Debug, Default, Clone)]
35pub struct U32AddFlags {
36 pub carry_in_bit: Option<Col<B1, 32>>,
39 pub commit_zout: bool,
40 pub expose_final_carry: bool,
41}
42
43impl U32Add {
44 pub fn new(
45 table: &mut TableBuilder,
46 xin: Col<B1, 32>,
47 yin: Col<B1, 32>,
48 flags: U32AddFlags,
49 ) -> Self {
50 let cout = table.add_committed::<B1, 32>("cout");
51 let cout_shl = table.add_shifted("cout_shl", cout, 5, 1, ShiftVariant::LogicalLeft);
52
53 let cin = if let Some(carry_in_bit) = flags.carry_in_bit {
54 table.add_computed("cin", cout_shl + carry_in_bit)
55 } else {
56 cout_shl
57 };
58
59 let final_carry = flags
60 .expose_final_carry
61 .then(|| table.add_selected("final_carry", cout, 31));
62
63 table.assert_zero("carry_out", (xin + cin) * (yin + cin) + cin - cout);
64
65 let zout = if flags.commit_zout {
66 let zout = table.add_committed::<B1, 32>("zout");
67 table.assert_zero("zout", xin + yin + cin - zout);
68 zout
69 } else {
70 table.add_computed("zout", xin + yin + cin)
71 };
72
73 Self {
74 xin,
75 yin,
76 cin,
77 cout,
78 cout_shl,
79 final_carry,
80 zout,
81 flags,
82 }
83 }
84
85 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
86 where
87 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
88 {
89 let xin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.xin)?;
90 let yin = index.get_mut_as(self.yin)?;
91 let mut cout = index.get_mut_as(self.cout)?;
92 let mut zout = index.get_mut_as(self.zout)?;
93 let mut final_carry = if let Some(final_carry) = self.final_carry {
94 let final_carry = index.get_mut(final_carry)?;
95 Some(final_carry)
96 } else {
97 None
98 };
99
100 if let Some(carry_in_bit_col) = self.flags.carry_in_bit {
101 let carry_in_bit = index.get_mut_as(carry_in_bit_col)?;
103
104 let mut cin = index.get_mut_as(self.cin)?;
105 let mut cout_shl = index.get_mut_as(self.cout_shl)?;
106 for i in 0..index.size() {
107 let (x_plus_y, carry0) = xin[i].overflowing_add(yin[i]);
108 let carry1;
109 (zout[i], carry1) = x_plus_y.overflowing_add(carry_in_bit[i]);
110 let carry = carry0 | carry1;
111
112 cin[i] = xin[i] ^ yin[i] ^ zout[i];
113 cout[i] = (carry as u32) << 31 | cin[i] >> 1;
114 cout_shl[i] = cout[i] << 1;
115
116 if let Some(ref mut final_carry) = final_carry {
117 set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
118 }
119 }
120 } else {
121 let mut cin = index.get_mut_as(self.cin)?;
123 for i in 0..index.size() {
124 let carry;
125 (zout[i], carry) = xin[i].overflowing_add(yin[i]);
126 cin[i] = xin[i] ^ yin[i] ^ zout[i];
127 cout[i] = (carry as u32) << 31 | cin[i] >> 1;
128 if let Some(ref mut final_carry) = final_carry {
129 set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
130 }
131 }
132 };
133 Ok(())
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use binius_field::{
140 arch::OptimalUnderlier128b, as_packed_field::PackedType, packed::get_packed_slice,
141 };
142 use bumpalo::Bump;
143 use rand::{prelude::StdRng, Rng as _, SeedableRng};
144
145 use super::*;
146 use crate::builder::{ConstraintSystem, Statement, WitnessIndex};
147
148 #[test]
149 fn test_basic() {
150 const TABLE_SZ: usize = 1 << 14;
151
152 let mut cs = ConstraintSystem::new();
153 let mut table = cs.add_table("u32_add test");
154
155 let xin = table.add_committed::<B1, 32>("xin");
156 let yin = table.add_committed::<B1, 32>("yin");
157
158 let adder = U32Add::new(&mut table, xin, yin, U32AddFlags::default());
159 let table_id = table.id();
160 let statement = Statement {
161 boundaries: vec![],
162 table_sizes: vec![TABLE_SZ],
163 };
164 let allocator = Bump::new();
165 let mut witness =
166 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
167
168 let table_witness = witness.init_table(table_id, TABLE_SZ).unwrap();
169 let mut segment = table_witness.full_segment();
170
171 let mut rng = StdRng::seed_from_u64(0);
172
173 let test_vector: Vec<(u32, u32, u32)> = (0..segment.size())
177 .map(|_| {
178 let x = rng.gen::<u32>();
179 let y = rng.gen::<u32>();
180 let z = x.wrapping_add(y);
181 (x, y, z)
182 })
183 .collect();
184
185 {
186 let mut xin_bits = segment.get_mut_as::<u32, _, 32>(adder.xin).unwrap();
187 let mut yin_bits = segment.get_mut_as::<u32, _, 32>(adder.yin).unwrap();
188 for (i, (x, y, _)) in test_vector.iter().enumerate() {
189 xin_bits[i] = *x;
190 yin_bits[i] = *y;
191 }
192 }
193
194 adder.populate(&mut segment).unwrap();
196
197 {
198 let zout_bits = segment.get_as::<u32, _, 32>(adder.zout).unwrap();
200 for (i, (_, _, z)) in test_vector.iter().enumerate() {
201 assert_eq!(zout_bits[i], *z);
202 }
203 }
204
205 let ccs = cs.compile(&statement).unwrap();
207 let witness = witness.into_multilinear_extension_index();
208
209 binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
210 }
211
212 #[test]
213 fn test_add_with_carry() {
214 const TABLE_SZ: usize = 1 << 2;
215
216 let mut cs = ConstraintSystem::new();
217 let mut table = cs.add_table("u32_add_with_carry test");
218
219 let xin = table.add_committed::<B1, 32>("xin");
220 let yin = table.add_committed::<B1, 32>("yin");
221 let carry_in = table.add_committed::<B1, 32>("carry_in");
222
223 let flags = U32AddFlags {
224 carry_in_bit: Some(carry_in),
225 expose_final_carry: true,
226 commit_zout: false,
227 };
228 let adder = U32Add::new(&mut table, xin, yin, flags);
229 let table_id = table.id();
230 let statement = Statement {
231 boundaries: vec![],
232 table_sizes: vec![TABLE_SZ],
233 };
234 let allocator = Bump::new();
235 let mut witness =
236 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
237
238 let table_witness = witness.init_table(table_id, TABLE_SZ).unwrap();
239 let mut segment = table_witness.full_segment();
240
241 let test_vector = [
245 (0xFFFFFFFF, 0x00000001, 0x00000000, 0x00000000, true),
246 (0xFFFFFFFF, 0x00000000, 0x00000000, 0xFFFFFFFF, false),
247 (0x7FFFFFFF, 0x00000001, 0x00000000, 0x80000000, false),
248 (0xFFFF0000, 0x0000FFFF, 0x00000001, 0x00000000, true),
249 ];
250 assert_eq!(test_vector.len(), segment.size());
251
252 {
253 let mut xin_bits = segment.get_mut_as::<u32, _, 32>(adder.xin).unwrap();
255 let mut yin_bits = segment.get_mut_as::<u32, _, 32>(adder.yin).unwrap();
256 let mut carry_in_bits = segment.get_mut_as::<u32, _, 32>(carry_in).unwrap();
257 for (i, (x, y, carry, _, _)) in test_vector.iter().enumerate() {
258 xin_bits[i] = *x;
259 yin_bits[i] = *y;
260 carry_in_bits[i] = *carry;
261 }
262 }
263
264 adder.populate(&mut segment).unwrap();
266
267 {
268 let zout_bits = segment.get_as::<u32, _, 32>(adder.zout).unwrap();
270 let final_carry = segment.get(adder.final_carry.unwrap()).unwrap();
271
272 for (i, (_, _, _, zout, expected_carry)) in test_vector.iter().enumerate() {
273 assert_eq!(zout_bits[i], *zout);
274
275 assert_eq!(get_packed_slice(&final_carry, i), B1::from(*expected_carry));
277 }
278 }
279
280 let ccs = cs.compile(&statement).unwrap();
282 let witness = witness.into_multilinear_extension_index();
283
284 binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
285 }
286}