binius_circuits/lasso/big_integer_ops/
byte_sliced_double_conditional_increment.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField1b, BinaryField8b};
7
8use crate::{
9	builder::ConstraintSystemBuilder,
10	lasso::{batch::LookupBatch, u8_double_conditional_increment},
11};
12
13type B1 = BinaryField1b;
14type B8 = BinaryField8b;
15
16#[allow(clippy::too_many_arguments)]
17pub fn byte_sliced_double_conditional_increment<Level: TowerLevel<Data<OracleId>: Sized>>(
18	builder: &mut ConstraintSystemBuilder,
19	name: impl ToString,
20	x_in: &Level::Data<OracleId>,
21	first_carry_in: OracleId,
22	second_carry_in: OracleId,
23	log_size: usize,
24	zero_oracle_carry: usize,
25	lookup_batch_dci: &mut LookupBatch,
26) -> Result<(OracleId, Level::Data<OracleId>), anyhow::Error> {
27	if Level::WIDTH == 1 {
28		let (carry_out, sum) = u8_double_conditional_increment(
29			builder,
30			lookup_batch_dci,
31			"u8 DCI",
32			x_in[0],
33			first_carry_in,
34			second_carry_in,
35			log_size,
36		)?;
37		let mut sum_arr = Level::default();
38		sum_arr[0] = sum;
39		return Ok((carry_out, sum_arr));
40	}
41
42	builder.push_namespace(name);
43
44	let (lower_half_x, upper_half_x) = Level::split(x_in);
45
46	let (internal_carry, lower_sum) = byte_sliced_double_conditional_increment::<Level::Base>(
47		builder,
48		format!("lower sum {}b", Level::Base::WIDTH),
49		lower_half_x,
50		first_carry_in,
51		second_carry_in,
52		log_size,
53		zero_oracle_carry,
54		lookup_batch_dci,
55	)?;
56
57	let (carry_out, upper_sum) = byte_sliced_double_conditional_increment::<Level::Base>(
58		builder,
59		format!("upper sum {}b", Level::Base::WIDTH),
60		upper_half_x,
61		internal_carry,
62		zero_oracle_carry,
63		log_size,
64		zero_oracle_carry,
65		lookup_batch_dci,
66	)?;
67
68	let sum = Level::join(&lower_sum, &upper_sum);
69
70	// Everything below is for test assertions
71	if let Some(witness) = builder.witness() {
72		let x_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
73			let this_byte_oracle = x_in[this_byte_idx];
74			witness
75				.get::<B8>(this_byte_oracle)
76				.unwrap()
77				.as_slice::<u8>()
78		});
79
80		let sum_bytes_as_u8 = (0..Level::WIDTH).map(|this_byte_idx| {
81			let this_byte_oracle = sum[this_byte_idx];
82			witness
83				.get::<B8>(this_byte_oracle)
84				.unwrap()
85				.as_slice::<u8>()
86		});
87
88		let first_cin_as_u8_packed = witness.get::<B1>(first_carry_in).unwrap().as_slice::<u8>();
89		let second_cin_as_u8_packed = witness.get::<B1>(second_carry_in).unwrap().as_slice::<u8>();
90
91		let cout_as_u8_packed = witness.get::<B1>(carry_out).unwrap().as_slice::<u8>();
92
93		for row_idx in 0..1 << log_size {
94			let mut x_u512 = U512::ZERO;
95			for (byte_idx, x_byte_column) in x_bytes_as_u8.clone().enumerate() {
96				x_u512 |= U512::from(x_byte_column[row_idx]) << (8 * byte_idx);
97			}
98
99			let mut sum_u512 = U512::ZERO;
100			for (byte_idx, sum_byte_column) in sum_bytes_as_u8.clone().enumerate() {
101				sum_u512 |= U512::from(sum_byte_column[row_idx]) << (8 * byte_idx);
102			}
103
104			let first_cin_u512 =
105				U512::from((first_cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
106
107			let second_cin_u512 =
108				U512::from((second_cin_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
109
110			let cout_u512 = U512::from((cout_as_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1);
111
112			let expected_sum_u128 = x_u512 + first_cin_u512 + second_cin_u512;
113
114			let sum_according_to_witness = sum_u512 | (cout_u512 << (Level::WIDTH * 8));
115
116			assert_eq!(expected_sum_u128, sum_according_to_witness);
117		}
118	}
119	builder.pop_namespace();
120
121	Ok((carry_out, sum))
122}