binius_circuits/lasso/big_integer_ops/
byte_sliced_modular_mul.rs1use alloy_primitives::U512;
4use anyhow::Result;
5use binius_core::{oracle::OracleId, transparent::constant::Constant};
6use binius_field::{
7 tower_levels::TowerLevel, underlier::WithUnderlier, BinaryField32b, BinaryField8b, TowerField,
8};
9use binius_macros::arith_expr;
10
11use super::{byte_sliced_add_carryfree, byte_sliced_mul};
12use crate::{
13 builder::{types::F, ConstraintSystemBuilder},
14 lasso::{
15 batch::LookupBatch,
16 lookups::u8_arithmetic::{add_carryfree_lookup, add_lookup, dci_lookup, mul_lookup},
17 },
18};
19
20type B8 = BinaryField8b;
21
22#[allow(clippy::too_many_arguments)]
23pub fn byte_sliced_modular_mul<LevelIn: TowerLevel, LevelOut: TowerLevel<Base = LevelIn>>(
24 builder: &mut ConstraintSystemBuilder,
25 name: impl ToString,
26 mult_a: &LevelIn::Data<OracleId>,
27 mult_b: &LevelIn::Data<OracleId>,
28 modulus_input: &[u8],
29 log_size: usize,
30 zero_byte_oracle: OracleId,
31 zero_carry_oracle: OracleId,
32) -> Result<LevelIn::Data<OracleId>, anyhow::Error> {
33 builder.push_namespace(name);
34
35 let lookup_t_mul = mul_lookup(builder, "mul table")?;
36 let lookup_t_add = add_lookup(builder, "add table")?;
37 let lookup_t_add_carryfree = add_carryfree_lookup(builder, "add cf table")?;
38
39 let lookup_t_dci = if LevelIn::WIDTH == 1 {
41 usize::MAX
42 } else {
43 dci_lookup(builder, "dci table")?
44 };
45
46 let mut lookup_batch_mul = LookupBatch::new([lookup_t_mul]);
47 let mut lookup_batch_add = LookupBatch::new([lookup_t_add]);
48 let mut lookup_batch_add_carryfree = LookupBatch::new([lookup_t_add_carryfree]);
49
50 let mut lookup_batch_dci = LookupBatch::new([lookup_t_dci]);
52
53 let mut quotient = LevelIn::default();
54 let mut remainder = LevelIn::default();
55 let mut modulus = LevelIn::default();
56
57 for byte_idx in 0..LevelIn::WIDTH {
58 quotient[byte_idx] = builder.add_committed("quotient", log_size, B8::TOWER_LEVEL);
59 remainder[byte_idx] = builder.add_committed("remainder", log_size, B8::TOWER_LEVEL);
60 modulus[byte_idx] = builder.add_transparent(
61 "modulus",
62 Constant::new(
63 log_size,
64 <F as WithUnderlier>::from_underlier(<u8 as Into<
65 <F as WithUnderlier>::Underlier,
66 >>::into(modulus_input[byte_idx])),
67 ),
68 )?;
69 }
70
71 let ab = byte_sliced_mul::<LevelIn, LevelOut>(
72 builder,
73 "ab",
74 mult_a,
75 mult_b,
76 log_size,
77 zero_carry_oracle,
78 &mut lookup_batch_mul,
79 &mut lookup_batch_add,
80 &mut lookup_batch_dci,
81 )?;
82
83 if let Some(witness) = builder.witness() {
84 let ab_bytes_as_u8: Vec<_> = (0..LevelOut::WIDTH)
85 .map(|this_byte_idx| {
86 let this_byte_oracle = ab[this_byte_idx];
87 witness
88 .get::<B8>(this_byte_oracle)
89 .unwrap()
90 .as_slice::<u8>()
91 })
92 .collect();
93
94 let mut quotient: Vec<_> = (0..LevelIn::WIDTH)
95 .map(|this_byte_idx| {
96 let this_byte_oracle = quotient[this_byte_idx];
97 witness.new_column::<B8>(this_byte_oracle)
98 })
99 .collect();
100
101 let mut remainder: Vec<_> = (0..LevelIn::WIDTH)
102 .map(|this_byte_idx| {
103 let this_byte_oracle: usize = remainder[this_byte_idx];
104 witness.new_column::<B8>(this_byte_oracle)
105 })
106 .collect();
107
108 let mut modulus: Vec<_> = (0..LevelIn::WIDTH)
109 .map(|this_byte_idx| {
110 let this_byte_oracle = modulus[this_byte_idx];
111 witness.new_column::<B8>(this_byte_oracle)
112 })
113 .collect();
114
115 let mut modulus_u512 = U512::ZERO;
116
117 for (byte_idx, modulus_byte_column) in modulus.iter_mut().enumerate() {
118 let modulus_byte_column_u8 = modulus_byte_column.as_mut_slice::<u8>();
119 modulus_u512 |= U512::from(modulus_input[byte_idx]) << (8 * byte_idx);
120 modulus_byte_column_u8.fill(modulus_input[byte_idx]);
121 }
122
123 for row_idx in 0..1 << log_size {
124 let mut ab_u512 = U512::ZERO;
125 for (byte_idx, ab_byte_column) in ab_bytes_as_u8.iter().enumerate() {
126 ab_u512 |= U512::from(ab_byte_column[row_idx]) << (8 * byte_idx);
127 }
128
129 let quotient_u512 = ab_u512 / modulus_u512;
130 let remainder_u512 = ab_u512 % modulus_u512;
131
132 for (byte_idx, quotient_byte_column) in quotient.iter_mut().enumerate() {
133 let quotient_byte_column_u8 = quotient_byte_column.as_mut_slice::<u8>();
134 quotient_byte_column_u8[row_idx] = quotient_u512.byte(byte_idx);
135 }
136
137 for (byte_idx, remainder_byte_column) in remainder.iter_mut().enumerate() {
138 let remainder_byte_column_u8 = remainder_byte_column.as_mut_slice::<u8>();
139 remainder_byte_column_u8[row_idx] = remainder_u512.byte(byte_idx);
140 }
141 }
142 }
143
144 let qm = byte_sliced_mul::<LevelIn, LevelOut>(
145 builder,
146 "qm",
147 "ient,
148 &modulus,
149 log_size,
150 zero_carry_oracle,
151 &mut lookup_batch_mul,
152 &mut lookup_batch_add,
153 &mut lookup_batch_dci,
154 )?;
155
156 let mut repeating_zero = LevelIn::default();
157 for byte_idx in 0..LevelIn::WIDTH {
158 repeating_zero[byte_idx] = zero_byte_oracle;
159 }
160
161 let qm_plus_r = byte_sliced_add_carryfree::<LevelOut>(
162 builder,
163 "hi*lo",
164 &qm,
165 &LevelOut::join(&remainder, &repeating_zero),
166 zero_carry_oracle,
167 log_size,
168 &mut lookup_batch_add,
169 &mut lookup_batch_add_carryfree,
170 )?;
171
172 lookup_batch_mul.execute::<BinaryField32b>(builder)?;
173 lookup_batch_add.execute::<BinaryField32b>(builder)?;
174 lookup_batch_add_carryfree.execute::<BinaryField32b>(builder)?;
175
176 if LevelIn::WIDTH != 1 {
177 lookup_batch_dci.execute::<BinaryField32b>(builder)?;
178 }
179
180 let consistency = arith_expr!([x, y] = x - y);
181
182 for byte_idx in 0..LevelOut::WIDTH {
183 builder.assert_zero(
184 format!("byte_consistency_{byte_idx}"),
185 [ab[byte_idx], qm_plus_r[byte_idx]],
186 consistency.clone().convert_field(),
187 );
188 }
189
190 builder.pop_namespace();
191 Ok(remainder)
192}