binius_circuits/lasso/big_integer_ops/
byte_sliced_add.rs1use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b};
7
8use crate::{
9 builder::ConstraintSystemBuilder,
10 lasso::{batch::LookupBatch, u8add},
11};
12
13type B1 = BinaryField1b;
14type B8 = BinaryField8b;
15
16pub fn byte_sliced_add<Level: TowerLevel<Data<OracleId>: Sized>>(
17 builder: &mut ConstraintSystemBuilder,
18 name: impl ToString + Clone,
19 x_in: &Level::Data<OracleId>,
20 y_in: &Level::Data<OracleId>,
21 carry_in: OracleId,
22 log_size: usize,
23 lookup_batch_add: &mut LookupBatch,
24) -> Result<(OracleId, Level::Data<OracleId>), anyhow::Error> {
25 if Level::WIDTH == 1 {
26 let (carry_out, sum) =
27 u8add(builder, lookup_batch_add, name, x_in[0], y_in[0], carry_in, log_size)?;
28 let mut sum_arr = Level::default();
29 sum_arr[0] = sum;
30 return Ok((carry_out, sum_arr));
31 }
32
33 builder.push_namespace(name);
34
35 let (lower_half_x, upper_half_x) = Level::split(x_in);
36 let (lower_half_y, upper_half_y) = Level::split(y_in);
37
38 let (internal_carry, lower_sum) = byte_sliced_add::<Level::Base>(
39 builder,
40 format!("lower sum {}b", Level::Base::WIDTH),
41 lower_half_x,
42 lower_half_y,
43 carry_in,
44 log_size,
45 lookup_batch_add,
46 )?;
47
48 let (carry_out, upper_sum) = byte_sliced_add::<Level::Base>(
49 builder,
50 format!("upper sum {}b", Level::Base::WIDTH),
51 upper_half_x,
52 upper_half_y,
53 internal_carry,
54 log_size,
55 lookup_batch_add,
56 )?;
57
58 let sum = Level::join(&lower_sum, &upper_sum);
59
60 if let Some(witness) = builder.witness() {
62 let x_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
63 let this_byte_oracle = x_in[this_byte_idx];
64 witness
65 .get::<B8>(this_byte_oracle)
66 .unwrap()
67 .as_slice::<u8>()
68 });
69
70 let y_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
71 let this_byte_oracle = y_in[this_byte_idx];
72 witness
73 .get::<B8>(this_byte_oracle)
74 .unwrap()
75 .as_slice::<u8>()
76 });
77
78 let sum_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
79 let this_byte_oracle = sum[this_byte_idx];
80 witness
81 .get::<B8>(this_byte_oracle)
82 .unwrap()
83 .as_slice::<u8>()
84 });
85
86 let cin_as_u8_packed = witness.get::<B1>(carry_in).unwrap().as_slice::<u8>();
87
88 let cout_as_u8_packed = witness.get::<B1>(carry_out).unwrap().as_slice::<u8>();
89
90 for row_idx in 0..1 << log_size {
91 let mut x_u512 = U512::ZERO;
92 for (byte_idx, x_byte_column) in x_bytes_as_u8.clone().enumerate() {
93 x_u512 |= U512::from(x_byte_column[row_idx]) << (8 * byte_idx);
94 }
95
96 let mut y_u512 = U512::ZERO;
97 for (byte_idx, y_byte_column) in y_bytes_as_u8.clone().enumerate() {
98 y_u512 |= U512::from(y_byte_column[row_idx]) << (8 * byte_idx);
99 }
100
101 let mut sum_u512 = U512::ZERO;
102 for (byte_idx, sum_byte_column) in sum_bytes_as_u8.clone().enumerate() {
103 sum_u512 |= U512::from(sum_byte_column[row_idx]) << (8 * byte_idx);
104 }
105
106 let cin_u512 = U512::from((cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
107
108 let cout_u512 = U512::from((cout_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
109
110 let expected_sum_u128 = x_u512 + y_u512 + cin_u512;
111
112 let sum_according_to_witness = if cout_u512 == U512::ZERO {
113 sum_u512
114 } else {
115 sum_u512 | (cout_u512 << (Level::WIDTH * 8))
116 };
117
118 assert_eq!(expected_sum_u128, sum_according_to_witness);
119 }
120 }
121 builder.pop_namespace();
122
123 Ok((carry_out, sum))
124}