1use binius_core::oracle::ShiftVariant;
4use binius_field::{packed::set_packed_slice, Field, PackedExtension, PackedFieldIndexable};
5
6use crate::builder::{column::Col, types::B1, witness::TableWitnessSegment, TableBuilder, B128};
7
8#[derive(Debug)]
13pub struct U32Add {
14 pub xin: Col<B1, 32>,
16 pub yin: Col<B1, 32>,
17
18 cin: Col<B1, 32>,
20 cout: Col<B1, 32>,
21 cout_shl: Col<B1, 32>,
22
23 pub zout: Col<B1, 32>,
27 pub final_carry: Option<Col<B1>>,
29 pub flags: U32AddFlags,
31}
32
33#[derive(Debug, Default, Clone)]
35pub struct U32AddFlags {
36 pub carry_in_bit: Option<Col<B1, 32>>,
39 pub commit_zout: bool,
40 pub expose_final_carry: bool,
41}
42
43impl U32Add {
44 pub fn new(
45 table: &mut TableBuilder,
46 xin: Col<B1, 32>,
47 yin: Col<B1, 32>,
48 flags: U32AddFlags,
49 ) -> Self {
50 let cout = table.add_committed::<B1, 32>("cout");
51 let cout_shl = table.add_shifted("cout_shl", cout, 5, 1, ShiftVariant::LogicalLeft);
52
53 let cin = if let Some(carry_in_bit) = flags.carry_in_bit {
54 table.add_computed("cin", cout_shl + carry_in_bit)
55 } else {
56 cout_shl
57 };
58
59 let final_carry = flags
60 .expose_final_carry
61 .then(|| table.add_selected("final_carry", cout, 31));
62
63 table.assert_zero("carry_out", (xin + cin) * (yin + cin) + cin - cout);
64
65 let zout = if flags.commit_zout {
66 let zout = table.add_committed::<B1, 32>("zout");
67 table.assert_zero("zout", xin + yin + cin - zout);
68 zout
69 } else {
70 table.add_computed("zout", xin + yin + cin)
71 };
72
73 Self {
74 xin,
75 yin,
76 cin,
77 cout,
78 cout_shl,
79 final_carry,
80 zout,
81 flags,
82 }
83 }
84
85 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
86 where
87 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
88 {
89 let xin: std::cell::RefMut<'_, [u32]> = index.get_mut_as(self.xin)?;
90 let yin = index.get_mut_as(self.yin)?;
91 let mut cout = index.get_mut_as(self.cout)?;
92 let mut zout = index.get_mut_as(self.zout)?;
93 let mut final_carry = if let Some(final_carry) = self.final_carry {
94 let final_carry = index.get_mut(final_carry)?;
95 Some(final_carry)
96 } else {
97 None
98 };
99
100 if let Some(carry_in_bit_col) = self.flags.carry_in_bit {
101 let carry_in_bit = index.get_mut_as(carry_in_bit_col)?;
103
104 let mut cin = index.get_mut_as(self.cin)?;
105 let mut cout_shl = index.get_mut_as(self.cout_shl)?;
106 for i in 0..index.size() {
107 let (x_plus_y, carry0) = xin[i].overflowing_add(yin[i]);
108 let carry1;
109 (zout[i], carry1) = x_plus_y.overflowing_add(carry_in_bit[i]);
110 let carry = carry0 | carry1;
111
112 cin[i] = xin[i] ^ yin[i] ^ zout[i];
113 cout[i] = (carry as u32) << 31 | cin[i] >> 1;
114 cout_shl[i] = cout[i] << 1;
115
116 if let Some(ref mut final_carry) = final_carry {
117 set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
118 }
119 }
120 } else {
121 let mut cin = index.get_mut_as(self.cin)?;
123 for i in 0..index.size() {
124 let carry;
125 (zout[i], carry) = xin[i].overflowing_add(yin[i]);
126 cin[i] = xin[i] ^ yin[i] ^ zout[i];
127 cout[i] = (carry as u32) << 31 | cin[i] >> 1;
128 if let Some(ref mut final_carry) = final_carry {
129 set_packed_slice(&mut *final_carry, i, if carry { B1::ONE } else { B1::ZERO });
130 }
131 }
132 };
133 Ok(())
134 }
135}