binius_core/ring_switch/
eq_ind.rs1use std::{any::TypeId, iter, marker::PhantomData, sync::Arc};
4
5use binius_field::{
6 byte_iteration::{
7 can_iterate_bytes, create_partial_sums_lookup_tables, iterate_bytes, ByteIteratorCallback,
8 },
9 util::inner_product_unchecked,
10 BinaryField1b, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
11 TowerField,
12};
13use binius_math::{tensor_prod_eq_ind, MultilinearExtension};
14use binius_maybe_rayon::prelude::*;
15use binius_utils::bail;
16use bytemuck::zeroed_vec;
17
18use super::error::Error;
19use crate::{
20 polynomial::{Error as PolynomialError, MultivariatePoly},
21 tensor_algebra::TensorAlgebra,
22};
23
24#[derive(Debug)]
26pub struct RowBatchCoeffs<F> {
27 coeffs: Vec<F>,
28 partial_sums_lookup_table: Vec<F>,
31}
32
33impl<F: Field> RowBatchCoeffs<F> {
34 pub fn new(coeffs: Vec<F>) -> Self {
35 let partial_sums_lookup_table = if coeffs.len() >= 8 {
36 create_partial_sums_lookup_tables(coeffs.as_slice())
37 } else {
38 Vec::new()
39 };
40
41 Self {
42 coeffs,
43 partial_sums_lookup_table,
44 }
45 }
46
47 pub fn coeffs(&self) -> &[F] {
48 &self.coeffs
49 }
50}
51
52#[derive(Debug, Clone)]
59pub struct RingSwitchEqInd<FSub, F> {
60 z_vals: Arc<[F]>,
62 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
63 mixing_coeff: F,
64 _marker: PhantomData<FSub>,
65}
66
67impl<FSub, F> RingSwitchEqInd<FSub, F>
68where
69 FSub: Field,
70 F: ExtensionField<FSub>,
71{
72 pub fn new(
73 z_vals: Arc<[F]>,
74 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
75 mixing_coeff: F,
76 ) -> Result<Self, Error> {
77 if row_batch_coeffs.coeffs.len() < F::DEGREE {
78 bail!(Error::InvalidArgs(
79 "RingSwitchEqInd::new expects row_batch_coeffs length greater than or equal to \
80 the extension degree"
81 .into()
82 ));
83 }
84
85 Ok(Self {
86 z_vals,
87 row_batch_coeffs,
88 mixing_coeff,
89 _marker: PhantomData,
90 })
91 }
92
93 pub fn multilinear_extension<P: PackedFieldIndexable<Scalar = F>>(
94 &self,
95 ) -> Result<MultilinearExtension<P>, Error> {
96 let mut evals = zeroed_vec::<P>(1 << self.z_vals.len().saturating_sub(P::LOG_WIDTH));
97 evals[0].set(0, self.mixing_coeff);
98 tensor_prod_eq_ind(0, &mut evals, &self.z_vals)?;
99 P::unpack_scalars_mut(&mut evals)
100 .par_iter_mut()
101 .for_each(|val| {
102 *val = inner_product_subfield(*val, &self.row_batch_coeffs);
103 });
104 Ok(MultilinearExtension::from_values(evals)?)
105 }
106}
107
108#[inline(always)]
109fn inner_product_subfield<FSub, F>(value: F, row_batch_coeffs: &RowBatchCoeffs<F>) -> F
110where
111 FSub: Field,
112 F: ExtensionField<FSub>,
113{
114 if TypeId::of::<FSub>() == TypeId::of::<BinaryField1b>() && can_iterate_bytes::<F>() {
115 struct Callback<'a, F> {
119 partial_sums_lookup: &'a [F],
120 result: F,
121 }
122
123 impl<F: Field> ByteIteratorCallback for Callback<'_, F> {
124 #[inline(always)]
125 fn call(&mut self, iter: impl Iterator<Item = u8>) {
126 for (byte_index, byte) in iter.enumerate() {
127 self.result += self.partial_sums_lookup[(byte_index << 8) + byte as usize];
128 }
129 }
130 }
131
132 let mut callback = Callback {
133 partial_sums_lookup: &row_batch_coeffs.partial_sums_lookup_table,
134 result: F::ZERO,
135 };
136 iterate_bytes(std::slice::from_ref(&value), &mut callback);
137
138 callback.result
139 } else {
140 inner_product_unchecked(row_batch_coeffs.coeffs.iter().copied(), F::iter_bases(&value))
142 }
143}
144
145impl<FSub, F> MultivariatePoly<F> for RingSwitchEqInd<FSub, F>
146where
147 FSub: TowerField,
148 F: TowerField + PackedField<Scalar = F> + PackedExtension<FSub>,
149{
150 fn n_vars(&self) -> usize {
151 self.z_vals.len()
152 }
153
154 fn degree(&self) -> usize {
155 self.n_vars()
156 }
157
158 fn evaluate(&self, query: &[F]) -> Result<F, PolynomialError> {
159 if query.len() != self.n_vars() {
160 bail!(PolynomialError::IncorrectQuerySize {
161 expected: self.n_vars()
162 });
163 };
164
165 let tensor_eval = iter::zip(&*self.z_vals, query).fold(
166 <TensorAlgebra<FSub, F>>::from_vertical(self.mixing_coeff),
167 |eval, (&vert_i, &hztl_i)| {
168 let vert_scaled = eval.clone().scale_vertical(vert_i);
171 let hztl_scaled = eval.clone().scale_horizontal(hztl_i);
172 eval + &vert_scaled + &hztl_scaled
173 },
174 );
175
176 let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs.coeffs);
177 Ok(folded_eval)
178 }
179
180 fn binary_tower_level(&self) -> usize {
181 F::TOWER_LEVEL
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use binius_field::{BinaryField128b, BinaryField8b};
188 use binius_math::MultilinearQuery;
189 use iter::repeat_with;
190 use rand::{prelude::StdRng, SeedableRng};
191
192 use super::*;
193
194 #[test]
195 fn test_evaluation_consistency() {
196 type FS = BinaryField8b;
197 type F = BinaryField128b;
198 let kappa = <TensorAlgebra<FS, F>>::kappa();
199 let ell = 10;
200 let mut rng = StdRng::seed_from_u64(0);
201
202 let n_vars = ell - kappa;
203 let z_vals = repeat_with(|| <F as Field>::random(&mut rng))
204 .take(n_vars)
205 .collect::<Arc<[_]>>();
206
207 let row_batch_coeffs = repeat_with(|| <F as Field>::random(&mut rng))
208 .take(1 << kappa)
209 .collect::<Vec<_>>();
210 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(row_batch_coeffs));
211
212 let eval_point = repeat_with(|| <F as Field>::random(&mut rng))
213 .take(n_vars)
214 .collect::<Vec<_>>();
215 let eval_query = MultilinearQuery::<F>::expand(&eval_point);
216
217 let mixing_coeff = <F as Field>::random(&mut rng);
218
219 let rs_eq = RingSwitchEqInd::<FS, _>::new(z_vals, row_batch_coeffs, mixing_coeff).unwrap();
220 let mle = rs_eq.multilinear_extension::<F>().unwrap();
221
222 let val1 = rs_eq.evaluate(&eval_point).unwrap();
223 let val2 = mle.evaluate(&eval_query).unwrap();
224 assert_eq!(val1, val2);
225 }
226}