binius_circuits/lasso/big_integer_ops/
byte_sliced_add_carryfree.rs1use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b};
7
8use super::byte_sliced_add;
9use crate::{
10 builder::ConstraintSystemBuilder,
11 lasso::{batch::LookupBatch, u8add_carryfree},
12};
13
14type B1 = BinaryField1b;
15type B8 = BinaryField8b;
16
17#[allow(clippy::too_many_arguments)]
18pub fn byte_sliced_add_carryfree<Level: TowerLevel<Data<OracleId>: Sized>>(
19 builder: &mut ConstraintSystemBuilder,
20 name: impl ToString,
21 x_in: &Level::Data<OracleId>,
22 y_in: &Level::Data<OracleId>,
23 carry_in: OracleId,
24 log_size: usize,
25 lookup_batch_add: &mut LookupBatch,
26 lookup_batch_add_carryfree: &mut LookupBatch,
27) -> Result<Level::Data<OracleId>, anyhow::Error> {
28 if Level::WIDTH == 1 {
29 let sum = u8add_carryfree(
30 builder,
31 lookup_batch_add_carryfree,
32 "u8 carryfree add",
33 x_in[0],
34 y_in[0],
35 carry_in,
36 log_size,
37 )?;
38 let mut sum_arr = Level::default();
39 sum_arr[0] = sum;
40 return Ok(sum_arr);
41 }
42
43 builder.push_namespace(name);
44
45 let (lower_half_x, upper_half_x) = Level::split(x_in);
46 let (lower_half_y, upper_half_y) = Level::split(y_in);
47
48 let (internal_carry, lower_sum) = byte_sliced_add::<Level::Base>(
49 builder,
50 format!("lower sum {}b", Level::Base::WIDTH),
51 lower_half_x,
52 lower_half_y,
53 carry_in,
54 log_size,
55 lookup_batch_add,
56 )?;
57
58 let upper_sum = byte_sliced_add_carryfree::<Level::Base>(
59 builder,
60 format!("upper sum {}b", Level::Base::WIDTH),
61 upper_half_x,
62 upper_half_y,
63 internal_carry,
64 log_size,
65 lookup_batch_add,
66 lookup_batch_add_carryfree,
67 )?;
68
69 let sum = Level::join(&lower_sum, &upper_sum);
70
71 if let Some(witness) = builder.witness() {
73 let x_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
74 let this_byte_oracle = x_in[this_byte_idx];
75 witness
76 .get::<B8>(this_byte_oracle)
77 .unwrap()
78 .as_slice::<u8>()
79 });
80
81 let y_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
82 let this_byte_oracle = y_in[this_byte_idx];
83 witness
84 .get::<B8>(this_byte_oracle)
85 .unwrap()
86 .as_slice::<u8>()
87 });
88
89 let sum_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
90 let this_byte_oracle = sum[this_byte_idx];
91 witness
92 .get::<B8>(this_byte_oracle)
93 .unwrap()
94 .as_slice::<u8>()
95 });
96
97 let cin_as_u8_packed = witness.get::<B1>(carry_in).unwrap().as_slice::<u8>();
98
99 for row_idx in 0..1 << log_size {
100 let mut x_u512 = U512::ZERO;
101 for (byte_idx, x_byte_column) in x_bytes_as_u8.clone().enumerate() {
102 x_u512 |= U512::from(x_byte_column[row_idx]) << (8 * byte_idx);
103 }
104
105 let mut y_u512 = U512::ZERO;
106 for (byte_idx, y_byte_column) in y_bytes_as_u8.clone().enumerate() {
107 y_u512 |= U512::from(y_byte_column[row_idx]) << (8 * byte_idx);
108 }
109
110 let mut sum_u512 = U512::ZERO;
111 for (byte_idx, sum_byte_column) in sum_bytes_as_u8.clone().enumerate() {
112 sum_u512 |= U512::from(sum_byte_column[row_idx]) << (8 * byte_idx);
113 }
114
115 let cin_u512 = U512::from((cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
116
117 let expected_sum_u512 = x_u512 + y_u512 + cin_u512;
118
119 assert_eq!(expected_sum_u512, sum_u512);
120 }
121 }
122 builder.pop_namespace();
123
124 Ok(sum)
125}