binius_circuits/lasso/
u8mul.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use anyhow::{ensure, Result};
4use binius_core::oracle::OracleId;
5use binius_field::{BinaryField16b, BinaryField32b, BinaryField8b, TowerField};
6use itertools::izip;
7
8use super::batch::LookupBatch;
9use crate::builder::{types::F, ConstraintSystemBuilder};
10
11type B8 = BinaryField8b;
12type B16 = BinaryField16b;
13type B32 = BinaryField32b;
14
15pub fn u8mul_bytesliced(
16	builder: &mut ConstraintSystemBuilder,
17	lookup_batch: &mut LookupBatch,
18	name: impl ToString + Clone,
19	mult_a: OracleId,
20	mult_b: OracleId,
21	n_multiplications: usize,
22) -> Result<[OracleId; 2], anyhow::Error> {
23	builder.push_namespace(name);
24	let log_rows = builder.log_rows([mult_a, mult_b])?;
25	let product = builder.add_committed_multiple("product", log_rows, B8::TOWER_LEVEL);
26
27	let lookup_u = builder.add_linear_combination(
28		"lookup_u",
29		log_rows,
30		[
31			(mult_a, <F as TowerField>::basis(3, 3)?),
32			(mult_b, <F as TowerField>::basis(3, 2)?),
33			(product[1], <F as TowerField>::basis(3, 1)?),
34			(product[0], <F as TowerField>::basis(3, 0)?),
35		],
36	)?;
37
38	let mut u_to_t_mapping = Vec::new();
39
40	if let Some(witness) = builder.witness() {
41		let mut product_low_witness = witness.new_column::<B8>(product[0]);
42		let mut product_high_witness = witness.new_column::<B8>(product[1]);
43		let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
44		let mut u_to_t_mapping_witness = vec![0; 1 << log_rows];
45
46		let mult_a_ints = witness.get::<B8>(mult_a)?.as_slice::<u8>();
47		let mult_b_ints = witness.get::<B8>(mult_b)?.as_slice::<u8>();
48
49		let product_low_u8 = product_low_witness.as_mut_slice::<u8>();
50		let product_high_u8 = product_high_witness.as_mut_slice::<u8>();
51		let lookup_u_u32 = lookup_u_witness.as_mut_slice::<u32>();
52
53		for (a, b, lookup_u, product_low, product_high, u_to_t) in izip!(
54			mult_a_ints,
55			mult_b_ints,
56			lookup_u_u32.iter_mut(),
57			product_low_u8.iter_mut(),
58			product_high_u8.iter_mut(),
59			u_to_t_mapping_witness.iter_mut()
60		) {
61			let a_int = *a as usize;
62			let b_int = *b as usize;
63			let ab_product = a_int * b_int;
64			let lookup_index = a_int << 8 | b_int;
65			*lookup_u = (lookup_index << 16 | ab_product) as u32;
66
67			*product_high = (ab_product >> 8) as u8;
68			*product_low = (ab_product & 0xff) as u8;
69
70			*u_to_t = lookup_index;
71		}
72
73		u_to_t_mapping = u_to_t_mapping_witness;
74	}
75
76	lookup_batch.add([lookup_u], u_to_t_mapping, n_multiplications);
77
78	builder.pop_namespace();
79	Ok(product)
80}
81
82pub fn u8mul(
83	builder: &mut ConstraintSystemBuilder,
84	lookup_batch: &mut LookupBatch,
85	name: impl ToString + Clone,
86	mult_a: OracleId,
87	mult_b: OracleId,
88	n_multiplications: usize,
89) -> Result<OracleId, anyhow::Error> {
90	builder.push_namespace(name.clone());
91
92	let product_bytesliced =
93		u8mul_bytesliced(builder, lookup_batch, name, mult_a, mult_b, n_multiplications)?;
94	let log_rows = builder.log_rows(product_bytesliced)?;
95	ensure!(n_multiplications <= 1 << log_rows);
96
97	let product = builder.add_linear_combination(
98		"bytes summed",
99		log_rows,
100		[
101			(product_bytesliced[0], <F as TowerField>::basis(3, 0)?),
102			(product_bytesliced[1], <F as TowerField>::basis(3, 1)?),
103		],
104	)?;
105
106	if let Some(witness) = builder.witness() {
107		let product_low_witness = witness.get::<B8>(product_bytesliced[0])?;
108		let product_high_witness = witness.get::<B8>(product_bytesliced[1])?;
109
110		let mut product_witness = witness.new_column::<B16>(product);
111
112		let product_low_u8 = product_low_witness.as_slice::<u8>();
113		let product_high_u8 = product_high_witness.as_slice::<u8>();
114
115		let product_u16 = product_witness.as_mut_slice::<u16>();
116
117		for (row_idx, row_product) in product_u16.iter_mut().enumerate() {
118			*row_product = (product_high_u8[row_idx] as u16) << 8 | product_low_u8[row_idx] as u16;
119		}
120	}
121
122	builder.pop_namespace();
123	Ok(product)
124}