binius_circuits/lasso/big_integer_ops/
byte_sliced_double_conditional_increment.rs1use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b};
7
8use crate::{
9 builder::ConstraintSystemBuilder,
10 lasso::{batch::LookupBatch, u8_double_conditional_increment},
11};
12
13type B1 = BinaryField1b;
14type B8 = BinaryField8b;
15
16#[allow(clippy::too_many_arguments)]
17pub fn byte_sliced_double_conditional_increment<Level: TowerLevel<Data<OracleId>: Sized>>(
18 builder: &mut ConstraintSystemBuilder,
19 name: impl ToString,
20 x_in: &Level::Data<OracleId>,
21 first_carry_in: OracleId,
22 second_carry_in: OracleId,
23 log_size: usize,
24 zero_oracle_carry: usize,
25 lookup_batch_dci: &mut LookupBatch,
26) -> Result<(OracleId, Level::Data<OracleId>), anyhow::Error> {
27 if Level::WIDTH == 1 {
28 let (carry_out, sum) = u8_double_conditional_increment(
29 builder,
30 lookup_batch_dci,
31 "u8 DCI",
32 x_in[0],
33 first_carry_in,
34 second_carry_in,
35 log_size,
36 )?;
37 let mut sum_arr = Level::default();
38 sum_arr[0] = sum;
39 return Ok((carry_out, sum_arr));
40 }
41
42 builder.push_namespace(name);
43
44 let (lower_half_x, upper_half_x) = Level::split(x_in);
45
46 let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::<Level::Base>(
47 builder,
48 format!("lower sum {}b", Level::Base::WIDTH),
49 lower_half_x,
50 first_carry_in,
51 second_carry_in,
52 log_size,
53 zero_oracle_carry,
54 lookup_batch_dci,
55 )?;
56
57 let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::<Level::Base>(
58 builder,
59 format!("upper sum {}b", Level::Base::WIDTH),
60 upper_half_x,
61 internal_carry,
62 zero_oracle_carry,
63 log_size,
64 zero_oracle_carry,
65 lookup_batch_dci,
66 )?;
67
68 let sum = Level::join(&lower_sum, &upper_sum);
69
70 if let Some(witness) = builder.witness() {
72 let x_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
73 let this_byte_oracle = x_in[this_byte_idx];
74 witness
75 .get::<B8>(this_byte_oracle)
76 .unwrap()
77 .as_slice::<u8>()
78 });
79
80 let sum_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
81 let this_byte_oracle = sum[this_byte_idx];
82 witness
83 .get::<B8>(this_byte_oracle)
84 .unwrap()
85 .as_slice::<u8>()
86 });
87
88 let first_cin_as_u8_packed = witness.get::<B1>(first_carry_in).unwrap().as_slice::<u8>();
89 let second_cin_as_u8_packed = witness.get::<B1>(second_carry_in).unwrap().as_slice::<u8>();
90
91 let cout_as_u8_packed = witness.get::<B1>(carry_out).unwrap().as_slice::<u8>();
92
93 for row_idx in 0..1 << log_size {
94 let mut x_u512 = U512::ZERO;
95 for (byte_idx, x_byte_column) in x_bytes_as_u8.clone().enumerate() {
96 x_u512 |= U512::from(x_byte_column[row_idx]) << (8 * byte_idx);
97 }
98
99 let mut sum_u512 = U512::ZERO;
100 for (byte_idx, sum_byte_column) in sum_bytes_as_u8.clone().enumerate() {
101 sum_u512 |= U512::from(sum_byte_column[row_idx]) << (8 * byte_idx);
102 }
103
104 let first_cin_u512 =
105 U512::from((first_cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
106
107 let second_cin_u512 =
108 U512::from((second_cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
109
110 let cout_u512 = U512::from((cout_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
111
112 let expected_sum_u128 = x_u512 + first_cin_u512 + second_cin_u512;
113
114 let sum_according_to_witness = sum_u512 | (cout_u512 << (Level::WIDTH * 8));
115
116 assert_eq!(expected_sum_u128, sum_according_to_witness);
117 }
118 }
119 builder.pop_namespace();
120
121 Ok((carry_out, sum))
122}