binius_circuits/lasso/
u32add.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::marker::PhantomData;
4
5use anyhow::Result;
6use binius_core::oracle::{OracleId, ShiftVariant};
7use binius_field::{
8	as_packed_field::PackScalar,
9	packed::{packed_from_fn_with_offset, set_packed_slice},
10	underlier::U1,
11	BinaryField1b, BinaryField32b, BinaryField8b, ExtensionField, TowerField,
12};
13use itertools::izip;
14
15use super::lasso::lasso;
16use crate::{
17	builder::{
18		types::{F, U},
19		ConstraintSystemBuilder,
20	},
21	pack::pack,
22};
23
24const ADD_T_LOG_SIZE: usize = 17;
25
26type B1 = BinaryField1b;
27type B8 = BinaryField8b;
28type B32 = BinaryField32b;
29
30pub fn u32add<FInput, FOutput>(
31	builder: &mut ConstraintSystemBuilder,
32	name: impl ToString + Clone,
33	xin: OracleId,
34	yin: OracleId,
35) -> Result<OracleId, anyhow::Error>
36where
37	FInput: TowerField,
38	FOutput: TowerField,
39	U: PackScalar<FInput> + PackScalar<FOutput>,
40	B8: ExtensionField<FInput> + ExtensionField<FOutput>,
41	F: ExtensionField<FInput> + ExtensionField<FOutput>,
42{
43	let mut several = SeveralU32add::new(builder)?;
44	let sum = several.u32add::<FInput, FOutput>(builder, name.clone(), xin, yin)?;
45	several.finalize(builder, name)?;
46	Ok(sum)
47}
48
49pub struct SeveralU32add {
50	n_lookups: Vec<usize>,
51	lookup_t: OracleId,
52	lookups_u: Vec<[OracleId; 1]>,
53	u_to_t_mappings: Vec<Vec<usize>>,
54	finalized: bool,
55	_phantom: PhantomData<(U, F)>,
56}
57
58impl SeveralU32add {
59	pub fn new(builder: &mut ConstraintSystemBuilder) -> Result<Self> {
60		let lookup_t = builder.add_committed("lookup_t", ADD_T_LOG_SIZE, B32::TOWER_LEVEL);
61
62		if let Some(witness) = builder.witness() {
63			let mut lookup_t_witness = witness.new_column::<B32>(lookup_t);
64
65			let lookup_t = lookup_t_witness.packed();
66			for (i, lookup_t) in lookup_t.iter_mut().enumerate() {
67				*lookup_t = packed_from_fn_with_offset(i, |i| {
68					let x = (i >> 9) & 0xff;
69					let y = (i >> 1) & 0xff;
70					let cin = i & 1;
71					let ab_sum = x + y + cin;
72					let cout = ab_sum >> 8;
73					let ab_sum = ab_sum & 0xff;
74
75					let lookup_t_u32 =
76						(((((((cin << 1 | cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
77
78					BinaryField32b::new(lookup_t_u32)
79				});
80			}
81		}
82		Ok(Self {
83			n_lookups: Vec::new(),
84			lookup_t,
85			lookups_u: Vec::new(),
86			u_to_t_mappings: Vec::new(),
87			finalized: false,
88			_phantom: PhantomData,
89		})
90	}
91
92	pub fn u32add<FInput, FOutput>(
93		&mut self,
94		builder: &mut ConstraintSystemBuilder,
95		name: impl ToString,
96		xin: OracleId,
97		yin: OracleId,
98	) -> Result<OracleId, anyhow::Error>
99	where
100		FInput: TowerField,
101		FOutput: TowerField,
102		U: PackScalar<FInput> + PackScalar<FOutput>,
103		F: ExtensionField<FInput> + ExtensionField<FOutput>,
104		B8: ExtensionField<FInput> + ExtensionField<FOutput>,
105	{
106		builder.push_namespace(name);
107
108		let input_log_size = builder.log_rows([xin, yin])?;
109
110		let b8_log_size = input_log_size - B8::TOWER_LEVEL + FInput::TOWER_LEVEL;
111
112		let output_log_size = input_log_size - FOutput::TOWER_LEVEL + FInput::TOWER_LEVEL;
113
114		let sum = builder.add_committed("sum", output_log_size, FOutput::TOWER_LEVEL);
115
116		let sum_packed = if FInput::TOWER_LEVEL == B8::TOWER_LEVEL {
117			sum
118		} else {
119			builder.add_packed("lasso sum packed", sum, B8::TOWER_LEVEL - FInput::TOWER_LEVEL)?
120		};
121
122		let cout = builder.add_committed("cout", b8_log_size, B1::TOWER_LEVEL);
123
124		let cin = builder.add_shifted("cin", cout, 1, 2, ShiftVariant::LogicalLeft)?;
125
126		let xin_u8 = pack::<FInput, B8>(xin, builder, "repacked xin")?;
127		let yin_u8 = pack::<FInput, B8>(yin, builder, "repacked yin")?;
128
129		let lookup_u = builder.add_linear_combination(
130			"lookup_u",
131			b8_log_size,
132			[
133				(cin, <F as TowerField>::basis(0, 25)?),
134				(cout, <F as TowerField>::basis(0, 24)?),
135				(xin_u8, <F as TowerField>::basis(3, 2)?),
136				(yin_u8, <F as TowerField>::basis(3, 1)?),
137				(sum_packed, <F as TowerField>::basis(3, 0)?),
138			],
139		)?;
140
141		if let Some(witness) = builder.witness() {
142			let mut sum_witness = witness.new_column::<FOutput>(sum);
143			let mut cin_witness = witness.new_column::<B1>(cin);
144			let mut cout_witness = witness.new_column::<B1>(cout);
145			let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
146			let mut u_to_t_mapping_witness = vec![0; 1 << (b8_log_size)];
147
148			let x_ints = witness.get::<B8>(xin_u8)?.as_slice::<u8>();
149			let y_ints = witness.get::<B8>(yin_u8)?.as_slice::<u8>();
150
151			let sum_scalars = sum_witness.as_mut_slice::<u8>();
152			let packed_slice_cin = cin_witness.packed();
153			let packed_slice_cout = cout_witness.packed();
154			let lookup_u = lookup_u_witness.packed();
155
156			let mut temp_cout = 0;
157
158			for (i, (x, y, sum, u_to_t)) in
159				izip!(x_ints, y_ints, sum_scalars.iter_mut(), u_to_t_mapping_witness.iter_mut())
160					.enumerate()
161			{
162				let x = *x as usize;
163				let y = *y as usize;
164
165				let cin = if i % 4 == 0 { 0 } else { temp_cout };
166
167				let xy_sum = x + y + cin;
168
169				temp_cout = xy_sum >> 8;
170
171				set_packed_slice(packed_slice_cin, i, BinaryField1b::new(U1::new(cin as u8)));
172				set_packed_slice(
173					packed_slice_cout,
174					i,
175					BinaryField1b::new(U1::new(temp_cout as u8)),
176				);
177
178				*u_to_t = (x << 8 | y) << 1 | cin;
179
180				let ab_sum = xy_sum & 0xff;
181
182				*sum = xy_sum as u8;
183
184				let lookup_u_u32 =
185					(((((((cin << 1 | temp_cout) << 8) | x) << 8) | y) << 8) | ab_sum) as u32;
186
187				set_packed_slice(lookup_u, i, B32::new(lookup_u_u32));
188			}
189
190			std::mem::drop(sum_witness);
191
192			let sum_packed_witness = witness.get::<FOutput>(sum)?;
193
194			witness.set::<B8>(sum_packed, sum_packed_witness.repacked::<B8>())?;
195
196			self.u_to_t_mappings.push(u_to_t_mapping_witness)
197		}
198
199		self.lookups_u.push([lookup_u]);
200		self.n_lookups.push(1 << b8_log_size);
201
202		builder.pop_namespace();
203		Ok(sum)
204	}
205
206	pub fn finalize(
207		mut self,
208		builder: &mut ConstraintSystemBuilder,
209		name: impl ToString,
210	) -> Result<()> {
211		let channel = builder.add_channel();
212		self.finalized = true;
213		lasso::<B32>(
214			builder,
215			name,
216			&self.n_lookups,
217			&self.u_to_t_mappings,
218			&self.lookups_u,
219			[self.lookup_t],
220			channel,
221		)
222	}
223}
224
225impl Drop for SeveralU32add {
226	fn drop(&mut self) {
227		assert!(self.finalized)
228	}
229}
230
231#[cfg(test)]
232mod tests {
233	use binius_field::{BinaryField1b, BinaryField8b};
234
235	use super::SeveralU32add;
236	use crate::{builder::test_utils::test_circuit, unconstrained::unconstrained};
237
238	#[test]
239	fn test_several_lasso_u32add() {
240		test_circuit(|builder| {
241			let mut several_u32_add = SeveralU32add::new(builder).unwrap();
242			for log_size in [11, 12, 13] {
243				// BinaryField8b is used here because we utilize an 8x8x1→8 table
244				let add_a_u8 = unconstrained::<BinaryField8b>(builder, "add_a", log_size).unwrap();
245				let add_b_u8 = unconstrained::<BinaryField8b>(builder, "add_b", log_size).unwrap();
246				let _sum = several_u32_add
247					.u32add::<BinaryField8b, BinaryField8b>(
248						builder,
249						"lasso_u32add",
250						add_a_u8,
251						add_b_u8,
252					)
253					.unwrap();
254			}
255			several_u32_add.finalize(builder, "lasso_u32add").unwrap();
256			Ok(vec![])
257		})
258		.unwrap();
259	}
260
261	#[test]
262	fn test_lasso_u32add() {
263		test_circuit(|builder| {
264			let log_size = 14;
265			let add_a = unconstrained::<BinaryField1b>(builder, "add_a", log_size)?;
266			let add_b = unconstrained::<BinaryField1b>(builder, "add_b", log_size)?;
267			let _sum = super::u32add::<BinaryField1b, BinaryField1b>(
268				builder,
269				"lasso_u32add",
270				add_a,
271				add_b,
272			)?;
273			Ok(vec![])
274		})
275		.unwrap();
276	}
277}