binius_circuits/lasso/big_integer_ops/
byte_sliced_modular_mul.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::{oracle::OracleId, transparent::constant::Constant};
6use binius_field::{
7	tower_levels::TowerLevel, underlier::WithUnderlier, BinaryField32b, BinaryField8b, TowerField,
8};
9use binius_macros::arith_expr;
10
11use super::{byte_sliced_add_carryfree, byte_sliced_mul};
12use crate::{
13	builder::{types::F, ConstraintSystemBuilder},
14	lasso::{
15		batch::LookupBatch,
16		lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup},
17	},
18};
19
20type B8 = BinaryField8b;
21
22#[allow(clippy::too_many_arguments)]
23pub fn byte_sliced_modular_mul<LevelIn: TowerLevel, LevelOut: TowerLevel<Base = LevelIn>>(
24	builder: &mut ConstraintSystemBuilder,
25	name: impl ToString,
26	mult_a: &LevelIn::Data<OracleId>,
27	mult_b: &LevelIn::Data<OracleId>,
28	modulus_input: &[u8],
29	log_size: usize,
30	zero_byte_oracle: OracleId,
31	zero_carry_oracle: OracleId,
32) -> Result<LevelIn::Data<OracleId>, anyhow::Error> {
33	builder.push_namespace(name);
34
35	let lookup_t_mul = mul_lookup(builder, "mul table")?;
36	let lookup_t_add = add_lookup(builder, "add table")?;
37	let lookup_t_add_carryfree = add_carryfree_lookup(builder, "add cf table")?;
38
39	// The double conditional increment wont be used if we're at the base of the tower
40	let lookup_t_dci = if LevelIn::WIDTH == 1 {
41		usize::MAX
42	} else {
43		dci_lookup(builder, "dci table")?
44	};
45
46	let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]);
47	let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
48	let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]);
49
50	// This batch WILL NOT get executed if we are instantiating it for 8b mul
51	let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]);
52
53	let mut quotient = LevelIn::default();
54	let mut remainder = LevelIn::default();
55	let mut modulus = LevelIn::default();
56
57	for byte_idx in 0..LevelIn::WIDTH {
58		quotient[byte_idx] = builder.add_committed("quotient", log_size, B8::TOWER_LEVEL);
59		remainder[byte_idx] = builder.add_committed("remainder", log_size, B8::TOWER_LEVEL);
60		modulus[byte_idx] = builder.add_transparent(
61			"modulus",
62			Constant::new(
63				log_size,
64				<F as WithUnderlier>::from_underlier(<u8 as Into<
65					<F as WithUnderlier>::Underlier,
66				>>::into(modulus_input[byte_idx])),
67			),
68		)?;
69	}
70
71	let ab = byte_sliced_mul::<LevelIn, LevelOut>(
72		builder,
73		"ab",
74		mult_a,
75		mult_b,
76		log_size,
77		zero_carry_oracle,
78		&mut lookup_batch_mul,
79		&mut lookup_batch_add,
80		&mut lookup_batch_dci,
81	)?;
82
83	if let Some(witness) = builder.witness() {
84		let ab_bytes_as_u8: Vec<_> = (0..LevelOut::WIDTH)
85			.map(|this_byte_idx| {
86				let this_byte_oracle = ab[this_byte_idx];
87				witness
88					.get::<B8>(this_byte_oracle)
89					.unwrap()
90					.as_slice::<u8>()
91			})
92			.collect();
93
94		let mut quotient: Vec<_> = (0..LevelIn::WIDTH)
95			.map(|this_byte_idx| {
96				let this_byte_oracle = quotient[this_byte_idx];
97				witness.new_column::<B8>(this_byte_oracle)
98			})
99			.collect();
100
101		let mut remainder: Vec<_> = (0..LevelIn::WIDTH)
102			.map(|this_byte_idx| {
103				let this_byte_oracle: usize = remainder[this_byte_idx];
104				witness.new_column::<B8>(this_byte_oracle)
105			})
106			.collect();
107
108		let mut modulus: Vec<_> = (0..LevelIn::WIDTH)
109			.map(|this_byte_idx| {
110				let this_byte_oracle = modulus[this_byte_idx];
111				witness.new_column::<B8>(this_byte_oracle)
112			})
113			.collect();
114
115		let mut modulus_u512 = U512::ZERO;
116
117		for (byte_idx, modulus_byte_column) in modulus.iter_mut().enumerate() {
118			let modulus_byte_column_u8 = modulus_byte_column.as_mut_slice::<u8>();
119			modulus_u512 |= U512::from(modulus_input[byte_idx]) << (8 * byte_idx);
120			modulus_byte_column_u8.fill(modulus_input[byte_idx]);
121		}
122
123		for row_idx in 0..1 << log_size {
124			let mut ab_u512 = U512::ZERO;
125			for (byte_idx, ab_byte_column) in ab_bytes_as_u8.iter().enumerate() {
126				ab_u512 |= U512::from(ab_byte_column[row_idx]) << (8 * byte_idx);
127			}
128
129			let quotient_u512 = ab_u512 / modulus_u512;
130			let remainder_u512 = ab_u512 % modulus_u512;
131
132			for (byte_idx, quotient_byte_column) in quotient.iter_mut().enumerate() {
133				let quotient_byte_column_u8 = quotient_byte_column.as_mut_slice::<u8>();
134				quotient_byte_column_u8[row_idx] = quotient_u512.byte(byte_idx);
135			}
136
137			for (byte_idx, remainder_byte_column) in remainder.iter_mut().enumerate() {
138				let remainder_byte_column_u8 = remainder_byte_column.as_mut_slice::<u8>();
139				remainder_byte_column_u8[row_idx] = remainder_u512.byte(byte_idx);
140			}
141		}
142	}
143
144	let qm = byte_sliced_mul::<LevelIn, LevelOut>(
145		builder,
146		"qm",
147		&quotient,
148		&modulus,
149		log_size,
150		zero_carry_oracle,
151		&mut lookup_batch_mul,
152		&mut lookup_batch_add,
153		&mut lookup_batch_dci,
154	)?;
155
156	let mut repeating_zero = LevelIn::default();
157	for byte_idx in 0..LevelIn::WIDTH {
158		repeating_zero[byte_idx] = zero_byte_oracle;
159	}
160
161	let qm_plus_r = byte_sliced_add_carryfree::<LevelOut>(
162		builder,
163		"hi*lo",
164		&qm,
165		&LevelOut::join(&remainder, &repeating_zero),
166		zero_carry_oracle,
167		log_size,
168		&mut lookup_batch_add,
169		&mut lookup_batch_add_carryfree,
170	)?;
171
172	lookup_batch_mul.execute::<BinaryField32b>(builder)?;
173	lookup_batch_add.execute::<BinaryField32b>(builder)?;
174	lookup_batch_add_carryfree.execute::<BinaryField32b>(builder)?;
175
176	if LevelIn::WIDTH != 1 {
177		lookup_batch_dci.execute::<BinaryField32b>(builder)?;
178	}
179
180	let consistency = arith_expr!([x, y] = x - y);
181
182	for byte_idx in 0..LevelOut::WIDTH {
183		builder.assert_zero(
184			format!("byte_consistency_{byte_idx}"),
185			[ab[byte_idx], qm_plus_r[byte_idx]],
186			consistency.clone().convert_field(),
187		);
188	}
189
190	builder.pop_namespace();
191	Ok(remainder)
192}