binius_circuits/lasso/big_integer_ops/
byte_sliced_mul.rs1use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField8b};
7
8use super::{byte_sliced_add, byte_sliced_double_conditional_increment};
9use crate::{
10 builder::ConstraintSystemBuilder,
11 lasso::{batch::LookupBatch, u8mul::u8mul_bytesliced},
12};
13
14type B8 = BinaryField8b;
15
16#[allow(clippy::too_many_arguments)]
17pub fn byte_sliced_mul<LevelIn: TowerLevel, LevelOut: TowerLevel<Base = LevelIn>>(
18 builder: &mut ConstraintSystemBuilder,
19 name: impl ToString,
20 mult_a: &LevelIn::Data<OracleId>,
21 mult_b: &LevelIn::Data<OracleId>,
22 log_size: usize,
23 zero_carry_oracle: OracleId,
24 lookup_batch_mul: &mut LookupBatch,
25 lookup_batch_add: &mut LookupBatch,
26 lookup_batch_dci: &mut LookupBatch,
27) -> Result<LevelOut::Data<OracleId>, anyhow::Error> {
28 if LevelIn::WIDTH == 1 {
29 let result_of_u8mul = u8mul_bytesliced(
30 builder,
31 lookup_batch_mul,
32 "u8 mul",
33 mult_a[0],
34 mult_b[0],
35 1 << log_size,
36 )?;
37 let mut lower_result_of_u8mul = LevelIn::default();
38 lower_result_of_u8mul[0] = result_of_u8mul[0];
39 let mut upper_result_of_u8mul = LevelIn::default();
40 upper_result_of_u8mul[0] = result_of_u8mul[1];
41
42 let result_typed_arr = LevelOut::join(&lower_result_of_u8mul, &upper_result_of_u8mul);
43
44 return Ok(result_typed_arr);
45 }
46
47 builder.push_namespace(name);
48
49 let (mult_a_low, mult_a_high) = LevelIn::split(mult_a);
50 let (mult_b_low, mult_b_high) = LevelIn::split(mult_b);
51
52 let a_lo_b_lo = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
53 builder,
54 format!("lo*lo {}b", LevelIn::Base::WIDTH),
55 mult_a_low,
56 mult_b_low,
57 log_size,
58 zero_carry_oracle,
59 lookup_batch_mul,
60 lookup_batch_add,
61 lookup_batch_dci,
62 )?;
63 let a_lo_b_hi = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
64 builder,
65 format!("lo*hi {}b", LevelIn::Base::WIDTH),
66 mult_a_low,
67 mult_b_high,
68 log_size,
69 zero_carry_oracle,
70 lookup_batch_mul,
71 lookup_batch_add,
72 lookup_batch_dci,
73 )?;
74 let a_hi_b_lo = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
75 builder,
76 format!("hi*lo {}b", LevelIn::Base::WIDTH),
77 mult_a_high,
78 mult_b_low,
79 log_size,
80 zero_carry_oracle,
81 lookup_batch_mul,
82 lookup_batch_add,
83 lookup_batch_dci,
84 )?;
85 let a_hi_b_hi = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
86 builder,
87 format!("hi*hi {}b", LevelIn::Base::WIDTH),
88 mult_a_high,
89 mult_b_high,
90 log_size,
91 zero_carry_oracle,
92 lookup_batch_mul,
93 lookup_batch_add,
94 lookup_batch_dci,
95 )?;
96
97 let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::<LevelIn>(
98 builder,
99 format!("karastsuba addition {}b", LevelIn::WIDTH),
100 &a_lo_b_hi,
101 &a_hi_b_lo,
102 zero_carry_oracle,
103 log_size,
104 lookup_batch_add,
105 )?;
106
107 let (a_lo_b_lo_lower_half, a_lo_b_lo_upper_half) = LevelIn::split(&a_lo_b_lo);
108 let (a_hi_b_hi_lower_half, a_hi_b_hi_upper_half) = LevelIn::split(&a_hi_b_hi);
109
110 let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::<LevelIn>(
111 builder,
112 format!("post kartsuba middle term addition {}b", LevelIn::WIDTH),
113 &karatsuba_term,
114 &LevelIn::join(a_lo_b_lo_upper_half, a_hi_b_hi_lower_half),
115 zero_carry_oracle,
116 log_size,
117 lookup_batch_add,
118 )?;
119
120 let (_, final_high_chunk) = byte_sliced_double_conditional_increment::<LevelIn::Base>(
121 builder,
122 format!("high chunk DCI {}b", LevelIn::Base::WIDTH),
123 a_hi_b_hi_upper_half,
124 karatsuba_carry_for_high_chunk,
125 additional_carry_for_high_chunk,
126 log_size,
127 zero_carry_oracle,
128 lookup_batch_dci,
129 )?;
130
131 let (final_middle_chunk_lower_half, final_middle_chunk_upper_half) =
132 LevelIn::split(&final_middle_chunk);
133
134 let final_lower_half = LevelIn::join(a_lo_b_lo_lower_half, final_middle_chunk_lower_half);
135
136 let final_upper_half = LevelIn::join(final_middle_chunk_upper_half, &final_high_chunk);
137
138 let product = LevelOut::join(&final_lower_half, &final_upper_half);
139
140 if let Some(witness) = builder.witness() {
142 let a_bytes_as_u8 = (0..LevelIn::WIDTH).map(|this_byte_idx| {
143 let this_byte_oracle = mult_a[this_byte_idx];
144 witness
145 .get::<B8>(this_byte_oracle)
146 .unwrap()
147 .as_slice::<u8>()
148 });
149
150 let b_bytes_as_u8 = (0..LevelIn::WIDTH).map(|this_byte_idx| {
151 let this_byte_oracle = mult_b[this_byte_idx];
152 witness
153 .get::<B8>(this_byte_oracle)
154 .unwrap()
155 .as_slice::<u8>()
156 });
157
158 let product_bytes_as_u8 = (0..LevelOut::WIDTH).map(|this_byte_idx| {
159 let this_byte_oracle = product[this_byte_idx];
160 witness
161 .get::<B8>(this_byte_oracle)
162 .unwrap()
163 .as_slice::<u8>()
164 });
165
166 for row_idx in 0..1 << log_size {
167 let mut a_u512 = U512::ZERO;
168 for (byte_idx, a_byte_column) in a_bytes_as_u8.clone().enumerate() {
169 a_u512 |= U512::from(a_byte_column[row_idx]) << (8 * byte_idx);
170 }
171
172 let mut b_u512 = U512::ZERO;
173 for (byte_idx, b_byte_column) in b_bytes_as_u8.clone().enumerate() {
174 b_u512 |= U512::from(b_byte_column[row_idx]) << (8 * byte_idx);
175 }
176
177 let mut product_u512 = U512::ZERO;
178 for (byte_idx, product_byte_column) in product_bytes_as_u8.clone().enumerate() {
179 product_u512 |= U512::from(product_byte_column[row_idx]) << (8 * byte_idx);
180 }
181
182 assert_eq!(a_u512 * b_u512, product_u512);
183 }
184 }
185
186 builder.pop_namespace();
187 Ok(product)
188}