binius_circuits/lasso/
u8_double_conditional_increment.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use anyhow::Result;
4use binius_core::oracle::OracleId;
5use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField};
6
7use super::batch::LookupBatch;
8use crate::builder::{types::F, ConstraintSystemBuilder};
9
10type B1 = BinaryField1b;
11type B8 = BinaryField8b;
12type B32 = BinaryField32b;
13
14pub fn u8_double_conditional_increment(
15	builder: &mut ConstraintSystemBuilder,
16	lookup_batch: &mut LookupBatch,
17	name: impl ToString + Clone,
18	x_in: OracleId,
19	first_carry_in: OracleId,
20	second_carry_in: OracleId,
21	log_size: usize,
22) -> Result<(OracleId, OracleId), anyhow::Error> {
23	builder.push_namespace(name);
24
25	let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL);
26
27	let carry_out = builder.add_committed("cout", log_size, B1::TOWER_LEVEL);
28
29	let lookup_u = builder.add_linear_combination(
30		"lookup_u",
31		log_size,
32		[
33			(first_carry_in, <F as TowerField>::basis(0, 18)?),
34			(second_carry_in, <F as TowerField>::basis(0, 17)?),
35			(carry_out, <F as TowerField>::basis(3, 2)?),
36			(x_in, <F as TowerField>::basis(3, 1)?),
37			(sum, <F as TowerField>::basis(3, 0)?),
38		],
39	)?;
40
41	let mut u_to_t_mapping = vec![];
42
43	if let Some(witness) = builder.witness() {
44		let mut sum_witness = witness.new_column::<B8>(sum);
45		let mut carry_out_witness = witness.new_column::<B1>(carry_out);
46		let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
47		let mut u_to_t_mapping_witness = vec![0; 1 << log_size];
48
49		let x_in_u8 = witness.get::<B8>(x_in)?.as_slice::<u8>();
50		let first_carry_in_u8_packed = witness.get::<B1>(first_carry_in)?.as_slice::<u8>();
51		let second_carry_in_u8_packed = witness.get::<B1>(second_carry_in)?.as_slice::<u8>();
52
53		let sum_u8 = sum_witness.as_mut_slice::<u8>();
54		let carry_out_u8_packed = carry_out_witness.as_mut_slice::<u8>();
55		let lookup_u_u32 = lookup_u_witness.as_mut_slice::<u32>();
56
57		for row_idx in 0..1 << log_size {
58			let first_carry_in_usize =
59				((first_carry_in_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1) as usize;
60			let second_carry_in_usize =
61				((second_carry_in_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1) as usize;
62
63			let x_in_usize = x_in_u8[row_idx] as usize;
64			let sum_with_carry_out = x_in_usize + first_carry_in_usize + second_carry_in_usize;
65			let sum_usize = sum_with_carry_out & 0xff;
66			let carry_out_usize = sum_with_carry_out >> 8;
67			let lookup_index =
68				(first_carry_in_usize << 9) | (second_carry_in_usize << 8) | x_in_usize;
69			let lookup_value = (first_carry_in_usize << 18)
70				| (second_carry_in_usize << 17)
71				| (carry_out_usize << 16)
72				| (x_in_usize << 8)
73				| sum_usize;
74
75			lookup_u_u32[row_idx] = lookup_value as u32;
76
77			sum_u8[row_idx] = sum_usize as u8;
78
79			// Write our value to the bit
80			carry_out_u8_packed[row_idx / 8] |= (carry_out_usize << (row_idx % 8)) as u8;
81
82			u_to_t_mapping_witness[row_idx] = lookup_index;
83		}
84
85		u_to_t_mapping = u_to_t_mapping_witness;
86	}
87
88	lookup_batch.add([lookup_u], u_to_t_mapping, 1 << log_size);
89
90	builder.pop_namespace();
91	Ok((carry_out, sum))
92}