binius_circuits/arithmetic/
u32.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_core::{
4	oracle::{OracleId, ProjectionVariant, ShiftVariant},
5	transparent::MultilinearExtensionTransparent,
6};
7use binius_field::{
8	as_packed_field::PackedType, packed::set_packed_slice, underlier::WithUnderlier, BinaryField1b,
9	BinaryField32b, Field, PackedField, TowerField,
10};
11use binius_macros::arith_expr;
12use binius_maybe_rayon::prelude::*;
13use binius_utils::checked_arithmetics::checked_log_2;
14use bytemuck::{pod_collect_to_vec, Pod};
15
16use crate::{
17	builder::{
18		types::{F, U},
19		ConstraintSystemBuilder,
20	},
21	transparent,
22};
23
24type B1 = BinaryField1b;
25type B32 = BinaryField32b;
26
27pub fn packed(
28	builder: &mut ConstraintSystemBuilder,
29	name: impl ToString,
30	input: OracleId,
31) -> Result<OracleId, anyhow::Error> {
32	let packed = builder.add_packed(name, input, 5)?;
33	if let Some(witness) = builder.witness() {
34		witness.set(packed, witness.get::<B1>(input)?.repacked::<B32>())?;
35	}
36	Ok(packed)
37}
38
39pub fn mul_const(
40	builder: &mut ConstraintSystemBuilder,
41	name: impl ToString,
42	input: OracleId,
43	value: u32,
44	flags: super::Flags,
45) -> Result<OracleId, anyhow::Error> {
46	if value == 0 {
47		let log_rows = builder.log_rows([input])?;
48		return transparent::constant(builder, name, log_rows, B1::ZERO);
49	}
50
51	if value == 1 {
52		return Ok(input);
53	}
54
55	builder.push_namespace(name);
56	let mut tmp = value;
57	let mut offset = 0;
58	let mut result = input;
59	let mut first = true;
60	while tmp != 0 {
61		if tmp & 1 == 1 {
62			let shifted = shl(builder, format!("input_shl{offset}"), input, offset)?;
63			if first {
64				result = shifted;
65				first = false;
66			} else {
67				result = add(builder, format!("add_shl{offset}"), result, shifted, flags)?;
68			}
69		}
70		tmp >>= 1;
71		if tmp != 0 {
72			offset += 1;
73		}
74	}
75
76	if matches!(flags, super::Flags::Checked) {
77		// Shift overflow checking
78		for i in 32 - offset..32 {
79			let x = select_bit(builder, format!("bit{i}"), input, i)?;
80			builder.assert_zero("overflow", [x], arith_expr!([x] = x).convert_field());
81		}
82	}
83
84	builder.pop_namespace();
85	Ok(result)
86}
87
88pub fn add(
89	builder: &mut ConstraintSystemBuilder,
90	name: impl ToString,
91	xin: OracleId,
92	yin: OracleId,
93	flags: super::Flags,
94) -> Result<OracleId, anyhow::Error> {
95	builder.push_namespace(name);
96	let log_rows = builder.log_rows([xin, yin])?;
97	let cout = builder.add_committed("cout", log_rows, B1::TOWER_LEVEL);
98	let cin = builder.add_shifted("cin", cout, 1, 5, ShiftVariant::LogicalLeft)?;
99	let zout = builder.add_committed("zout", log_rows, B1::TOWER_LEVEL);
100
101	if let Some(witness) = builder.witness() {
102		(
103			witness.get::<B1>(xin)?.as_slice::<u32>(),
104			witness.get::<B1>(yin)?.as_slice::<u32>(),
105			witness.new_column::<B1>(zout).as_mut_slice::<u32>(),
106			witness.new_column::<B1>(cout).as_mut_slice::<u32>(),
107			witness.new_column::<B1>(cin).as_mut_slice::<u32>(),
108		)
109			.into_par_iter()
110			.for_each(|(xin, yin, zout, cout, cin)| {
111				let carry;
112				(*zout, carry) = (*xin).overflowing_add(*yin);
113				*cin = (*xin) ^ (*yin) ^ (*zout);
114				*cout = ((carry as u32) << 31) | (*cin >> 1);
115			});
116	}
117
118	builder.assert_zero(
119		"sum",
120		[xin, yin, cin, zout],
121		arith_expr!([xin, yin, cin, zout] = xin + yin + cin - zout).convert_field(),
122	);
123
124	builder.assert_zero(
125		"carry",
126		[xin, yin, cin, cout],
127		arith_expr!([xin, yin, cin, cout] = (xin + cin) * (yin + cin) + cin - cout).convert_field(),
128	);
129
130	// Overflow checking
131	if matches!(flags, super::Flags::Checked) {
132		let last_cout = select_bit(builder, "last_cout", cout, 31)?;
133		builder.assert_zero(
134			"overflow",
135			[last_cout],
136			arith_expr!([last_cout] = last_cout).convert_field(),
137		);
138	}
139
140	builder.pop_namespace();
141	Ok(zout)
142}
143
144pub fn sub(
145	builder: &mut ConstraintSystemBuilder,
146	name: impl ToString,
147	zin: OracleId,
148	yin: OracleId,
149	flags: super::Flags,
150) -> Result<OracleId, anyhow::Error> {
151	builder.push_namespace(name);
152	let log_rows = builder.log_rows([zin, yin])?;
153	let cout = builder.add_committed("cout", log_rows, B1::TOWER_LEVEL);
154	let cin = builder.add_shifted("cin", cout, 1, 5, ShiftVariant::LogicalLeft)?;
155	let xout = builder.add_committed("xin", log_rows, B1::TOWER_LEVEL);
156
157	if let Some(witness) = builder.witness() {
158		(
159			witness.get::<B1>(zin)?.as_slice::<u32>(),
160			witness.get::<B1>(yin)?.as_slice::<u32>(),
161			witness.new_column::<B1>(xout).as_mut_slice::<u32>(),
162			witness.new_column::<B1>(cout).as_mut_slice::<u32>(),
163			witness.new_column::<B1>(cin).as_mut_slice::<u32>(),
164		)
165			.into_par_iter()
166			.for_each(|(zout, yin, xin, cout, cin)| {
167				let carry;
168				(*xin, carry) = (*zout).overflowing_sub(*yin);
169				*cin = (*xin) ^ (*yin) ^ (*zout);
170				*cout = ((carry as u32) << 31) | (*cin >> 1);
171			});
172	}
173
174	builder.assert_zero(
175		"sum",
176		[xout, yin, cin, zin],
177		arith_expr!([xout, yin, cin, zin] = xout + yin + cin - zin).convert_field(),
178	);
179
180	builder.assert_zero(
181		"carry",
182		[xout, yin, cin, cout],
183		arith_expr!([xout, yin, cin, cout] = (xout + cin) * (yin + cin) + cin - cout)
184			.convert_field(),
185	);
186
187	// Underflow checking
188	if matches!(flags, super::Flags::Checked) {
189		let last_cout = select_bit(builder, "last_cout", cout, 31)?;
190		builder.assert_zero(
191			"underflow",
192			[last_cout],
193			arith_expr!([last_cout] = last_cout).convert_field(),
194		);
195	}
196
197	builder.pop_namespace();
198	Ok(xout)
199}
200
201pub fn half(
202	builder: &mut ConstraintSystemBuilder,
203	name: impl ToString,
204	input: OracleId,
205	flags: super::Flags,
206) -> Result<OracleId, anyhow::Error> {
207	if matches!(flags, super::Flags::Checked) {
208		// Assert that the number is even
209		let lsb = select_bit(builder, "lsb", input, 0)?;
210		builder.assert_zero("is_even", [lsb], arith_expr!([lsb] = lsb).convert_field());
211	}
212	shr(builder, name, input, 1)
213}
214
215pub fn shl(
216	builder: &mut ConstraintSystemBuilder,
217	name: impl ToString,
218	input: OracleId,
219	offset: usize,
220) -> Result<OracleId, anyhow::Error> {
221	if offset == 0 {
222		return Ok(input);
223	}
224
225	let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalLeft)?;
226	if let Some(witness) = builder.witness() {
227		(
228			witness.new_column::<B1>(shifted).as_mut_slice::<u32>(),
229			witness.get::<B1>(input)?.as_slice::<u32>(),
230		)
231			.into_par_iter()
232			.for_each(|(shifted, input)| *shifted = *input << offset);
233	}
234
235	Ok(shifted)
236}
237
238pub fn shr(
239	builder: &mut ConstraintSystemBuilder,
240	name: impl ToString,
241	input: OracleId,
242	offset: usize,
243) -> Result<OracleId, anyhow::Error> {
244	if offset == 0 {
245		return Ok(input);
246	}
247
248	let shifted = builder.add_shifted(name, input, offset, 5, ShiftVariant::LogicalRight)?;
249	if let Some(witness) = builder.witness() {
250		(
251			witness.new_column::<B1>(shifted).as_mut_slice::<u32>(),
252			witness.get::<B1>(input)?.as_slice::<u32>(),
253		)
254			.into_par_iter()
255			.for_each(|(shifted, input)| *shifted = *input >> offset);
256	}
257
258	Ok(shifted)
259}
260
261pub fn select_bit(
262	builder: &mut ConstraintSystemBuilder,
263	name: impl ToString,
264	input: OracleId,
265	index: usize,
266) -> Result<OracleId, anyhow::Error> {
267	let log_rows = builder.log_rows([input])?;
268	anyhow::ensure!(log_rows >= 5, "Polynomial must have n_vars >= 5. Got {log_rows}");
269	anyhow::ensure!(index < 32, "Only index values between 0 and 32 are allowed. Got {index}");
270
271	let query = binius_core::polynomial::test_utils::decompose_index_to_hypercube_point(5, index);
272	let bits = builder.add_projected(name, input, query, ProjectionVariant::FirstVars)?;
273
274	if let Some(witness) = builder.witness() {
275		let mut bits = witness.new_column::<B1>(bits);
276		let bits = bits.packed();
277		let input = witness.get::<B1>(input)?.as_slice::<u32>();
278		input.iter().enumerate().for_each(|(i, &val)| {
279			let value = match (val >> index) & 1 {
280				0 => B1::ZERO,
281				_ => B1::ONE,
282			};
283			set_packed_slice(bits, i, value);
284		});
285	}
286
287	Ok(bits)
288}
289
290pub fn constant(
291	builder: &mut ConstraintSystemBuilder,
292	name: impl ToString,
293	log_count: usize,
294	value: u32,
295) -> Result<OracleId, anyhow::Error> {
296	builder.push_namespace(name);
297	// This would not need to be committed if we had `builder.add_unpacked(..)`
298	let output = builder.add_committed("output", log_count + 5, B1::TOWER_LEVEL);
299	if let Some(witness) = builder.witness() {
300		witness.new_column::<B1>(output).as_mut_slice().fill(value);
301	}
302
303	let output_packed = builder.add_packed("output_packed", output, 5)?;
304	let transparent = builder.add_transparent(
305		"transparent",
306		binius_core::transparent::constant::Constant::new(log_count, B32::new(value)),
307	)?;
308	if let Some(witness) = builder.witness() {
309		let packed = witness.get::<B1>(output)?.repacked::<B32>();
310		witness.set(output_packed, packed)?;
311		witness.set(transparent, packed)?;
312	}
313	builder.assert_zero(
314		"unpack",
315		[output_packed, transparent],
316		arith_expr!([x, y] = x - y).convert_field(),
317	);
318	builder.pop_namespace();
319	Ok(output)
320}
321
322pub const LOG_U32_BITS: usize = checked_log_2(32);
323
324#[inline]
325fn into_packed_vec<P>(src: &[impl Pod]) -> Vec<P>
326where
327	P: PackedField + WithUnderlier,
328	P::Underlier: Pod,
329{
330	pod_collect_to_vec::<_, P::Underlier>(src)
331		.into_iter()
332		.map(P::from_underlier)
333		.collect()
334}
335
336pub fn u32const_repeating(
337	log_size: usize,
338	builder: &mut ConstraintSystemBuilder,
339	x: u32,
340	name: &str,
341) -> Result<OracleId, anyhow::Error> {
342	let brodcasted = vec![x; 1 << (PackedType::<U, B1>::LOG_WIDTH.saturating_sub(LOG_U32_BITS))];
343
344	let transparent_id = builder.add_transparent(
345		format!("transparent {}", name),
346		MultilinearExtensionTransparent::<_, PackedType<U, F>, _>::from_values(into_packed_vec::<
347			PackedType<U, B1>,
348		>(&brodcasted))?,
349	)?;
350
351	let repeating_id = builder.add_repeating(
352		format!("repeating {}", name),
353		transparent_id,
354		log_size - PackedType::<U, B1>::LOG_WIDTH,
355	)?;
356
357	if let Some(witness) = builder.witness() {
358		let mut transparent_witness = witness.new_column::<B1>(transparent_id);
359		transparent_witness.as_mut_slice::<u32>().fill(x);
360
361		let mut repeating_witness = witness.new_column::<B1>(repeating_id);
362		repeating_witness.as_mut_slice::<u32>().fill(x);
363	}
364
365	Ok(repeating_id)
366}
367
368#[cfg(test)]
369mod tests {
370	use binius_field::{BinaryField1b, TowerField};
371
372	use crate::{arithmetic, builder::test_utils::test_circuit, unconstrained::unconstrained};
373
374	#[test]
375	fn test_mul_const() {
376		test_circuit(|builder| {
377			let a = builder.add_committed("a", 5, BinaryField1b::TOWER_LEVEL);
378			if let Some(witness) = builder.witness() {
379				witness
380					.new_column::<BinaryField1b>(a)
381					.as_mut_slice::<u32>()
382					.iter_mut()
383					.for_each(|v| *v = 0b01000000_00000000_00000000_00000000u32);
384			}
385			let _c = arithmetic::u32::mul_const(builder, "mul3", a, 3, arithmetic::Flags::Checked)?;
386			Ok(vec![])
387		})
388		.unwrap();
389	}
390
391	#[test]
392	fn test_add() {
393		test_circuit(|builder| {
394			let log_size = 14;
395			let a = unconstrained::<BinaryField1b>(builder, "a", log_size)?;
396			let b = unconstrained::<BinaryField1b>(builder, "b", log_size)?;
397			let _c = arithmetic::u32::add(builder, "u32add", a, b, arithmetic::Flags::Unchecked)?;
398			Ok(vec![])
399		})
400		.unwrap();
401	}
402
403	#[test]
404	fn test_sub() {
405		test_circuit(|builder| {
406			let a = unconstrained::<BinaryField1b>(builder, "a", 7).unwrap();
407			let b = unconstrained::<BinaryField1b>(builder, "a", 7).unwrap();
408			let _c = arithmetic::u32::sub(builder, "c", a, b, arithmetic::Flags::Unchecked)?;
409			Ok(vec![])
410		})
411		.unwrap();
412	}
413}