binius_circuits/lasso/big_integer_ops/
byte_sliced_mul.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::oracle::OracleId;
6use binius_field::{tower_levels::TowerLevel, BinaryField8b};
7
8use super::{byte_sliced_add, byte_sliced_double_conditional_increment};
9use crate::{
10	builder::ConstraintSystemBuilder,
11	lasso::{batch::LookupBatch, u8mul::u8mul_bytesliced},
12};
13
14type B8 = BinaryField8b;
15
16#[allow(clippy::too_many_arguments)]
17pub fn byte_sliced_mul<LevelIn: TowerLevel, LevelOut: TowerLevel<Base = LevelIn>>(
18	builder: &mut ConstraintSystemBuilder,
19	name: impl ToString,
20	mult_a: &LevelIn::Data<OracleId>,
21	mult_b: &LevelIn::Data<OracleId>,
22	log_size: usize,
23	zero_carry_oracle: OracleId,
24	lookup_batch_mul: &mut LookupBatch,
25	lookup_batch_add: &mut LookupBatch,
26	lookup_batch_dci: &mut LookupBatch,
27) -> Result<LevelOut::Data<OracleId>, anyhow::Error> {
28	if LevelIn::WIDTH == 1 {
29		let result_of_u8mul = u8mul_bytesliced(
30			builder,
31			lookup_batch_mul,
32			"u8 mul",
33			mult_a[0],
34			mult_b[0],
35			1 << log_size,
36		)?;
37		let mut lower_result_of_u8mul = LevelIn::default();
38		lower_result_of_u8mul[0] = result_of_u8mul[0];
39		let mut upper_result_of_u8mul = LevelIn::default();
40		upper_result_of_u8mul[0] = result_of_u8mul[1];
41
42		let result_typed_arr = LevelOut::join(&lower_result_of_u8mul, &upper_result_of_u8mul);
43
44		return Ok(result_typed_arr);
45	}
46
47	builder.push_namespace(name);
48
49	let (mult_a_low, mult_a_high) = LevelIn::split(mult_a);
50	let (mult_b_low, mult_b_high) = LevelIn::split(mult_b);
51
52	let a_lo_b_lo = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
53		builder,
54		format!("lo*lo {}b", LevelIn::Base::WIDTH),
55		mult_a_low,
56		mult_b_low,
57		log_size,
58		zero_carry_oracle,
59		lookup_batch_mul,
60		lookup_batch_add,
61		lookup_batch_dci,
62	)?;
63	let a_lo_b_hi = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
64		builder,
65		format!("lo*hi {}b", LevelIn::Base::WIDTH),
66		mult_a_low,
67		mult_b_high,
68		log_size,
69		zero_carry_oracle,
70		lookup_batch_mul,
71		lookup_batch_add,
72		lookup_batch_dci,
73	)?;
74	let a_hi_b_lo = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
75		builder,
76		format!("hi*lo {}b", LevelIn::Base::WIDTH),
77		mult_a_high,
78		mult_b_low,
79		log_size,
80		zero_carry_oracle,
81		lookup_batch_mul,
82		lookup_batch_add,
83		lookup_batch_dci,
84	)?;
85	let a_hi_b_hi = byte_sliced_mul::<LevelIn::Base, LevelOut::Base>(
86		builder,
87		format!("hi*hi {}b", LevelIn::Base::WIDTH),
88		mult_a_high,
89		mult_b_high,
90		log_size,
91		zero_carry_oracle,
92		lookup_batch_mul,
93		lookup_batch_add,
94		lookup_batch_dci,
95	)?;
96
97	let (karatsuba_carry_for_high_chunk, karatsuba_term) = byte_sliced_add::<LevelIn>(
98		builder,
99		format!("karastsuba addition {}b", LevelIn::WIDTH),
100		&a_lo_b_hi,
101		&a_hi_b_lo,
102		zero_carry_oracle,
103		log_size,
104		lookup_batch_add,
105	)?;
106
107	let (a_lo_b_lo_lower_half, a_lo_b_lo_upper_half) = LevelIn::split(&a_lo_b_lo);
108	let (a_hi_b_hi_lower_half, a_hi_b_hi_upper_half) = LevelIn::split(&a_hi_b_hi);
109
110	let (additional_carry_for_high_chunk, final_middle_chunk) = byte_sliced_add::<LevelIn>(
111		builder,
112		format!("post kartsuba middle term addition {}b", LevelIn::WIDTH),
113		&karatsuba_term,
114		&LevelIn::join(a_lo_b_lo_upper_half, a_hi_b_hi_lower_half),
115		zero_carry_oracle,
116		log_size,
117		lookup_batch_add,
118	)?;
119
120	let (_, final_high_chunk) = byte_sliced_double_conditional_increment::<LevelIn::Base>(
121		builder,
122		format!("high chunk DCI {}b", LevelIn::Base::WIDTH),
123		a_hi_b_hi_upper_half,
124		karatsuba_carry_for_high_chunk,
125		additional_carry_for_high_chunk,
126		log_size,
127		zero_carry_oracle,
128		lookup_batch_dci,
129	)?;
130
131	let (final_middle_chunk_lower_half, final_middle_chunk_upper_half) =
132		LevelIn::split(&final_middle_chunk);
133
134	let final_lower_half = LevelIn::join(a_lo_b_lo_lower_half, final_middle_chunk_lower_half);
135
136	let final_upper_half = LevelIn::join(final_middle_chunk_upper_half, &final_high_chunk);
137
138	let product = LevelOut::join(&final_lower_half, &final_upper_half);
139
140	// All of the code below is for test assertions
141	if let Some(witness) = builder.witness() {
142		let a_bytes_as_u8 = (0..LevelIn::WIDTH).map(|this_byte_idx| {
143			let this_byte_oracle = mult_a[this_byte_idx];
144			witness
145				.get::<B8>(this_byte_oracle)
146				.unwrap()
147				.as_slice::<u8>()
148		});
149
150		let b_bytes_as_u8 = (0..LevelIn::WIDTH).map(|this_byte_idx| {
151			let this_byte_oracle = mult_b[this_byte_idx];
152			witness
153				.get::<B8>(this_byte_oracle)
154				.unwrap()
155				.as_slice::<u8>()
156		});
157
158		let product_bytes_as_u8 = (0..LevelOut::WIDTH).map(|this_byte_idx| {
159			let this_byte_oracle = product[this_byte_idx];
160			witness
161				.get::<B8>(this_byte_oracle)
162				.unwrap()
163				.as_slice::<u8>()
164		});
165
166		for row_idx in 0..1 << log_size {
167			let mut a_u512 = U512::ZERO;
168			for (byte_idx, a_byte_column) in a_bytes_as_u8.clone().enumerate() {
169				a_u512 |= U512::from(a_byte_column[row_idx]) << (8 * byte_idx);
170			}
171
172			let mut b_u512 = U512::ZERO;
173			for (byte_idx, b_byte_column) in b_bytes_as_u8.clone().enumerate() {
174				b_u512 |= U512::from(b_byte_column[row_idx]) << (8 * byte_idx);
175			}
176
177			let mut product_u512 = U512::ZERO;
178			for (byte_idx, product_byte_column) in product_bytes_as_u8.clone().enumerate() {
179				product_u512 |= U512::from(product_byte_column[row_idx]) << (8 * byte_idx);
180			}
181
182			assert_eq!(a_u512 * b_u512, product_u512);
183		}
184	}
185
186	builder.pop_namespace();
187	Ok(product)
188}