binius_circuits/lasso/
u32add.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use anyhow::Result;
6use binius_core::oracle::{OracleId, ShiftVariant};
7use binius_field::{
8	as_packed_field::{PackScalar, PackedType},
9	packed::set_packed_slice,
10	underlier::U1,
11	BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, PackedFieldIndexable, TowerField,
12};
13use itertools::izip;
14
15use super::lasso::lasso;
16use crate::{
17	builder::{
18		types::{F, U},
19		ConstraintSystemBuilder,
20	},
21	pack::pack,
22};
23
24const ADD_T_LOG_SIZE: usize = 17;
25
26type B1 = BinaryField1b;
27type B8 = BinaryField8b;
28type B32 = BinaryField32b;
29
30pub fn u32add<FInput, FOutput>(
31	builder: &mut ConstraintSystemBuilder,
32	name: impl ToString + Clone,
33	xin: OracleId,
34	yin: OracleId,
35) -> Result<OracleId, anyhow::Error>
36where
37	FInput: TowerField,
38	FOutput: TowerField,
39	U: PackScalar<FInput> + PackScalar<FOutput>,
40	B8: ExtensionField<FInput> + ExtensionField<FOutput>,
41	F: ExtensionField<FInput> + ExtensionField<FOutput>,
42{
43	let mut several = SeveralU32add::new(builder)?;
44	let sum = several.u32add::<FInput, FOutput>(builder, name.clone(), xin, yin)?;
45	several.finalize(builder, name)?;
46	Ok(sum)
47}
48
49pub struct SeveralU32add {
50	n_lookups: Vec<usize>,
51	lookup_t: OracleId,
52	lookups_u: Vec<[OracleId; 1]>,
53	u_to_t_mappings: Vec<Vec<usize>>,
54	finalized: bool,
55	_phantom: PhantomData<(U, F)>,
56}
57
58impl SeveralU32add {
59	pub fn new(builder: &mut ConstraintSystemBuilder) -> Result<Self> {
60		let lookup_t = builder.add_committed("lookup_t", ADD_T_LOG_SIZE, B32::TOWER_LEVEL);
61
62		if let Some(witness) = builder.witness() {
63			let mut lookup_t_witness = witness.new_column::<B32>(lookup_t);
64
65			let lookup_t_scalars =
66				PackedType::<U, B32>::unpack_scalars_mut(lookup_t_witness.packed());
67
68			for (i, lookup_t) in lookup_t_scalars.iter_mut().enumerate() {
69				let x = (i >> 9) & 0xff;
70				let y = (i >> 1) & 0xff;
71				let cin = i & 1;
72				let ab_sum = x + y + cin;
73				let cout = ab_sum >> 8;
74				let ab_sum = ab_sum & 0xff;
75
76				let lookup_t_u32 =
77					(((((((cin << 1 | cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
78
79				*lookup_t = BinaryField32b::new(lookup_t_u32);
80			}
81		}
82		Ok(Self {
83			n_lookups: Vec::new(),
84			lookup_t,
85			lookups_u: Vec::new(),
86			u_to_t_mappings: Vec::new(),
87			finalized: false,
88			_phantom: PhantomData,
89		})
90	}
91
92	pub fn u32add<FInput, FOutput>(
93		&mut self,
94		builder: &mut ConstraintSystemBuilder,
95		name: impl ToString,
96		xin: OracleId,
97		yin: OracleId,
98	) -> Result<OracleId, anyhow::Error>
99	where
100		FInput: TowerField,
101		FOutput: TowerField,
102		U: PackScalar<FInput> + PackScalar<FOutput>,
103		F: ExtensionField<FInput> + ExtensionField<FOutput>,
104		B8: ExtensionField<FInput> + ExtensionField<FOutput>,
105	{
106		builder.push_namespace(name);
107
108		let input_log_size = builder.log_rows([xin, yin])?;
109
110		let b8_log_size = input_log_size - B8::TOWER_LEVEL + FInput::TOWER_LEVEL;
111
112		let output_log_size = input_log_size - FOutput::TOWER_LEVEL + FInput::TOWER_LEVEL;
113
114		let sum = builder.add_committed("sum", output_log_size, FOutput::TOWER_LEVEL);
115
116		let sum_packed = if FInput::TOWER_LEVEL == B8::TOWER_LEVEL {
117			sum
118		} else {
119			builder.add_packed("lasso sum packed", sum, B8::TOWER_LEVEL - FInput::TOWER_LEVEL)?
120		};
121
122		let cout = builder.add_committed("cout", b8_log_size, B1::TOWER_LEVEL);
123
124		let cin = builder.add_shifted("cin", cout, 1, 2, ShiftVariant::LogicalLeft)?;
125
126		let xin_u8 = pack::<FInput, B8>(xin, builder, "repacked xin")?;
127		let yin_u8 = pack::<FInput, B8>(yin, builder, "repacked yin")?;
128
129		let lookup_u = builder.add_linear_combination(
130			"lookup_u",
131			b8_log_size,
132			[
133				(cin, <F as TowerField>::basis(0, 25)?),
134				(cout, <F as TowerField>::basis(0, 24)?),
135				(xin_u8, <F as TowerField>::basis(3, 2)?),
136				(yin_u8, <F as TowerField>::basis(3, 1)?),
137				(sum_packed, <F as TowerField>::basis(3, 0)?),
138			],
139		)?;
140
141		if let Some(witness) = builder.witness() {
142			let mut sum_witness = witness.new_column::<FOutput>(sum);
143			let mut cin_witness = witness.new_column::<B1>(cin);
144			let mut cout_witness = witness.new_column::<B1>(cout);
145			let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
146			let mut u_to_t_mapping_witness = vec![0; 1 << (b8_log_size)];
147
148			let x_ints = witness.get::<B8>(xin_u8)?.as_slice::<u8>();
149			let y_ints = witness.get::<B8>(yin_u8)?.as_slice::<u8>();
150
151			let sum_scalars = sum_witness.as_mut_slice::<u8>();
152			let packed_slice_cin = cin_witness.packed();
153			let packed_slice_cout = cout_witness.packed();
154			let lookup_u_scalars =
155				PackedType::<U, B32>::unpack_scalars_mut(lookup_u_witness.packed());
156
157			let mut temp_cout = 0;
158
159			for (i, (x, y, sum, lookup_u, u_to_t)) in izip!(
160				x_ints,
161				y_ints,
162				sum_scalars.iter_mut(),
163				lookup_u_scalars.iter_mut(),
164				u_to_t_mapping_witness.iter_mut()
165			)
166			.enumerate()
167			{
168				let x = *x as usize;
169				let y = *y as usize;
170
171				let cin = if i % 4 == 0 { 0 } else { temp_cout };
172
173				let xy_sum = x + y + cin;
174
175				temp_cout = xy_sum >> 8;
176
177				set_packed_slice(packed_slice_cin, i, BinaryField1b::new(U1::new(cin as u8)));
178				set_packed_slice(
179					packed_slice_cout,
180					i,
181					BinaryField1b::new(U1::new(temp_cout as u8)),
182				);
183
184				*u_to_t = (x << 8 | y) << 1 | cin;
185
186				let ab_sum = xy_sum & 0xff;
187
188				*sum = xy_sum as u8;
189
190				let lookup_u_u32 =
191					(((((((cin << 1 | temp_cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
192
193				*lookup_u = B32::new(lookup_u_u32);
194			}
195
196			std::mem::drop(sum_witness);
197
198			let sum_packed_witness = witness.get::<FOutput>(sum)?;
199
200			witness.set::<B8>(sum_packed, sum_packed_witness.repacked::<B8>())?;
201
202			self.u_to_t_mappings.push(u_to_t_mapping_witness)
203		}
204
205		self.lookups_u.push([lookup_u]);
206		self.n_lookups.push(1 << b8_log_size);
207
208		builder.pop_namespace();
209		Ok(sum)
210	}
211
212	pub fn finalize(
213		mut self,
214		builder: &mut ConstraintSystemBuilder,
215		name: impl ToString,
216	) -> Result<()> {
217		let channel = builder.add_channel();
218		self.finalized = true;
219		lasso::<B32>(
220			builder,
221			name,
222			&self.n_lookups,
223			&self.u_to_t_mappings,
224			&self.lookups_u,
225			[self.lookup_t],
226			channel,
227		)
228	}
229}
230
231impl Drop for SeveralU32add {
232	fn drop(&mut self) {
233		assert!(self.finalized)
234	}
235}
236
237#[cfg(test)]
238mod tests {
239	use binius_field::{BinaryField1b, BinaryField8b};
240
241	use super::SeveralU32add;
242	use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained};
243
244	#[test]
245	fn test_several_lasso_u32add() {
246		test_circuit(|builder| {
247			let mut several_u32_add = SeveralU32add::new(builder).unwrap();
248			for log_size in [11, 12, 13] {
249				// BinaryField8b is used here because we utilize an 8x8x1→8 table
250				let add_a_u8 = unconstrained::<BinaryField8b>(builder, "add_a", log_size).unwrap();
251				let add_b_u8 = unconstrained::<BinaryField8b>(builder, "add_b", log_size).unwrap();
252				let _sum = several_u32_add
253					.u32add::<BinaryField8b, BinaryField8b>(
254						builder,
255						"lasso_u32add",
256						add_a_u8,
257						add_b_u8,
258					)
259					.unwrap();
260			}
261			several_u32_add.finalize(builder, "lasso_u32add").unwrap();
262			Ok(vec![])
263		})
264		.unwrap();
265	}
266
267	#[test]
268	fn test_lasso_u32add() {
269		test_circuit(|builder| {
270			let log_size = 14;
271			let add_a = unconstrained::<BinaryField1b>(builder, "add_a", log_size)?;
272			let add_b = unconstrained::<BinaryField1b>(builder, "add_b", log_size)?;
273			let _sum = super::u32add::<BinaryField1b, BinaryField1b>(
274				builder,
275				"lasso_u32add",
276				add_a,
277				add_b,
278			)?;
279			Ok(vec![])
280		})
281		.unwrap();
282	}
283}