1use std::marker::PhantomData;
4
5use anyhow::Result;
6use binius_core::oracle::{OracleId, ShiftVariant};
7use binius_field::{
8 as_packed_field::{PackScalar, PackedType},
9 packed::set_packed_slice,
10 underlier::U1,
11 BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, PackedFieldIndexable, TowerField,
12};
13use itertools::izip;
14
15use super::lasso::lasso;
16use crate::{
17 builder::{
18 types::{F, U},
19 ConstraintSystemBuilder,
20 },
21 pack::pack,
22};
23
24const ADD_T_LOG_SIZE: usize = 17;
25
26type B1 = BinaryField1b;
27type B8 = BinaryField8b;
28type B32 = BinaryField32b;
29
30pub fn u32add<FInput, FOutput>(
31 builder: &mut ConstraintSystemBuilder,
32 name: impl ToString + Clone,
33 xin: OracleId,
34 yin: OracleId,
35) -> Result<OracleId, anyhow::Error>
36where
37 FInput: TowerField,
38 FOutput: TowerField,
39 U: PackScalar<FInput> + PackScalar<FOutput>,
40 B8: ExtensionField<FInput> + ExtensionField<FOutput>,
41 F: ExtensionField<FInput> + ExtensionField<FOutput>,
42{
43 let mut several = SeveralU32add::new(builder)?;
44 let sum = several.u32add::<FInput, FOutput>(builder, name.clone(), xin, yin)?;
45 several.finalize(builder, name)?;
46 Ok(sum)
47}
48
49pub struct SeveralU32add {
50 n_lookups: Vec<usize>,
51 lookup_t: OracleId,
52 lookups_u: Vec<[OracleId; 1]>,
53 u_to_t_mappings: Vec<Vec<usize>>,
54 finalized: bool,
55 _phantom: PhantomData<(U, F)>,
56}
57
58impl SeveralU32add {
59 pub fn new(builder: &mut ConstraintSystemBuilder) -> Result<Self> {
60 let lookup_t = builder.add_committed("lookup_t", ADD_T_LOG_SIZE, B32::TOWER_LEVEL);
61
62 if let Some(witness) = builder.witness() {
63 let mut lookup_t_witness = witness.new_column::<B32>(lookup_t);
64
65 let lookup_t_scalars =
66 PackedType::<U, B32>::unpack_scalars_mut(lookup_t_witness.packed());
67
68 for (i, lookup_t) in lookup_t_scalars.iter_mut().enumerate() {
69 let x = (i >> 9) & 0xff;
70 let y = (i >> 1) & 0xff;
71 let cin = i & 1;
72 let ab_sum = x + y + cin;
73 let cout = ab_sum >> 8;
74 let ab_sum = ab_sum & 0xff;
75
76 let lookup_t_u32 =
77 (((((((cin << 1 | cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
78
79 *lookup_t = BinaryField32b::new(lookup_t_u32);
80 }
81 }
82 Ok(Self {
83 n_lookups: Vec::new(),
84 lookup_t,
85 lookups_u: Vec::new(),
86 u_to_t_mappings: Vec::new(),
87 finalized: false,
88 _phantom: PhantomData,
89 })
90 }
91
92 pub fn u32add<FInput, FOutput>(
93 &mut self,
94 builder: &mut ConstraintSystemBuilder,
95 name: impl ToString,
96 xin: OracleId,
97 yin: OracleId,
98 ) -> Result<OracleId, anyhow::Error>
99 where
100 FInput: TowerField,
101 FOutput: TowerField,
102 U: PackScalar<FInput> + PackScalar<FOutput>,
103 F: ExtensionField<FInput> + ExtensionField<FOutput>,
104 B8: ExtensionField<FInput> + ExtensionField<FOutput>,
105 {
106 builder.push_namespace(name);
107
108 let input_log_size = builder.log_rows([xin, yin])?;
109
110 let b8_log_size = input_log_size - B8::TOWER_LEVEL + FInput::TOWER_LEVEL;
111
112 let output_log_size = input_log_size - FOutput::TOWER_LEVEL + FInput::TOWER_LEVEL;
113
114 let sum = builder.add_committed("sum", output_log_size, FOutput::TOWER_LEVEL);
115
116 let sum_packed = if FInput::TOWER_LEVEL == B8::TOWER_LEVEL {
117 sum
118 } else {
119 builder.add_packed("lasso sum packed", sum, B8::TOWER_LEVEL - FInput::TOWER_LEVEL)?
120 };
121
122 let cout = builder.add_committed("cout", b8_log_size, B1::TOWER_LEVEL);
123
124 let cin = builder.add_shifted("cin", cout, 1, 2, ShiftVariant::LogicalLeft)?;
125
126 let xin_u8 = pack::<FInput, B8>(xin, builder, "repacked xin")?;
127 let yin_u8 = pack::<FInput, B8>(yin, builder, "repacked yin")?;
128
129 let lookup_u = builder.add_linear_combination(
130 "lookup_u",
131 b8_log_size,
132 [
133 (cin, <F as TowerField>::basis(0, 25)?),
134 (cout, <F as TowerField>::basis(0, 24)?),
135 (xin_u8, <F as TowerField>::basis(3, 2)?),
136 (yin_u8, <F as TowerField>::basis(3, 1)?),
137 (sum_packed, <F as TowerField>::basis(3, 0)?),
138 ],
139 )?;
140
141 if let Some(witness) = builder.witness() {
142 let mut sum_witness = witness.new_column::<FOutput>(sum);
143 let mut cin_witness = witness.new_column::<B1>(cin);
144 let mut cout_witness = witness.new_column::<B1>(cout);
145 let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
146 let mut u_to_t_mapping_witness = vec![0; 1 << (b8_log_size)];
147
148 let x_ints = witness.get::<B8>(xin_u8)?.as_slice::<u8>();
149 let y_ints = witness.get::<B8>(yin_u8)?.as_slice::<u8>();
150
151 let sum_scalars = sum_witness.as_mut_slice::<u8>();
152 let packed_slice_cin = cin_witness.packed();
153 let packed_slice_cout = cout_witness.packed();
154 let lookup_u_scalars =
155 PackedType::<U, B32>::unpack_scalars_mut(lookup_u_witness.packed());
156
157 let mut temp_cout = 0;
158
159 for (i, (x, y, sum, lookup_u, u_to_t)) in izip!(
160 x_ints,
161 y_ints,
162 sum_scalars.iter_mut(),
163 lookup_u_scalars.iter_mut(),
164 u_to_t_mapping_witness.iter_mut()
165 )
166 .enumerate()
167 {
168 let x = *x as usize;
169 let y = *y as usize;
170
171 let cin = if i % 4 == 0 { 0 } else { temp_cout };
172
173 let xy_sum = x + y + cin;
174
175 temp_cout = xy_sum >> 8;
176
177 set_packed_slice(packed_slice_cin, i, BinaryField1b::new(U1::new(cin as u8)));
178 set_packed_slice(
179 packed_slice_cout,
180 i,
181 BinaryField1b::new(U1::new(temp_cout as u8)),
182 );
183
184 *u_to_t = (x << 8 | y) << 1 | cin;
185
186 let ab_sum = xy_sum & 0xff;
187
188 *sum = xy_sum as u8;
189
190 let lookup_u_u32 =
191 (((((((cin << 1 | temp_cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
192
193 *lookup_u = B32::new(lookup_u_u32);
194 }
195
196 std::mem::drop(sum_witness);
197
198 let sum_packed_witness = witness.get::<FOutput>(sum)?;
199
200 witness.set::<B8>(sum_packed, sum_packed_witness.repacked::<B8>())?;
201
202 self.u_to_t_mappings.push(u_to_t_mapping_witness)
203 }
204
205 self.lookups_u.push([lookup_u]);
206 self.n_lookups.push(1 << b8_log_size);
207
208 builder.pop_namespace();
209 Ok(sum)
210 }
211
212 pub fn finalize(
213 mut self,
214 builder: &mut ConstraintSystemBuilder,
215 name: impl ToString,
216 ) -> Result<()> {
217 let channel = builder.add_channel();
218 self.finalized = true;
219 lasso::<B32>(
220 builder,
221 name,
222 &self.n_lookups,
223 &self.u_to_t_mappings,
224 &self.lookups_u,
225 [self.lookup_t],
226 channel,
227 )
228 }
229}
230
231impl Drop for SeveralU32add {
232 fn drop(&mut self) {
233 assert!(self.finalized)
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use binius_field::{BinaryField1b, BinaryField8b};
240
241 use super::SeveralU32add;
242 use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained};
243
244 #[test]
245 fn test_several_lasso_u32add() {
246 test_circuit(|builder| {
247 let mut several_u32_add = SeveralU32add::new(builder).unwrap();
248 for log_size in [11, 12, 13] {
249 let add_a_u8 = unconstrained::<BinaryField8b>(builder, "add_a", log_size).unwrap();
251 let add_b_u8 = unconstrained::<BinaryField8b>(builder, "add_b", log_size).unwrap();
252 let _sum = several_u32_add
253 .u32add::<BinaryField8b, BinaryField8b>(
254 builder,
255 "lasso_u32add",
256 add_a_u8,
257 add_b_u8,
258 )
259 .unwrap();
260 }
261 several_u32_add.finalize(builder, "lasso_u32add").unwrap();
262 Ok(vec![])
263 })
264 .unwrap();
265 }
266
267 #[test]
268 fn test_lasso_u32add() {
269 test_circuit(|builder| {
270 let log_size = 14;
271 let add_a = unconstrained::<BinaryField1b>(builder, "add_a", log_size)?;
272 let add_b = unconstrained::<BinaryField1b>(builder, "add_b", log_size)?;
273 let _sum = super::u32add::<BinaryField1b, BinaryField1b>(
274 builder,
275 "lasso_u32add",
276 add_a,
277 add_b,
278 )?;
279 Ok(vec![])
280 })
281 .unwrap();
282 }
283}