binius_circuits/lasso/big_integer_ops/
byte_sliced_add.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b};
7
8use crate::{
9	builder::ConstraintSystemBuilder,
10	lasso::{batch::LookupBatch, u8add},
11};
12
13type B1 = BinaryField1b;
14type B8 = BinaryField8b;
15
16pub fn byte_sliced_add<Level: TowerLevel<Data<OracleId>: Sized>>(
17	builder: &mut ConstraintSystemBuilder,
18	name: impl ToString + Clone,
19	x_in: &Level::Data<OracleId>,
20	y_in: &Level::Data<OracleId>,
21	carry_in: OracleId,
22	log_size: usize,
23	lookup_batch_add: &mut LookupBatch,
24) -> Result<(OracleId, Level::Data<OracleId>), anyhow::Error> {
25	if Level::WIDTH == 1 {
26		let (carry_out, sum) =
27			u8add(builder, lookup_batch_add, name, x_in[0], y_in[0], carry_in, log_size)?;
28		let mut sum_arr = Level::default();
29		sum_arr[0] = sum;
30		return Ok((carry_out, sum_arr));
31	}
32
33	builder.push_namespace(name);
34
35	let (lower_half_x, upper_half_x) = Level::split(x_in);
36	let (lower_half_y, upper_half_y) = Level::split(y_in);
37
38	let (internal_carry, lower_sum) = byte_sliced_add::<Level::Base>(
39		builder,
40		format!("lower sum {}b", Level::Base::WIDTH),
41		lower_half_x,
42		lower_half_y,
43		carry_in,
44		log_size,
45		lookup_batch_add,
46	)?;
47
48	let (carry_out, upper_sum) = byte_sliced_add::<Level::Base>(
49		builder,
50		format!("upper sum {}b", Level::Base::WIDTH),
51		upper_half_x,
52		upper_half_y,
53		internal_carry,
54		log_size,
55		lookup_batch_add,
56	)?;
57
58	let sum = Level::join(&lower_sum, &upper_sum);
59
60	// Everything below is for test assertions
61	if let Some(witness) = builder.witness() {
62		let x_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
63			let this_byte_oracle = x_in[this_byte_idx];
64			witness
65				.get::<B8>(this_byte_oracle)
66				.unwrap()
67				.as_slice::<u8>()
68		});
69
70		let y_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
71			let this_byte_oracle = y_in[this_byte_idx];
72			witness
73				.get::<B8>(this_byte_oracle)
74				.unwrap()
75				.as_slice::<u8>()
76		});
77
78		let sum_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
79			let this_byte_oracle = sum[this_byte_idx];
80			witness
81				.get::<B8>(this_byte_oracle)
82				.unwrap()
83				.as_slice::<u8>()
84		});
85
86		let cin_as_u8_packed = witness.get::<B1>(carry_in).unwrap().as_slice::<u8>();
87
88		let cout_as_u8_packed = witness.get::<B1>(carry_out).unwrap().as_slice::<u8>();
89
90		for row_idx in 0..1 << log_size {
91			let mut x_u512 = U512::ZERO;
92			for (byte_idx, x_byte_column) in x_bytes_as_u8.clone().enumerate() {
93				x_u512 |= U512::from(x_byte_column[row_idx]) << (8 * byte_idx);
94			}
95
96			let mut y_u512 = U512::ZERO;
97			for (byte_idx, y_byte_column) in y_bytes_as_u8.clone().enumerate() {
98				y_u512 |= U512::from(y_byte_column[row_idx]) << (8 * byte_idx);
99			}
100
101			let mut sum_u512 = U512::ZERO;
102			for (byte_idx, sum_byte_column) in sum_bytes_as_u8.clone().enumerate() {
103				sum_u512 |= U512::from(sum_byte_column[row_idx]) << (8 * byte_idx);
104			}
105
106			let cin_u512 = U512::from((cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
107
108			let cout_u512 = U512::from((cout_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
109
110			let expected_sum_u128 = x_u512 + y_u512 + cin_u512;
111
112			let sum_according_to_witness = if cout_u512 == U512::ZERO {
113				sum_u512
114			} else {
115				sum_u512 | (cout_u512 << (Level::WIDTH * 8))
116			};
117
118			assert_eq!(expected_sum_u128, sum_according_to_witness);
119		}
120	}
121	builder.pop_namespace();
122
123	Ok((carry_out, sum))
124}