1use std::{iter, marker::PhantomData, sync::Arc};
4
5use binius_compute::{
6 ComputeLayer, ComputeLayerExecutor, ComputeMemory, SizedSlice, SubfieldSlice,
7 alloc::ComputeAllocator, cpu::CpuMemory,
8};
9use binius_field::{ExtensionField, Field, PackedExtension, PackedField, TowerField};
10use binius_utils::bail;
11
12use super::error::Error;
13use crate::{
14 polynomial::{Error as PolynomialError, MultivariatePoly},
15 tensor_algebra::TensorAlgebra,
16};
17
18#[derive(Debug)]
20pub struct RowBatchCoeffs<F> {
21 coeffs: Vec<F>,
22}
23
24impl<F: Field> RowBatchCoeffs<F> {
25 pub fn new(coeffs: Vec<F>) -> Self {
26 Self { coeffs }
27 }
28
29 pub fn coeffs(&self) -> &[F] {
30 &self.coeffs
31 }
32}
33
34#[derive(Debug, Clone)]
41pub struct RingSwitchEqInd<FSub, F> {
42 z_vals: Arc<[F]>,
44 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
45 mixing_coeff: F,
46 _marker: PhantomData<FSub>,
47}
48
49pub struct RingSwitchEqIndPrecompute<'a, F: Field, Mem: ComputeMemory<F>> {
50 evals: Mem::FSliceMut<'a>,
51 row_batching_query_expansion: Mem::FSlice<'a>,
52 mle: Mem::FSliceMut<'a>,
53}
54
55impl<FSub, F> RingSwitchEqInd<FSub, F>
56where
57 FSub: Field,
58 F: ExtensionField<FSub>,
59{
60 pub fn new(
61 z_vals: Arc<[F]>,
62 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
63 mixing_coeff: F,
64 ) -> Result<Self, Error> {
65 if row_batch_coeffs.coeffs.len() < F::DEGREE {
66 bail!(Error::InvalidArgs(
67 "RingSwitchEqInd::new expects row_batch_coeffs length greater than or equal to \
68 the extension degree"
69 .into()
70 ));
71 }
72
73 Ok(Self {
74 z_vals,
75 row_batch_coeffs,
76 mixing_coeff,
77 _marker: PhantomData,
78 })
79 }
80
81 pub fn precompute_values<'a, Hal: ComputeLayer<F>, HostAllocatorType, DeviceAllocatorType>(
82 z_vals: Arc<[F]>,
83 row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
84 mixing_coeff: F,
85 kappa: usize,
86 hal: &Hal,
87 dev_alloc: &'a DeviceAllocatorType,
88 host_alloc: &HostAllocatorType,
89 ) -> Result<RingSwitchEqIndPrecompute<'a, F, Hal::DevMem>, Error>
90 where
91 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
92 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
93 {
94 let extension_degree = 1 << (kappa);
95
96 let mut row_batching_query_expansion = dev_alloc.alloc(extension_degree)?;
97
98 hal.copy_h2d(
99 &row_batch_coeffs.coeffs()[0..extension_degree],
100 &mut row_batching_query_expansion,
101 )?;
102
103 let row_batching_query_expansion = Hal::DevMem::to_const(row_batching_query_expansion);
104
105 let n_vars = z_vals.len();
106 let mut evals = dev_alloc.alloc(1 << n_vars)?;
107
108 {
109 let host_val = host_alloc.alloc(1)?;
110 host_val[0] = mixing_coeff;
111 let mut dev_val = Hal::DevMem::slice_power_of_two_mut(&mut evals, 1);
112 hal.copy_h2d(host_val, &mut dev_val)?;
113 }
114
115 let mle = dev_alloc.alloc(evals.len())?;
116
117 Ok(RingSwitchEqIndPrecompute {
118 evals,
119 row_batching_query_expansion,
120 mle,
121 })
122 }
123
124 pub fn multilinear_extension<
125 'a,
126 Mem: ComputeMemory<F>,
127 Exec: ComputeLayerExecutor<F, DevMem = Mem>,
128 >(
129 &self,
130 precompute: RingSwitchEqIndPrecompute<'a, F, Mem>,
131 exec: &mut Exec,
132 tower_level: usize,
133 ) -> Result<Mem::FSlice<'a>, Error> {
134 let RingSwitchEqIndPrecompute {
135 mut evals,
136 row_batching_query_expansion,
137 mut mle,
138 } = precompute;
139
140 exec.tensor_expand(0, &self.z_vals, &mut evals)?;
141
142 let subfield_vector = SubfieldSlice::new(Mem::as_const(&evals), tower_level);
143
144 exec.fold_right(subfield_vector, row_batching_query_expansion, &mut mle)?;
145
146 Ok(Mem::to_const(mle))
147 }
148}
149
150impl<FSub, F> MultivariatePoly<F> for RingSwitchEqInd<FSub, F>
151where
152 FSub: TowerField,
153 F: TowerField + PackedField<Scalar = F> + PackedExtension<FSub>,
154{
155 fn n_vars(&self) -> usize {
156 self.z_vals.len()
157 }
158
159 fn degree(&self) -> usize {
160 self.n_vars()
161 }
162
163 fn evaluate(&self, query: &[F]) -> Result<F, PolynomialError> {
164 if query.len() != self.n_vars() {
165 bail!(PolynomialError::IncorrectQuerySize {
166 expected: self.n_vars(),
167 actual: query.len(),
168 });
169 };
170
171 let tensor_eval = iter::zip(&*self.z_vals, query).fold(
172 <TensorAlgebra<FSub, F>>::from_vertical(self.mixing_coeff),
173 |eval, (&vert_i, &hztl_i)| {
174 let vert_scaled = eval.clone().scale_vertical(vert_i);
177 let hztl_scaled = eval.clone().scale_horizontal(hztl_i);
178 eval + &vert_scaled + &hztl_scaled
179 },
180 );
181
182 let folded_eval = tensor_eval.fold_vertical(&self.row_batch_coeffs.coeffs);
183 Ok(folded_eval)
184 }
185
186 fn binary_tower_level(&self) -> usize {
187 F::TOWER_LEVEL
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use binius_compute::{ComputeData, ComputeHolder, cpu::layer::CpuLayerHolder};
194 use binius_field::{BinaryField8b, BinaryField128b};
195 use binius_math::{MultilinearQuery, eq_ind_partial_eval};
196 use iter::repeat_with;
197 use rand::{SeedableRng, prelude::StdRng};
198
199 use super::*;
200
201 #[test]
202 fn test_evaluation_consistency() {
203 type FS = BinaryField8b;
204 type F = BinaryField128b;
205 let kappa = <TensorAlgebra<FS, F>>::kappa();
206 let ell = 10;
207 let mut rng = StdRng::seed_from_u64(0);
208
209 let n_vars = ell - kappa;
210 let z_vals = repeat_with(|| <F as Field>::random(&mut rng))
211 .take(n_vars)
212 .collect::<Arc<[_]>>();
213
214 let row_batch_challenges = repeat_with(|| <F as Field>::random(&mut rng))
215 .take(kappa)
216 .collect::<Vec<_>>();
217
218 let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
219 MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
220 ));
221
222 let eval_point = repeat_with(|| <F as Field>::random(&mut rng))
223 .take(n_vars)
224 .collect::<Vec<_>>();
225
226 let mixing_coeff = <F as Field>::random(&mut rng);
227
228 let mut compute_holder = CpuLayerHolder::new(1 << 10, 1 << 10);
229
230 let compute_data = compute_holder.to_data();
231
232 let ComputeData {
233 hal,
234 dev_alloc,
235 host_alloc,
236 ..
237 } = compute_data;
238
239 let precompute = RingSwitchEqInd::<FS, _>::precompute_values(
240 z_vals.clone(),
241 row_batch_coeffs.clone(),
242 mixing_coeff,
243 kappa,
244 hal,
245 &dev_alloc,
246 &host_alloc,
247 )
248 .unwrap();
249
250 let rs_eq = RingSwitchEqInd::<FS, _>::new(z_vals, row_batch_coeffs, mixing_coeff).unwrap();
251
252 let val1 = rs_eq.evaluate(&eval_point).unwrap();
253
254 hal.execute(|exec| {
255 let mle = rs_eq
256 .multilinear_extension(precompute, exec, BinaryField8b::TOWER_LEVEL)
257 .unwrap();
258
259 let mle = SubfieldSlice::new(mle, BinaryField128b::TOWER_LEVEL);
260
261 let query = eq_ind_partial_eval(&eval_point);
262
263 let val2 = exec.inner_product(mle, &query).unwrap();
264
265 assert_eq!(val1, val2);
266 Ok(vec![])
267 })
268 .unwrap();
269 }
270}