binius_circuits/lasso/
u8mul.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright 2024-2025 Irreducible Inc.

use anyhow::{ensure, Result};
use binius_core::oracle::OracleId;
use binius_field::{
	as_packed_field::{PackScalar, PackedType},
	underlier::UnderlierType,
	BinaryField, BinaryField16b, BinaryField32b, BinaryField8b, ExtensionField,
	PackedFieldIndexable, TowerField,
};
use bytemuck::Pod;
use itertools::izip;

use super::batch::LookupBatch;
use crate::builder::ConstraintSystemBuilder;

type B8 = BinaryField8b;
type B16 = BinaryField16b;
type B32 = BinaryField32b;

pub fn u8mul_bytesliced<U, F>(
	builder: &mut ConstraintSystemBuilder<U, F>,
	lookup_batch: &mut LookupBatch,
	name: impl ToString + Clone,
	mult_a: OracleId,
	mult_b: OracleId,
	n_multiplications: usize,
) -> Result<[OracleId; 2], anyhow::Error>
where
	U: Pod + UnderlierType + PackScalar<B8> + PackScalar<B16> + PackScalar<B32> + PackScalar<F>,
	PackedType<U, B8>: PackedFieldIndexable,
	PackedType<U, B16>: PackedFieldIndexable,
	PackedType<U, B32>: PackedFieldIndexable,
	F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
	builder.push_namespace(name.clone());
	let log_rows = builder.log_rows([mult_a, mult_b])?;
	let product = builder.add_committed_multiple("product", log_rows, B8::TOWER_LEVEL);

	let lookup_u = builder.add_linear_combination(
		"lookup_u",
		log_rows,
		[
			(mult_a, <F as TowerField>::basis(3, 3)?),
			(mult_b, <F as TowerField>::basis(3, 2)?),
			(product[1], <F as TowerField>::basis(3, 1)?),
			(product[0], <F as TowerField>::basis(3, 0)?),
		],
	)?;

	let mut u_to_t_mapping = Vec::new();

	if let Some(witness) = builder.witness() {
		let mut product_low_witness = witness.new_column::<B8>(product[0]);
		let mut product_high_witness = witness.new_column::<B8>(product[1]);
		let mut lookup_u_witness = witness.new_column::<B32>(lookup_u);
		let mut u_to_t_mapping_witness = vec![0; 1 << log_rows];

		let mult_a_ints = witness.get::<B8>(mult_a)?.as_slice::<u8>();
		let mult_b_ints = witness.get::<B8>(mult_b)?.as_slice::<u8>();

		let product_low_u8 = product_low_witness.as_mut_slice::<u8>();
		let product_high_u8 = product_high_witness.as_mut_slice::<u8>();
		let lookup_u_u32 = lookup_u_witness.as_mut_slice::<u32>();

		for (a, b, lookup_u, product_low, product_high, u_to_t) in izip!(
			mult_a_ints,
			mult_b_ints,
			lookup_u_u32.iter_mut(),
			product_low_u8.iter_mut(),
			product_high_u8.iter_mut(),
			u_to_t_mapping_witness.iter_mut()
		) {
			let a_int = *a as usize;
			let b_int = *b as usize;
			let ab_product = a_int * b_int;
			let lookup_index = a_int << 8 | b_int;
			*lookup_u = (lookup_index << 16 | ab_product) as u32;

			*product_high = (ab_product >> 8) as u8;
			*product_low = (ab_product & 0xff) as u8;

			*u_to_t = lookup_index;
		}

		u_to_t_mapping = u_to_t_mapping_witness;
	}

	lookup_batch.add(lookup_u, u_to_t_mapping, n_multiplications);

	builder.pop_namespace();
	Ok(product)
}

pub fn u8mul<U, F>(
	builder: &mut ConstraintSystemBuilder<U, F>,
	lookup_batch: &mut LookupBatch,
	name: impl ToString + Clone,
	mult_a: OracleId,
	mult_b: OracleId,
	n_multiplications: usize,
) -> Result<OracleId, anyhow::Error>
where
	U: Pod + UnderlierType + PackScalar<B8> + PackScalar<B16> + PackScalar<B32> + PackScalar<F>,
	PackedType<U, B8>: PackedFieldIndexable,
	PackedType<U, B16>: PackedFieldIndexable,
	PackedType<U, B32>: PackedFieldIndexable,
	F: TowerField + BinaryField + ExtensionField<B8> + ExtensionField<B16> + ExtensionField<B32>,
{
	builder.push_namespace(name.clone());

	let product_bytesliced =
		u8mul_bytesliced(builder, lookup_batch, name, mult_a, mult_b, n_multiplications)?;
	let log_rows = builder.log_rows(product_bytesliced)?;
	ensure!(n_multiplications <= 1 << log_rows);

	let product = builder.add_linear_combination(
		"bytes summed",
		log_rows,
		[
			(product_bytesliced[0], <F as TowerField>::basis(3, 0)?),
			(product_bytesliced[1], <F as TowerField>::basis(3, 1)?),
		],
	)?;

	if let Some(witness) = builder.witness() {
		let product_low_witness = witness.get::<B8>(product_bytesliced[0])?;
		let product_high_witness = witness.get::<B8>(product_bytesliced[1])?;

		let mut product_witness = witness.new_column::<B16>(product);

		let product_low_u8 = product_low_witness.as_slice::<u8>();
		let product_high_u8 = product_high_witness.as_slice::<u8>();

		let product_u16 = product_witness.as_mut_slice::<u16>();

		for (row_idx, row_product) in product_u16.iter_mut().enumerate() {
			*row_product = (product_high_u8[row_idx] as u16) << 8 | product_low_u8[row_idx] as u16;
		}
	}

	builder.pop_namespace();
	Ok(product)
}