binius_core/ring_switch/
eq_ind.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{any::TypeId, iter, marker::PhantomData, sync::Arc};
4
5use binius_field::{
6	byte_iteration::{
7		can_iterate_bytes, create_partial_sums_lookup_tables, iterate_bytes, ByteIteratorCallback,
8	},
9	util::inner_product_unchecked,
10	BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
11	TowerField,
12};
13use binius_math::{tensor_prod_eq_ind, MultilinearExtension};
14use binius_maybe_rayon::prelude::*;
15use binius_utils::bail;
16use bytemuck::zeroed_vec;
17
18use super::error::Error;
19use crate::{
20	polynomial::{Error as PolynomialError, MultivariatePoly},
21	tensor_algebra::TensorAlgebra,
22};
23
24/// Information about the row-batching coefficients.
25#[derive(Debug)]
26pub struct RowBatchCoeffs<F> {
27	coeffs: Vec<F>,
28	/// This is a lookup table for the partial sums of the coefficients
29	/// that is used to efficiently fold with 1-bit coefficients.
30	partial_sums_lookup_table: Vec<F>,
31}
32
33impl<F: Field> RowBatchCoeffs<F> {
34	pub fn new(coeffs: Vec<F>) -> Self {
35		let partial_sums_lookup_table = if coeffs.len() >= 8 {
36			create_partial_sums_lookup_tables(coeffs.as_slice())
37		} else {
38			Vec::new()
39		};
40
41		Self {
42			coeffs,
43			partial_sums_lookup_table,
44		}
45	}
46
47	pub fn coeffs(&self) -> &[F] {
48		&self.coeffs
49	}
50}
51
52/// The multilinear function $A$ from [DP24] Section 5.
53///
54/// The function $A$ is $\ell':= \ell - \kappa$-variate and depends on the last $\ell'$ coordinates
55/// of the evaluation point as well as the $\kappa$ mixing challenges.
56///
57/// [DP24]: <https://eprint.iacr.org/2024/504>
58#[derive(Debug, Clone)]
59pub struct RingSwitchEqInd<FSub, F> {
60	/// $z_{\kappa}, \ldots, z_{\ell-1}$
61	z_vals: Arc<[F]>,
62	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
63	mixing_coeff: F,
64	_marker: PhantomData<FSub>,
65}
66
67impl<FSub, F> RingSwitchEqInd<FSub, F>
68where
69	FSub: Field,
70	F: ExtensionField<FSub>,
71{
72	pub fn new(
73		z_vals: Arc<[F]>,
74		row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
75		mixing_coeff: F,
76	) -> Result<Self, Error> {
77		if row_batch_coeffs.coeffs.len() < F::DEGREE {
78			bail!(Error::InvalidArgs(
79				"RingSwitchEqInd::new expects row_batch_coeffs length greater than or equal to \
80				the extension degree"
81					.into()
82			));
83		}
84
85		Ok(Self {
86			z_vals,
87			row_batch_coeffs,
88			mixing_coeff,
89			_marker: PhantomData,
90		})
91	}
92
93	pub fn multilinear_extension<P: PackedFieldIndexable<Scalar = F>>(
94		&self,
95	) -> Result<MultilinearExtension<P>, Error> {
96		let mut evals = zeroed_vec::<P>(1 << self.z_vals.len().saturating_sub(P::LOG_WIDTH));
97		evals[0].set(0, self.mixing_coeff);
98		tensor_prod_eq_ind(0, &mut evals, &self.z_vals)?;
99		P::unpack_scalars_mut(&mut evals)
100			.par_iter_mut()
101			.for_each(|val| {
102				*val = inner_product_subfield(*val, &self.row_batch_coeffs);
103			});
104		Ok(MultilinearExtension::from_values(evals)?)
105	}
106}
107
108#[inline(always)]
109fn inner_product_subfield<FSub, F>(value: F, row_batch_coeffs: &RowBatchCoeffs<F>) -> F
110where
111	FSub: Field,
112	F: ExtensionField<FSub>,
113{
114	if TypeId::of::<FSub>() == TypeId::of::<BinaryField1b>() && can_iterate_bytes::<F>() {
115		// Special case when we are folding with 1-bit coefficients.
116		// Use partial sums lookup table to speed up the computation.
117
118		struct Callback<'a, F> {
119			partial_sums_lookup: &'a [F],
120			result: F,
121		}
122
123		impl<F: Field> ByteIteratorCallback for Callback<'_, F> {
124			#[inline(always)]
125			fn call(&mut self, iter: impl Iterator<Item = u8>) {
126				for (byte_index, byte) in iter.enumerate() {
127					self.result += self.partial_sums_lookup[(byte_index << 8) + byte as usize];
128				}
129			}
130		}
131
132		let mut callback = Callback {
133			partial_sums_lookup: &row_batch_coeffs.partial_sums_lookup_table,
134			result: F::ZERO,
135		};
136		iterate_bytes(std::slice::from_ref(&value), &mut callback);
137
138		callback.result
139	} else {
140		// fall back to the general case
141		inner_product_unchecked(row_batch_coeffs.coeffs.iter().copied(), F::iter_bases(&value))
142	}
143}
144
145impl<FSub, F> MultivariatePoly<F> for RingSwitchEqInd<FSub, F>
146where
147	FSub: TowerField,
148	F: TowerField + PackedField<Scalar = F> + PackedExtension<FSub>,
149{
150	fn n_vars(&self) -> usize {
151		self.z_vals.len()
152	}
153
154	fn degree(&self) -> usize {
155		self.n_vars()
156	}
157
158	fn evaluate(&self, query: &[F]) -> Result<F, PolynomialError> {
159		if query.len() != self.n_vars() {
160			bail!(PolynomialError::IncorrectQuerySize {
161				expected: self.n_vars()
162			});
163		};
164
165		let tensor_eval = iter::zip(&*self.z_vals, query).fold(
166			<TensorAlgebra<FSub, F>>::from_vertical(self.mixing_coeff),
167			|eval, (&vert_i, &hztl_i)| {
168				// This formula is specific to characteristic 2 fields
169				// Here we know that $h v + (1 - h) (1 - v) = 1 + h + v$.
170				let vert_scaled = eval.clone().scale_vertical(vert_i);
171				let hztl_scaled = eval.clone().scale_horizontal(hztl_i);
172				eval + &vert_scaled + &hztl_scaled
173			},
174		);
175
176		let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs.coeffs);
177		Ok(folded_eval)
178	}
179
180	fn binary_tower_level(&self) -> usize {
181		F::TOWER_LEVEL
182	}
183}
184
185#[cfg(test)]
186mod tests {
187	use binius_field::{BinaryField128b, BinaryField8b};
188	use binius_math::MultilinearQuery;
189	use iter::repeat_with;
190	use rand::{prelude::StdRng, SeedableRng};
191
192	use super::*;
193
194	#[test]
195	fn test_evaluation_consistency() {
196		type FS = BinaryField8b;
197		type F = BinaryField128b;
198		let kappa = <TensorAlgebra<FS, F>>::kappa();
199		let ell = 10;
200		let mut rng = StdRng::seed_from_u64(0);
201
202		let n_vars = ell - kappa;
203		let z_vals = repeat_with(|| <F as Field>::random(&mut rng))
204			.take(n_vars)
205			.collect::<Arc<[_]>>();
206
207		let row_batch_coeffs = repeat_with(|| <F as Field>::random(&mut rng))
208			.take(1 << kappa)
209			.collect::<Vec<_>>();
210		let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(row_batch_coeffs));
211
212		let eval_point = repeat_with(|| <F as Field>::random(&mut rng))
213			.take(n_vars)
214			.collect::<Vec<_>>();
215		let eval_query = MultilinearQuery::<F>::expand(&eval_point);
216
217		let mixing_coeff = <F as Field>::random(&mut rng);
218
219		let rs_eq = RingSwitchEqInd::<FS, _>::new(z_vals, row_batch_coeffs, mixing_coeff).unwrap();
220		let mle = rs_eq.multilinear_extension::<F>().unwrap();
221
222		let val1 = rs_eq.evaluate(&eval_point).unwrap();
223		let val2 = mle.evaluate(&eval_query).unwrap();
224		assert_eq!(val1, val2);
225	}
226}