1use std::marker::PhantomData;
4
5use anyhow::Result;
6use binius_core::oracle::{OracleId, ShiftVariant};
7use binius_field::{
8 as_packed_field::PackScalar,
9 packed::{packed_from_fn_with_offset, set_packed_slice},
10 underlier::U1,
11 BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, TowerField,
12};
13use itertools::izip;
14
15use super::lasso::lasso;
16use crate::{
17 builder::{
18 types::{F, U},
19 ConstraintSystemBuilder,
20 },
21 pack::pack,
22};
23
24const ADD_T_LOG_SIZE: usize = 17;
25
26type B1 = BinaryField1b;
27type B8 = BinaryField8b;
28type B32 = BinaryField32b;
29
30pub fn u32add<FInput, FOutput>(
31 builder: &mut ConstraintSystemBuilder,
32 name: impl ToString + Clone,
33 xin: OracleId,
34 yin: OracleId,
35) -> Result<OracleId, anyhow::Error>
36where
37 FInput: TowerField,
38 FOutput: TowerField,
39 U: PackScalar<FInput> + PackScalar<FOutput>,
40 B8: ExtensionField<FInput> + ExtensionField<FOutput>,
41 F: ExtensionField<FInput> + ExtensionField<FOutput>,
42{
43 let mut several = SeveralU32add::new(builder)?;
44 let sum = several.u32add::<FInput, FOutput>(builder, name.clone(), xin, yin)?;
45 several.finalize(builder, name)?;
46 Ok(sum)
47}
48
49pub struct SeveralU32add {
50 n_lookups: Vec<usize>,
51 lookup_t: OracleId,
52 lookups_u: Vec<[OracleId; 1]>,
53 u_to_t_mappings: Vec<Vec<usize>>,
54 finalized: bool,
55 _phantom: PhantomData<(U, F)>,
56}
57
58impl SeveralU32add {
59 pub fn new(builder: &mut ConstraintSystemBuilder) -> Result<Self> {
60 let lookup_t = builder.add_committed("lookup_t", ADD_T_LOG_SIZE, B32::TOWER_LEVEL);
61
62 if let Some(witness) = builder.witness() {
63 let mut lookup_t_witness = witness.new_column::<B32>(lookup_t);
64
65 let lookup_t = lookup_t_witness.packed();
66 for (i, lookup_t) in lookup_t.iter_mut().enumerate() {
67 *lookup_t = packed_from_fn_with_offset(i, |i| {
68 let x = (i >> 9) & 0xff;
69 let y = (i >> 1) & 0xff;
70 let cin = i & 1;
71 let ab_sum = x + y + cin;
72 let cout = ab_sum >> 8;
73 let ab_sum = ab_sum & 0xff;
74
75 let lookup_t_u32 =
76 (((((((cin << 1 | cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
77
78 BinaryField32b::new(lookup_t_u32)
79 });
80 }
81 }
82 Ok(Self {
83 n_lookups: Vec::new(),
84 lookup_t,
85 lookups_u: Vec::new(),
86 u_to_t_mappings: Vec::new(),
87 finalized: false,
88 _phantom: PhantomData,
89 })
90 }
91
92 pub fn u32add<FInput, FOutput>(
93 &mut self,
94 builder: &mut ConstraintSystemBuilder,
95 name: impl ToString,
96 xin: OracleId,
97 yin: OracleId,
98 ) -> Result<OracleId, anyhow::Error>
99 where
100 FInput: TowerField,
101 FOutput: TowerField,
102 U: PackScalar<FInput> + PackScalar<FOutput>,
103 F: ExtensionField<FInput> + ExtensionField<FOutput>,
104 B8: ExtensionField<FInput> + ExtensionField<FOutput>,
105 {
106 builder.push_namespace(name);
107
108 let input_log_size = builder.log_rows([xin, yin])?;
109
110 let b8_log_size = input_log_size - B8::TOWER_LEVEL + FInput::TOWER_LEVEL;
111
112 let output_log_size = input_log_size - FOutput::TOWER_LEVEL + FInput::TOWER_LEVEL;
113
114 let sum = builder.add_committed("sum", output_log_size, FOutput::TOWER_LEVEL);
115
116 let sum_packed = if FInput::TOWER_LEVEL == B8::TOWER_LEVEL {
117 sum
118 } else {
119 builder.add_packed("lasso sum packed", sum, B8::TOWER_LEVEL - FInput::TOWER_LEVEL)?
120 };
121
122 let cout = builder.add_committed("cout", b8_log_size, B1::TOWER_LEVEL);
123
124 let cin = builder.add_shifted("cin", cout, 1, 2, ShiftVariant::LogicalLeft)?;
125
126 let xin_u8 = pack::<FInput, B8>(xin, builder, "repacked xin")?;
127 let yin_u8 = pack::<FInput, B8>(yin, builder, "repacked yin")?;
128
129 let lookup_u = builder.add_linear_combination(
130 "lookup_u",
131 b8_log_size,
132 [
133 (cin, <F as TowerField>::basis(0, 25)?),
134 (cout, <F as TowerField>::basis(0, 24)?),
135 (xin_u8, <F as TowerField>::basis(3, 2)?),
136 (yin_u8, <F as TowerField>::basis(3, 1)?),
137 (sum_packed, <F as TowerField>::basis(3, 0)?),
138 ],
139 )?;
140
141 if let Some(witness) = builder.witness() {
142 let mut sum_witness = witness.new_column::<FOutput>(sum);
143 let mut cin_witness = witness.new_column::<B1>(cin);
144 let mut cout_witness = witness.new_column::<B1>(cout);
145 let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
146 let mut u_to_t_mapping_witness = vec![0; 1 << (b8_log_size)];
147
148 let x_ints = witness.get::<B8>(xin_u8)?.as_slice::<u8>();
149 let y_ints = witness.get::<B8>(yin_u8)?.as_slice::<u8>();
150
151 let sum_scalars = sum_witness.as_mut_slice::<u8>();
152 let packed_slice_cin = cin_witness.packed();
153 let packed_slice_cout = cout_witness.packed();
154 let lookup_u = lookup_u_witness.packed();
155
156 let mut temp_cout = 0;
157
158 for (i, (x, y, sum, u_to_t)) in
159 izip!(x_ints, y_ints, sum_scalars.iter_mut(), u_to_t_mapping_witness.iter_mut())
160 .enumerate()
161 {
162 let x = *x as usize;
163 let y = *y as usize;
164
165 let cin = if i % 4 == 0 { 0 } else { temp_cout };
166
167 let xy_sum = x + y + cin;
168
169 temp_cout = xy_sum >> 8;
170
171 set_packed_slice(packed_slice_cin, i, BinaryField1b::new(U1::new(cin as u8)));
172 set_packed_slice(
173 packed_slice_cout,
174 i,
175 BinaryField1b::new(U1::new(temp_cout as u8)),
176 );
177
178 *u_to_t = (x << 8 | y) << 1 | cin;
179
180 let ab_sum = xy_sum & 0xff;
181
182 *sum = xy_sum as u8;
183
184 let lookup_u_u32 =
185 (((((((cin << 1 | temp_cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
186
187 set_packed_slice(lookup_u, i, B32::new(lookup_u_u32));
188 }
189
190 std::mem::drop(sum_witness);
191
192 let sum_packed_witness = witness.get::<FOutput>(sum)?;
193
194 witness.set::<B8>(sum_packed, sum_packed_witness.repacked::<B8>())?;
195
196 self.u_to_t_mappings.push(u_to_t_mapping_witness)
197 }
198
199 self.lookups_u.push([lookup_u]);
200 self.n_lookups.push(1 << b8_log_size);
201
202 builder.pop_namespace();
203 Ok(sum)
204 }
205
206 pub fn finalize(
207 mut self,
208 builder: &mut ConstraintSystemBuilder,
209 name: impl ToString,
210 ) -> Result<()> {
211 let channel = builder.add_channel();
212 self.finalized = true;
213 lasso::<B32>(
214 builder,
215 name,
216 &self.n_lookups,
217 &self.u_to_t_mappings,
218 &self.lookups_u,
219 [self.lookup_t],
220 channel,
221 )
222 }
223}
224
225impl Drop for SeveralU32add {
226 fn drop(&mut self) {
227 assert!(self.finalized)
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use binius_field::{BinaryField1b, BinaryField8b};
234
235 use super::SeveralU32add;
236 use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained};
237
238 #[test]
239 fn test_several_lasso_u32add() {
240 test_circuit(|builder| {
241 let mut several_u32_add = SeveralU32add::new(builder).unwrap();
242 for log_size in [11, 12, 13] {
243 let add_a_u8 = unconstrained::<BinaryField8b>(builder, "add_a", log_size).unwrap();
245 let add_b_u8 = unconstrained::<BinaryField8b>(builder, "add_b", log_size).unwrap();
246 let _sum = several_u32_add
247 .u32add::<BinaryField8b, BinaryField8b>(
248 builder,
249 "lasso_u32add",
250 add_a_u8,
251 add_b_u8,
252 )
253 .unwrap();
254 }
255 several_u32_add.finalize(builder, "lasso_u32add").unwrap();
256 Ok(vec![])
257 })
258 .unwrap();
259 }
260
261 #[test]
262 fn test_lasso_u32add() {
263 test_circuit(|builder| {
264 let log_size = 14;
265 let add_a = unconstrained::<BinaryField1b>(builder, "add_a", log_size)?;
266 let add_b = unconstrained::<BinaryField1b>(builder, "add_b", log_size)?;
267 let _sum = super::u32add::<BinaryField1b, BinaryField1b>(
268 builder,
269 "lasso_u32add",
270 add_a,
271 add_b,
272 )?;
273 Ok(vec![])
274 })
275 .unwrap();
276 }
277}