binius_circuits/lasso/
u8add_carryfree.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use anyhow::Result;
4use binius_core::oracle::OracleId;
5use binius_field::{BinaryField1b, BinaryField32b, BinaryField8b, TowerField};
6
7use super::batch::LookupBatch;
8use crate::builder::{types::F, ConstraintSystemBuilder};
9
10type B1 = BinaryField1b;
11type B8 = BinaryField8b;
12type B32 = BinaryField32b;
13
14pub fn u8add_carryfree(
15	builder: &mut ConstraintSystemBuilder,
16	lookup_batch: &mut LookupBatch,
17	name: impl ToString + Clone,
18	x_in: OracleId,
19	y_in: OracleId,
20	carry_in: OracleId,
21	log_size: usize,
22) -> Result<OracleId, anyhow::Error> {
23	builder.push_namespace(name);
24
25	let sum = builder.add_committed("sum", log_size, B8::TOWER_LEVEL);
26
27	let lookup_u = builder.add_linear_combination(
28		"lookup_u",
29		log_size,
30		[
31			(carry_in, <F as TowerField>::basis(3, 3)?),
32			(x_in, <F as TowerField>::basis(3, 2)?),
33			(y_in, <F as TowerField>::basis(3, 1)?),
34			(sum, <F as TowerField>::basis(3, 0)?),
35		],
36	)?;
37
38	let mut u_to_t_mapping = vec![];
39
40	if let Some(witness) = builder.witness() {
41		let mut sum_witness = witness.new_column::<B8>(sum);
42		let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
43		let mut u_to_t_mapping_witness = vec![0; 1 << log_size];
44
45		let x_in_u8 = witness.get::<B8>(x_in)?.as_slice::<u8>();
46		let y_in_u8 = witness.get::<B8>(y_in)?.as_slice::<u8>();
47		let carry_in_u8_packed = witness.get::<B1>(carry_in)?.as_slice::<u8>();
48
49		let sum_u8 = sum_witness.as_mut_slice::<u8>();
50		let lookup_u_u32 = lookup_u_witness.as_mut_slice::<u32>();
51
52		for row_idx in 0..1 << log_size {
53			let carry_in_usize = ((carry_in_u8_packed[row_idx / 8] >> (row_idx % 8)) & 1) as usize;
54
55			let x_in_usize = x_in_u8[row_idx] as usize;
56			let y_in_usize = y_in_u8[row_idx] as usize;
57			let xy_sum_usize = x_in_usize + y_in_usize + carry_in_usize;
58			let lookup_index = (carry_in_usize << 16) | (x_in_usize << 8) | y_in_usize;
59			let lookup_value = if xy_sum_usize <= 0xff {
60				(carry_in_usize << 24) | (x_in_usize << 16) | (y_in_usize << 8) | xy_sum_usize
61			} else {
62				0
63			};
64
65			lookup_u_u32[row_idx] = lookup_value as u32;
66
67			sum_u8[row_idx] = xy_sum_usize as u8;
68
69			u_to_t_mapping_witness[row_idx] = lookup_index;
70		}
71
72		u_to_t_mapping = u_to_t_mapping_witness;
73	}
74
75	lookup_batch.add([lookup_u], u_to_t_mapping, 1 << log_size);
76
77	builder.pop_namespace();
78	Ok(sum)
79}