binius_circuits/lasso/big_integer_ops/
byte_sliced_test_utils.rs1use std::{array, fmt::Debug};
4
5use alloy_primitives::U512;
6use binius_core::oracle::OracleId;
7use binius_field::{
8 tower_levels::TowerLevel, BinaryField1b, BinaryField32b, BinaryField8b, Field, TowerField,
9};
10use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};
11
12use super::{
13 byte_sliced_add, byte_sliced_add_carryfree, byte_sliced_double_conditional_increment,
14 byte_sliced_modular_mul, byte_sliced_mul,
15};
16use crate::{
17 builder::test_utils::test_circuit,
18 lasso::{
19 batch::LookupBatch,
20 lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup},
21 },
22 transparent,
23 unconstrained::unconstrained,
24};
25
26type B8 = BinaryField8b;
27type B32 = BinaryField32b;
28
29pub fn random_u512(rng: &mut impl Rng) -> U512 {
30 let limbs = array::from_fn(|_| rng.gen());
31 U512::from_limbs(limbs)
32}
33
34pub fn test_bytesliced_add<const WIDTH: usize, TL>()
35where
36 TL: TowerLevel,
37{
38 test_circuit(|builder| {
39 let log_size = 14;
40 let x_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
41 let y_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "y", log_size).unwrap());
42 let c_in = unconstrained::<BinaryField1b>(builder, "cin first", log_size)?;
43 let lookup_t_add = add_lookup(builder, "add table")?;
44 let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
45 let _sum_and_cout = byte_sliced_add::<TL>(
46 builder,
47 "lasso_bytesliced_add",
48 &x_in,
49 &y_in,
50 c_in,
51 log_size,
52 &mut lookup_batch_add,
53 )?;
54 lookup_batch_add.execute::<B32>(builder)?;
55 Ok(vec![])
56 })
57 .unwrap();
58}
59
60pub fn test_bytesliced_add_carryfree<const WIDTH: usize, TL>()
61where
62 TL: TowerLevel,
63{
64 test_circuit(|builder| {
65 let log_size = 14;
66 let x_in =
67 TL::from_fn(|_| builder.add_committed("x", log_size, BinaryField8b::TOWER_LEVEL));
68 let y_in =
69 TL::from_fn(|_| builder.add_committed("y", log_size, BinaryField8b::TOWER_LEVEL));
70 let c_in = builder.add_committed("c", log_size, BinaryField1b::TOWER_LEVEL);
71
72 if let Some(witness) = builder.witness() {
73 let mut x_in: [_; WIDTH] =
74 array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(x_in[byte_idx]));
75 let mut y_in: [_; WIDTH] =
76 array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(y_in[byte_idx]));
77 let mut c_in = witness.new_column::<BinaryField1b>(c_in);
78
79 let x_in_bytes_u8: [_; WIDTH] = x_in.each_mut().map(|col| col.as_mut_slice::<u8>());
80 let y_in_bytes_u8: [_; WIDTH] = y_in.each_mut().map(|col| col.as_mut_slice::<u8>());
81 let c_in_u8 = c_in.as_mut_slice::<u8>();
82
83 for row_idx in 0..1 << log_size {
84 let mut rng = thread_rng();
85 let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8);
86 let mut x = random_u512(&mut rng);
87 x &= input_bitmask;
88 let mut y = random_u512(&mut rng);
89 y &= input_bitmask;
90
91 let mut c: bool = rng.gen();
92
93 while (x + y + U512::from(c)) > input_bitmask {
94 x = random_u512(&mut rng);
95 x &= input_bitmask;
96 y = random_u512(&mut rng);
97 y &= input_bitmask;
98 c = rng.gen();
99 }
100
101 for byte_idx in 0..WIDTH {
102 x_in_bytes_u8[byte_idx][row_idx] = x.byte(byte_idx);
103
104 y_in_bytes_u8[byte_idx][row_idx] = y.byte(byte_idx);
105 }
106
107 c_in_u8[row_idx / 8] |= (c as u8) << (row_idx % 8);
108 }
109 }
110
111 let lookup_t_add = add_lookup(builder, "add table")?;
112 let lookup_t_add_carryfree = add_carryfree_lookup(builder, "add table")?;
113
114 let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
115 let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]);
116
117 let _sum_and_cout = byte_sliced_add_carryfree::<TL>(
118 builder,
119 "lasso_bytesliced_add_carryfree",
120 &x_in,
121 &y_in,
122 c_in,
123 log_size,
124 &mut lookup_batch_add,
125 &mut lookup_batch_add_carryfree,
126 )?;
127
128 lookup_batch_add.execute::<B32>(builder)?;
129 lookup_batch_add_carryfree.execute::<B32>(builder)?;
130 Ok(vec![])
131 })
132 .unwrap();
133}
134
135pub fn test_bytesliced_double_conditional_increment<const WIDTH: usize, TL>()
136where
137 TL: TowerLevel,
138{
139 test_circuit(|builder| {
140 let log_size = 14;
141 let x_in = TL::from_fn(|_| unconstrained::<BinaryField8b>(builder, "x", log_size).unwrap());
142 let first_c_in = unconstrained::<BinaryField1b>(builder, "cin first", log_size)?;
143 let second_c_in = unconstrained::<BinaryField1b>(builder, "cin second", log_size)?;
144 let zero_oracle_carry =
145 transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?;
146 let lookup_t_dci = dci_lookup(builder, "add table")?;
147 let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]);
148 let _sum_and_cout = byte_sliced_double_conditional_increment::<TL>(
149 builder,
150 "lasso_bytesliced_DCI",
151 &x_in,
152 first_c_in,
153 second_c_in,
154 log_size,
155 zero_oracle_carry,
156 &mut lookup_batch_dci,
157 )?;
158 lookup_batch_dci.execute::<B32>(builder)?;
159 Ok(vec![])
160 })
161 .unwrap();
162}
163
164pub fn test_bytesliced_mul<const WIDTH: usize, TL>()
165where
166 TL: TowerLevel,
167{
168 test_circuit(|builder| {
169 let log_size = 14;
170 let mult_a =
171 TL::Base::from_fn(|_| unconstrained::<BinaryField8b>(builder, "a", log_size).unwrap());
172 let mult_b =
173 TL::Base::from_fn(|_| unconstrained::<BinaryField8b>(builder, "b", log_size).unwrap());
174 let zero_oracle_carry =
175 transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?;
176 let lookup_t_mul = mul_lookup(builder, "mul lookup")?;
177 let lookup_t_add = add_lookup(builder, "add lookup")?;
178 let lookup_t_dci = dci_lookup(builder, "dci lookup")?;
179 let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]);
180 let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
181 let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]);
182 let _sum_and_cout = byte_sliced_mul::<TL::Base, TL>(
183 builder,
184 "lasso_bytesliced_mul",
185 &mult_a,
186 &mult_b,
187 log_size,
188 zero_oracle_carry,
189 &mut lookup_batch_mul,
190 &mut lookup_batch_add,
191 &mut lookup_batch_dci,
192 )?;
193 Ok(vec![])
194 })
195 .unwrap();
196}
197
198pub fn test_bytesliced_modular_mul<const WIDTH: usize, TL>()
199where
200 TL: TowerLevel<Data<usize>: Debug>,
201 TL::Base: TowerLevel<Data<usize> = [OracleId; WIDTH]>,
202{
203 test_circuit(|builder| {
204 let log_size = 12;
205 let mut rng = thread_rng();
206 let mult_a = builder.add_committed_multiple::<WIDTH>("a", log_size, B8::TOWER_LEVEL);
207 let mult_b = builder.add_committed_multiple::<WIDTH>("b", log_size, B8::TOWER_LEVEL);
208 let input_bitmask = (U512::from(1u8) << (8 * WIDTH)) - U512::from(1u8);
209 let modulus =
210 (random_u512(&mut StdRng::from_seed([42; 32])) % input_bitmask) + U512::from(1u8);
211
212 if let Some(witness) = builder.witness() {
213 let mut mult_a: [_; WIDTH] =
214 array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(mult_a[byte_idx]));
215
216 let mult_a_u8 = mult_a.each_mut().map(|col| col.as_mut_slice::<u8>());
217
218 let mut mult_b: [_; WIDTH] =
219 array::from_fn(|byte_idx| witness.new_column::<BinaryField8b>(mult_b[byte_idx]));
220
221 let mult_b_u8 = mult_b.each_mut().map(|col| col.as_mut_slice::<u8>());
222
223 for row_idx in 0..1 << log_size {
224 let mut a = random_u512(&mut rng);
225 let mut b = random_u512(&mut rng);
226
227 a %= modulus;
228 b %= modulus;
229
230 for byte_idx in 0..WIDTH {
231 mult_a_u8[byte_idx][row_idx] = a.byte(byte_idx);
232 mult_b_u8[byte_idx][row_idx] = b.byte(byte_idx);
233 }
234 }
235 }
236
237 let modulus_input: [_; WIDTH] = array::from_fn(|byte_idx| modulus.byte(byte_idx));
238 let zero_oracle_byte =
239 transparent::constant(builder, "zero carry", log_size, BinaryField8b::ZERO)?;
240 let zero_oracle_carry =
241 transparent::constant(builder, "zero carry", log_size, BinaryField1b::ZERO)?;
242 let _modded_product = byte_sliced_modular_mul::<TL::Base, TL>(
243 builder,
244 "lasso_bytesliced_mul",
245 &mult_a,
246 &mult_b,
247 &modulus_input,
248 log_size,
249 zero_oracle_byte,
250 zero_oracle_carry,
251 )?;
252 Ok(vec![])
253 })
254 .unwrap();
255}