binius_circuits/lasso/big_integer_ops/
byte_sliced_test_utils.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug};
4
5use alloy_primitives::U512;
6use binius_core::oracle::OracleId;
7use binius_field::{
8	tower_levels::TowerLevel, BinaryField1b, BinaryField32b, BinaryField8b, Field, TowerField,
9};
10use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};
11
12use super::{
13	byte_sliced_add, byte_sliced_add_carryfree, byte_sliced_double_conditional_increment,
14	byte_sliced_modular_mul, byte_sliced_mul,
15};
16use crate::{
17	builder::test_utils::test_circuit,
18	lasso::{
19		batch::LookupBatch,
20		lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup},
21	},
22	transparent,
23	unconstrained::unconstrained,
24};
25
26type B8 = BinaryField8b;
27type B32 = BinaryField32b;
28
29pub fn random_u512(rng: &mut impl Rng) -> U512 {
30	let limbs = array::from_fn(|_| rng.gen());
31	U512::from_limbs(limbs)
32}
33
34pub fn test_bytesliced_add<const WIDTH: usize, TL>()
35where
36	TL: TowerLevel,
37{
38	test_circuit(|builder| {
39		let log_size = 14;
40		let x_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
41		let y_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "y", log_size).unwrap());
42		let c_in = unconstrained::<BinaryField1b>(builder, "cin first", log_size)?;
43		let lookup_t_add = add_lookup(builder, "add table")?;
44		let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
45		let _sum_and_cout = byte_sliced_add::<TL>(
46			builder,
47			"lasso_bytesliced_add",
48			&x_in,
49			&y_in,
50			c_in,
51			log_size,
52			&mut lookup_batch_add,
53		)?;
54		lookup_batch_add.execute::<B32>(builder)?;
55		Ok(vec![])
56	})
57	.unwrap();
58}
59
60pub fn test_bytesliced_add_carryfree<const WIDTH: usize, TL>()
61where
62	TL: TowerLevel,
63{
64	test_circuit(|builder| {
65		let log_size = 14;
66		let x_in =
67			TL::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL));
68		let y_in =
69			TL::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL));
70		let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL);
71
72		if let Some(witness) = builder.witness() {
73			let mut x_in: [_; WIDTH] =
74				array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(x_in[byte_idx]));
75			let mut y_in: [_; WIDTH] =
76				array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(y_in[byte_idx]));
77			let mut c_in = witness.new_column::<BinaryField1b>(c_in);
78
79			let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::<u8>());
80			let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::<u8>());
81			let c_in_u8 = c_in.as_mut_slice::<u8>();
82
83			for row_idx in 0..1 << log_size {
84				let mut rng = thread_rng();
85				let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8);
86				let mut x = random_u512(&mut rng);
87				x &= input_bitmask;
88				let mut y = random_u512(&mut rng);
89				y &= input_bitmask;
90
91				let mut c: bool = rng.gen();
92
93				while (x + y + U512::from(c)) > input_bitmask {
94					x = random_u512(&mut rng);
95					x &= input_bitmask;
96					y = random_u512(&mut rng);
97					y &= input_bitmask;
98					c = rng.gen();
99				}
100
101				for byte_idx in 0..WIDTH {
102					x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx);
103
104					y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx);
105				}
106
107				c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8);
108			}
109		}
110
111		let lookup_t_add = add_lookup(builder, "add table")?;
112		let lookup_t_add_carryfree = add_carryfree_lookup(builder, "add table")?;
113
114		let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
115		let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]);
116
117		let _sum_and_cout = byte_sliced_add_carryfree::<TL>(
118			builder,
119			"lasso_bytesliced_add_carryfree",
120			&x_in,
121			&y_in,
122			c_in,
123			log_size,
124			&mut lookup_batch_add,
125			&mut lookup_batch_add_carryfree,
126		)?;
127
128		lookup_batch_add.execute::<B32>(builder)?;
129		lookup_batch_add_carryfree.execute::<B32>(builder)?;
130		Ok(vec![])
131	})
132	.unwrap();
133}
134
135pub fn test_bytesliced_double_conditional_increment<const WIDTH: usize, TL>()
136where
137	TL: TowerLevel,
138{
139	test_circuit(|builder| {
140		let log_size = 14;
141		let x_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
142		let first_c_in = unconstrained::<BinaryField1b>(builder, "cin first", log_size)?;
143		let second_c_in = unconstrained::<BinaryField1b>(builder, "cin second", log_size)?;
144		let zero_oracle_carry =
145			transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?;
146		let lookup_t_dci = dci_lookup(builder, "add table")?;
147		let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]);
148		let _sum_and_cout = byte_sliced_double_conditional_increment::<TL>(
149			builder,
150			"lasso_bytesliced_DCI",
151			&x_in,
152			first_c_in,
153			second_c_in,
154			log_size,
155			zero_oracle_carry,
156			&mut lookup_batch_dci,
157		)?;
158		lookup_batch_dci.execute::<B32>(builder)?;
159		Ok(vec![])
160	})
161	.unwrap();
162}
163
164pub fn test_bytesliced_mul<const WIDTH: usize, TL>()
165where
166	TL: TowerLevel,
167{
168	test_circuit(|builder| {
169		let log_size = 14;
170		let mult_a =
171			TL::Base::from_fn(|_| unconstrained::<BinaryField8b>(builder, "a", log_size).unwrap());
172		let mult_b =
173			TL::Base::from_fn(|_| unconstrained::<BinaryField8b>(builder, "b", log_size).unwrap());
174		let zero_oracle_carry =
175			transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?;
176		let lookup_t_mul = mul_lookup(builder, "mul lookup")?;
177		let lookup_t_add = add_lookup(builder, "add lookup")?;
178		let lookup_t_dci = dci_lookup(builder, "dci lookup")?;
179		let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]);
180		let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
181		let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]);
182		let _sum_and_cout = byte_sliced_mul::<TL::Base, TL>(
183			builder,
184			"lasso_bytesliced_mul",
185			&mult_a,
186			&mult_b,
187			log_size,
188			zero_oracle_carry,
189			&mut lookup_batch_mul,
190			&mut lookup_batch_add,
191			&mut lookup_batch_dci,
192		)?;
193		Ok(vec![])
194	})
195	.unwrap();
196}
197
198pub fn test_bytesliced_modular_mul<const WIDTH: usize, TL>()
199where
200	TL: TowerLevel<Data<usize>: Debug>,
201	TL::Base: TowerLevel<Data<usize> = [OracleId; WIDTH]>,
202{
203	test_circuit(|builder| {
204		let log_size = 12;
205		let mut rng = thread_rng();
206		let mult_a = builder.add_committed_multiple::<WIDTH>("a", log_size, B8::TOWER_LEVEL);
207		let mult_b = builder.add_committed_multiple::<WIDTH>("b", log_size, B8::TOWER_LEVEL);
208		let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8);
209		let modulus =
210			(random_u512(&mut StdRng::from_seed([42; 32])) % input_bitmask) + U512::from(1u8);
211
212		if let Some(witness) = builder.witness() {
213			let mut mult_a: [_; WIDTH] =
214				array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(mult_a[byte_idx]));
215
216			let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::<u8>());
217
218			let mut mult_b: [_; WIDTH] =
219				array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(mult_b[byte_idx]));
220
221			let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::<u8>());
222
223			for row_idx in 0..1 << log_size {
224				let mut a = random_u512(&mut rng);
225				let mut b = random_u512(&mut rng);
226
227				a %= modulus;
228				b %= modulus;
229
230				for byte_idx in 0..WIDTH {
231					mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx);
232					mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx);
233				}
234			}
235		}
236
237		let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx));
238		let zero_oracle_byte =
239			transparent::constant(builder, "zero carry", log_size, BinaryField8b::ZERO)?;
240		let zero_oracle_carry =
241			transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?;
242		let _modded_product = byte_sliced_modular_mul::<TL::Base, TL>(
243			builder,
244			"lasso_bytesliced_mul",
245			&mult_a,
246			&mult_b,
247			&modulus_input,
248			log_size,
249			zero_oracle_byte,
250			zero_oracle_carry,
251		)?;
252		Ok(vec![])
253	})
254	.unwrap();
255}