binius_circuits/lasso/big_integer_ops/
byte_sliced_add_carryfree.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b};
7
8use super::byte_sliced_add;
9use crate::{
10	builder::ConstraintSystemBuilder,
11	lasso::{batch::LookupBatch, u8add_carryfree},
12};
13
14type B1 = BinaryField1b;
15type B8 = BinaryField8b;
16
17#[allow(clippy::too_many_arguments)]
18pub fn byte_sliced_add_carryfree<Level: TowerLevel<Data<OracleId>: Sized>>(
19	builder: &mut ConstraintSystemBuilder,
20	name: impl ToString,
21	x_in: &Level::Data<OracleId>,
22	y_in: &Level::Data<OracleId>,
23	carry_in: OracleId,
24	log_size: usize,
25	lookup_batch_add: &mut LookupBatch,
26	lookup_batch_add_carryfree: &mut LookupBatch,
27) -> Result<Level::Data<OracleId>, anyhow::Error> {
28	if Level::WIDTH == 1 {
29		let sum = u8add_carryfree(
30			builder,
31			lookup_batch_add_carryfree,
32			"u8 carryfree add",
33			x_in[0],
34			y_in[0],
35			carry_in,
36			log_size,
37		)?;
38		let mut sum_arr = Level::default();
39		sum_arr[0] = sum;
40		return Ok(sum_arr);
41	}
42
43	builder.push_namespace(name);
44
45	let (lower_half_x, upper_half_x) = Level::split(x_in);
46	let (lower_half_y, upper_half_y) = Level::split(y_in);
47
48	let (internal_carry, lower_sum) = byte_sliced_add::<Level::Base>(
49		builder,
50		format!("lower sum {}b", Level::Base::WIDTH),
51		lower_half_x,
52		lower_half_y,
53		carry_in,
54		log_size,
55		lookup_batch_add,
56	)?;
57
58	let upper_sum = byte_sliced_add_carryfree::<Level::Base>(
59		builder,
60		format!("upper sum {}b", Level::Base::WIDTH),
61		upper_half_x,
62		upper_half_y,
63		internal_carry,
64		log_size,
65		lookup_batch_add,
66		lookup_batch_add_carryfree,
67	)?;
68
69	let sum = Level::join(&lower_sum, &upper_sum);
70
71	// Everything below is for test assertions
72	if let Some(witness) = builder.witness() {
73		let x_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
74			let this_byte_oracle = x_in[this_byte_idx];
75			witness
76				.get::<B8>(this_byte_oracle)
77				.unwrap()
78				.as_slice::<u8>()
79		});
80
81		let y_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
82			let this_byte_oracle = y_in[this_byte_idx];
83			witness
84				.get::<B8>(this_byte_oracle)
85				.unwrap()
86				.as_slice::<u8>()
87		});
88
89		let sum_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
90			let this_byte_oracle = sum[this_byte_idx];
91			witness
92				.get::<B8>(this_byte_oracle)
93				.unwrap()
94				.as_slice::<u8>()
95		});
96
97		let cin_as_u8_packed = witness.get::<B1>(carry_in).unwrap().as_slice::<u8>();
98
99		for row_idx in 0..1 << log_size {
100			let mut x_u512 = U512::ZERO;
101			for (byte_idx, x_byte_column) in x_bytes_as_u8.clone().enumerate() {
102				x_u512 |= U512::from(x_byte_column[row_idx]) << (8 * byte_idx);
103			}
104
105			let mut y_u512 = U512::ZERO;
106			for (byte_idx, y_byte_column) in y_bytes_as_u8.clone().enumerate() {
107				y_u512 |= U512::from(y_byte_column[row_idx]) << (8 * byte_idx);
108			}
109
110			let mut sum_u512 = U512::ZERO;
111			for (byte_idx, sum_byte_column) in sum_bytes_as_u8.clone().enumerate() {
112				sum_u512 |= U512::from(sum_byte_column[row_idx]) << (8 * byte_idx);
113			}
114
115			let cin_u512 = U512::from((cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
116
117			let expected_sum_u512 = x_u512 + y_u512 + cin_u512;
118
119			assert_eq!(expected_sum_u512, sum_u512);
120		}
121	}
122	builder.pop_namespace();
123
124	Ok(sum)
125}