binius_circuits/lasso/
u8mul.rs1use anyhow::{ensure, Result};
4use binius_core::oracle::OracleId;
5use binius_field::{BinaryField16b, BinaryField32b, BinaryField8b, TowerField};
6use itertools::izip;
7
8use super::batch::LookupBatch;
9use crate::builder::{types::F, ConstraintSystemBuilder};
10
11type B8 = BinaryField8b;
12type B16 = BinaryField16b;
13type B32 = BinaryField32b;
14
15pub fn u8mul_bytesliced(
16 builder: &mut ConstraintSystemBuilder,
17 lookup_batch: &mut LookupBatch,
18 name: impl ToString + Clone,
19 mult_a: OracleId,
20 mult_b: OracleId,
21 n_multiplications: usize,
22) -> Result<[OracleId; 2], anyhow::Error> {
23 builder.push_namespace(name);
24 let log_rows = builder.log_rows([mult_a, mult_b])?;
25 let product = builder.add_committed_multiple("product", log_rows, B8::TOWER_LEVEL);
26
27 let lookup_u = builder.add_linear_combination(
28 "lookup_u",
29 log_rows,
30 [
31 (mult_a, <F as TowerField>::basis(3, 3)?),
32 (mult_b, <F as TowerField>::basis(3, 2)?),
33 (product[1], <F as TowerField>::basis(3, 1)?),
34 (product[0], <F as TowerField>::basis(3, 0)?),
35 ],
36 )?;
37
38 let mut u_to_t_mapping = Vec::new();
39
40 if let Some(witness) = builder.witness() {
41 let mut product_low_witness = witness.new_column::<B8>(product[0]);
42 let mut product_high_witness = witness.new_column::<B8>(product[1]);
43 let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
44 let mut u_to_t_mapping_witness = vec![0; 1 << log_rows];
45
46 let mult_a_ints = witness.get::<B8>(mult_a)?.as_slice::<u8>();
47 let mult_b_ints = witness.get::<B8>(mult_b)?.as_slice::<u8>();
48
49 let product_low_u8 = product_low_witness.as_mut_slice::<u8>();
50 let product_high_u8 = product_high_witness.as_mut_slice::<u8>();
51 let lookup_u_u32 = lookup_u_witness.as_mut_slice::<u32>();
52
53 for (a, b, lookup_u, product_low, product_high, u_to_t) in izip!(
54 mult_a_ints,
55 mult_b_ints,
56 lookup_u_u32.iter_mut(),
57 product_low_u8.iter_mut(),
58 product_high_u8.iter_mut(),
59 u_to_t_mapping_witness.iter_mut()
60 ) {
61 let a_int = *a as usize;
62 let b_int = *b as usize;
63 let ab_product = a_int * b_int;
64 let lookup_index = a_int << 8 | b_int;
65 *lookup_u = (lookup_index << 16 | ab_product) as u32;
66
67 *product_high = (ab_product >> 8) as u8;
68 *product_low = (ab_product & 0xff) as u8;
69
70 *u_to_t = lookup_index;
71 }
72
73 u_to_t_mapping = u_to_t_mapping_witness;
74 }
75
76 lookup_batch.add([lookup_u], u_to_t_mapping, n_multiplications);
77
78 builder.pop_namespace();
79 Ok(product)
80}
81
82pub fn u8mul(
83 builder: &mut ConstraintSystemBuilder,
84 lookup_batch: &mut LookupBatch,
85 name: impl ToString + Clone,
86 mult_a: OracleId,
87 mult_b: OracleId,
88 n_multiplications: usize,
89) -> Result<OracleId, anyhow::Error> {
90 builder.push_namespace(name.clone());
91
92 let product_bytesliced =
93 u8mul_bytesliced(builder, lookup_batch, name, mult_a, mult_b, n_multiplications)?;
94 let log_rows = builder.log_rows(product_bytesliced)?;
95 ensure!(n_multiplications <= 1 << log_rows);
96
97 let product = builder.add_linear_combination(
98 "bytes summed",
99 log_rows,
100 [
101 (product_bytesliced[0], <F as TowerField>::basis(3, 0)?),
102 (product_bytesliced[1], <F as TowerField>::basis(3, 1)?),
103 ],
104 )?;
105
106 if let Some(witness) = builder.witness() {
107 let product_low_witness = witness.get::<B8>(product_bytesliced[0])?;
108 let product_high_witness = witness.get::<B8>(product_bytesliced[1])?;
109
110 let mut product_witness = witness.new_column::<B16>(product);
111
112 let product_low_u8 = product_low_witness.as_slice::<u8>();
113 let product_high_u8 = product_high_witness.as_slice::<u8>();
114
115 let product_u16 = product_witness.as_mut_slice::<u16>();
116
117 for (row_idx, row_product) in product_u16.iter_mut().enumerate() {
118 *row_product = (product_high_u8[row_idx] as u16) << 8 | product_low_u8[row_idx] as u16;
119 }
120 }
121
122 builder.pop_namespace();
123 Ok(product)
124}