1use std::{iter, mem, slice};
4
5use binius_compute::{
6 ComputeLayer, ComputeLayerExecutor, ComputeMemory, FSlice, KernelBuffer, KernelExecutor,
7 KernelMemMap, SizedSlice, SlicesBatch, alloc::ComputeAllocator, cpu::CpuMemory,
8};
9use binius_field::{
10 Field, TowerField,
11 util::{eq, powers},
12};
13use binius_math::{ArithCircuit, CompositionPoly, EvaluationOrder, evaluate_univariate};
14use binius_utils::bail;
15use itertools::Itertools;
16
17use super::bivariate_product::{PhaseState, SumcheckMultilinear};
18use crate::{
19 composition::{BivariateProduct, IndexComposition},
20 protocols::sumcheck::{
21 CompositeSumClaim, EqIndSumcheckClaim, Error, RoundCoeffs, prove::SumcheckProver,
22 },
23};
24
25pub struct BivariateMLEcheckProver<
32 'a,
33 F: Field,
34 Hal: ComputeLayer<F>,
35 HostAllocatorType,
36 DeviceAllocatorType,
37> where
38 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
39 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
40{
41 hal: &'a Hal,
42 dev_alloc: &'a DeviceAllocatorType,
43 host_alloc: &'a HostAllocatorType,
44 n_vars_initial: usize,
45 n_vars_remaining: usize,
46 multilins: Vec<SumcheckMultilinear<'a, F, Hal::DevMem>>,
47 compositions: Vec<IndexComposition<BivariateProduct, 2>>,
48 last_coeffs_or_sums: PhaseState<F>,
49 eq_ind_prefix_eval: F,
50 eq_ind_partial_evals: Option<SumcheckMultilinear<'a, F, Hal::DevMem>>,
52 eq_ind_challenges: Vec<F>,
53}
54
55impl<'a, F, Hal, HostAllocatorType, DeviceAllocatorType>
56 BivariateMLEcheckProver<'a, F, Hal, HostAllocatorType, DeviceAllocatorType>
57where
58 F: TowerField,
59 Hal: ComputeLayer<F>,
60 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
61 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
62{
63 #[allow(clippy::too_many_arguments)]
64 pub fn new(
65 hal: &'a Hal,
66 dev_alloc: &'a DeviceAllocatorType,
67 host_alloc: &'a HostAllocatorType,
68 claim: &EqIndSumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
69 multilins: Vec<FSlice<'a, F, Hal>>,
70 eq_ind_partial_evals: FSlice<'a, F, Hal>,
73 eq_ind_challenges: Vec<F>,
74 ) -> Result<Self, Error> {
75 let n_vars = claim.n_vars();
76
77 assert_eq!(claim.n_multilinears(), multilins.len());
79 for multilin in &multilins {
80 if multilin.len() != 1 << n_vars {
81 bail!(Error::NumberOfVariablesMismatch);
82 }
83 }
84
85 let multilins = multilins
87 .into_iter()
88 .map(SumcheckMultilinear::PreFold)
89 .collect();
90
91 let (compositions, sums) = claim
92 .eq_ind_composite_sums()
93 .iter()
94 .map(|CompositeSumClaim { composition, sum }| (composition.clone(), *sum))
95 .unzip();
96
97 if eq_ind_partial_evals.len() != 1 << n_vars.saturating_sub(1) {
100 bail!(Error::IncorrectEqIndPartialEvalsSize);
101 }
102
103 let eq_ind_partial_evals_buffer = SumcheckMultilinear::PreFold(eq_ind_partial_evals);
104
105 Ok(Self {
106 hal,
107 dev_alloc,
108 host_alloc,
109 n_vars_initial: n_vars,
110 n_vars_remaining: n_vars,
111 multilins,
112 compositions,
113 last_coeffs_or_sums: PhaseState::InitialSums(sums),
114 eq_ind_prefix_eval: F::ONE,
115 eq_ind_partial_evals: Some(eq_ind_partial_evals_buffer),
116 eq_ind_challenges,
117 })
118 }
119
120 fn update_eq_ind_prefix_eval(&mut self, challenge: F) {
121 self.eq_ind_prefix_eval *= eq(self.eq_ind_challenges[self.n_vars_remaining - 1], challenge);
123 }
124
125 pub fn required_host_memory(
127 claim: &EqIndSumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
128 ) -> usize {
129 claim.n_multilinears() + 1
132 }
133
134 pub fn required_device_memory(
136 claim: &EqIndSumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
137 with_eq_ind_partial_evals: bool,
138 ) -> usize {
139 let n_multilinears = claim.n_multilinears() + if with_eq_ind_partial_evals { 0 } else { 1 };
142 n_multilinears * (1 << (claim.n_vars() - 1))
143 }
144
145 pub fn fold_multilinears(&mut self, challenge: F) -> Result<(), Error> {
146 use binius_compute::{FSlice, FSliceMut};
147
148 type PreparedExtrapolateLineArgs<'a, F, Hal> = (FSliceMut<'a, F, Hal>, FSlice<'a, F, Hal>);
149
150 let prepared_extrapolate_line_ops = self
151 .multilins
152 .drain(..)
153 .map(
154 |multilin| -> Result<
155 PreparedExtrapolateLineArgs<'a, F, Hal>,
156 binius_compute::Error,
157 > {
158 match multilin {
159 SumcheckMultilinear::PreFold(evals) => {
160 debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
161 let (evals_0, evals_1) =
162 Hal::DevMem::split_half(evals);
163
164 let mut folded_evals =
166 self.dev_alloc.alloc(1 << (self.n_vars_remaining - 1))?;
167 self.hal.copy_d2d(evals_0, &mut folded_evals)?;
168 Ok((folded_evals, evals_1))
169 }
170 SumcheckMultilinear::PostFold(evals) => {
171 debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
172 let (evals_0, evals_1) =
173 Hal::DevMem::split_half_mut(evals);
174 Ok((evals_0, Hal::DevMem::to_const(evals_1)))
175 }
176 }
177 },
178 )
179 .collect_vec();
180 self.hal.execute(|exec| {
181 self.multilins = exec.map(
182 prepared_extrapolate_line_ops.into_iter(),
183 |exec, extrapolate_line_args| {
184 let (mut evals_0, evals_1) = extrapolate_line_args?;
185 exec.extrapolate_line(&mut evals_0, evals_1, challenge)?;
186 Ok(SumcheckMultilinear::<F, Hal::DevMem>::PostFold(evals_0))
187 },
188 )?;
189
190 Ok(Vec::new())
191 })?;
192 Ok(())
193 }
194
195 pub fn fold_eq_ind(&mut self) -> Result<(), Error> {
196 let eq_ind_partial_evals = mem::take(&mut self.eq_ind_partial_evals).expect("exist");
197
198 let split_n_vars = self.n_vars_remaining - 2;
199
200 let (mut evals_0, evals_1) = match eq_ind_partial_evals {
201 SumcheckMultilinear::PreFold(evals) => {
202 let (evals_0, evals_1) = Hal::DevMem::split_half(evals);
203
204 let mut buffer = self.dev_alloc.alloc(evals_0.len())?;
205 self.hal.copy_d2d(evals_0, &mut buffer)?;
206
207 (buffer, evals_1)
208 }
209 SumcheckMultilinear::PostFold(evals) => {
210 let (evals_0, evals_1) = Hal::DevMem::split_half_mut(evals);
211
212 let evals_1 = Hal::DevMem::to_const(evals_1);
213
214 (evals_0, evals_1)
215 }
216 };
217
218 let kernel_mappings = vec![
219 KernelMemMap::ChunkedMut {
220 data: Hal::DevMem::to_owned_mut(&mut evals_0),
221 log_min_chunk_size: 0,
222 },
223 KernelMemMap::Chunked {
224 data: Hal::DevMem::narrow(&evals_1),
225 log_min_chunk_size: 0,
226 },
227 ];
228
229 let _ = self.hal.execute(|exec| {
230 exec.map_kernels(
231 |local_exec, log_chunks, mut buffers| {
232 let log_chunk_size = split_n_vars - log_chunks;
233
234 let Ok([KernelBuffer::Mut(evals_0), KernelBuffer::Ref(evals_1)]) =
235 TryInto::<&mut [_; 2]>::try_into(buffers.as_mut_slice())
236 else {
237 panic!(
238 "exec_kernels did not create the mapped buffers struct according to the mapping"
239 );
240 };
241 local_exec.add_assign(log_chunk_size, *evals_1, evals_0)?;
242
243 Ok(())
244 },
245 kernel_mappings,
246 )?;
247
248 Ok(Vec::new())
249 })?;
250
251 self.eq_ind_partial_evals = Some(SumcheckMultilinear::PostFold(evals_0));
252
253 Ok(())
254 }
255}
256
257impl<F, Hal, HostAllocatorType, DeviceAllocatorType> SumcheckProver<F>
258 for BivariateMLEcheckProver<'_, F, Hal, HostAllocatorType, DeviceAllocatorType>
259where
260 F: TowerField,
261 Hal: ComputeLayer<F>,
262 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
263 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
264{
265 fn n_vars(&self) -> usize {
266 self.n_vars_initial
267 }
268
269 fn evaluation_order(&self) -> EvaluationOrder {
270 EvaluationOrder::HighToLow
271 }
272
273 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
274 let multilins = self
275 .multilins
276 .iter()
277 .map(|multilin| multilin.const_slice())
278 .collect::<Vec<_>>();
279
280 let round_evals = calculate_round_evals(
281 self.hal,
282 self.n_vars_remaining,
283 batch_coeff,
284 &multilins,
285 self.eq_ind_partial_evals
286 .as_ref()
287 .expect("eq_ind_partial_evals not None")
288 .const_slice(),
289 &self.compositions,
290 )?;
291
292 let batched_sum = match self.last_coeffs_or_sums {
293 PhaseState::Coeffs(_) => {
294 bail!(Error::ExpectedFold);
295 }
296 PhaseState::InitialSums(ref sums) => evaluate_univariate(sums, batch_coeff),
297 PhaseState::BatchedSum(sum) => sum,
298 };
299
300 let alpha = self.eq_ind_challenges[self.n_vars_remaining - 1];
301
302 let prime_coeffs = calculate_round_coeffs_from_evals(batched_sum, round_evals, alpha);
303
304 self.last_coeffs_or_sums = PhaseState::Coeffs(prime_coeffs.clone());
305
306 let prime_coeffs_scaled_by_constant_term = prime_coeffs.clone() * (F::ONE - alpha);
309
310 let mut prime_coeffs_scaled_by_linear_term = prime_coeffs * (alpha.double() - F::ONE);
311
312 prime_coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); let coeffs = (prime_coeffs_scaled_by_constant_term + &prime_coeffs_scaled_by_linear_term)
315 * self.eq_ind_prefix_eval;
316
317 Ok(coeffs)
318 }
319
320 fn fold(&mut self, challenge: F) -> Result<(), Error> {
321 if self.n_vars_remaining == 0 {
322 bail!(Error::ExpectedFinish);
323 }
324
325 match self.last_coeffs_or_sums {
327 PhaseState::Coeffs(ref coeffs) => {
328 let new_sum = evaluate_univariate(&coeffs.0, challenge);
329 self.last_coeffs_or_sums = PhaseState::BatchedSum(new_sum);
330 }
331 PhaseState::InitialSums(_) | PhaseState::BatchedSum(_) => {
332 bail!(Error::ExpectedExecution);
333 }
334 }
335
336 self.update_eq_ind_prefix_eval(challenge);
337
338 self.fold_multilinears(challenge)?;
339
340 if self.n_vars_remaining - 1 != 0 {
341 self.fold_eq_ind()?;
342 }
343
344 self.n_vars_remaining -= 1;
345 Ok(())
346 }
347
348 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
349 match self.last_coeffs_or_sums {
350 PhaseState::Coeffs(_) => {
351 bail!(Error::ExpectedFold);
352 }
353 _ => match self.n_vars_remaining {
354 0 => {}
355 _ => bail!(Error::ExpectedExecution),
356 },
357 };
358
359 let buffer = self.host_alloc.alloc(self.multilins.len())?;
361 for (multilin, dst_i) in iter::zip(self.multilins, &mut *buffer) {
362 let vals = multilin.const_slice();
363 debug_assert_eq!(vals.len(), 1);
364 self.hal.copy_d2h(vals, slice::from_mut(dst_i))?;
365 }
366
367 let mut res = buffer.to_vec();
368
369 res.push(self.eq_ind_prefix_eval);
370
371 Ok(res)
372 }
373}
374
375fn calculate_round_coeffs_from_evals<F: Field>(sum: F, evals: [F; 2], alpha: F) -> RoundCoeffs<F> {
376 let [y_1, y_inf] = evals;
377
378 let y_0 = (sum - y_1 * alpha) * (F::ONE - alpha).invert_or_zero();
379
380 let c_0 = y_0;
386 let c_2 = y_inf;
387 let c_1 = y_1 - c_0 - c_2;
388 RoundCoeffs(vec![c_0, c_1, c_2])
389}
390
391fn calculate_round_evals<'a, F: TowerField, Hal: ComputeLayer<F>>(
392 hal: &Hal,
393 n_vars: usize,
394 batch_coeff: F,
395 multilins: &[FSlice<'a, F, Hal>],
396 eq_ind_partial_evals: FSlice<'a, F, Hal>,
397 compositions: &[IndexComposition<BivariateProduct, 2>],
398) -> Result<[F; 2], Error> {
399 let prod_evaluators = compositions
400 .iter()
401 .map(|composition| {
402 let mut prod_expr = CompositionPoly::<F>::expression(&composition);
403 prod_expr *= ArithCircuit::var(multilins.len());
405
406 hal.compile_expr(&prod_expr)
407 })
408 .collect::<Result<Vec<_>, _>>()?;
409
410 let split_n_vars = n_vars - 1;
412 let mut kernel_mappings = multilins
413 .iter()
414 .copied()
415 .flat_map(|multilin| {
416 let (lo_half, hi_half) = Hal::DevMem::split_half(multilin);
417 [
418 KernelMemMap::Chunked {
419 data: lo_half,
420 log_min_chunk_size: 0,
421 },
422 KernelMemMap::Chunked {
423 data: hi_half,
424 log_min_chunk_size: 0,
425 },
426 KernelMemMap::Local {
428 log_size: split_n_vars,
429 },
430 ]
431 })
432 .collect::<Vec<_>>();
433
434 kernel_mappings.push(KernelMemMap::Chunked {
435 data: eq_ind_partial_evals,
436 log_min_chunk_size: 0,
437 });
438
439 let batch_coeffs = powers(batch_coeff)
440 .take(compositions.len())
441 .collect::<Vec<_>>();
442
443 let evals = hal.execute(|exec| {
444 exec.accumulate_kernels(
445 |local_exec, log_chunks, mut buffers| {
446 let log_chunk_size = split_n_vars - log_chunks;
447
448 let eq_ind = buffers.pop().expect(
449 "The presence of eq_ind in the buffer is due to it being added earlier in the code.",
450 );
451
452 let mut acc_1 = local_exec.decl_value(F::ZERO)?;
454 {
455 let mut eval_1s_with_eq_ind = (0..multilins.len())
456 .map(|i| buffers[i * 3 + 1].to_ref())
457 .collect::<Vec<_>>();
458
459 eval_1s_with_eq_ind.push(eq_ind.to_ref());
460
461 let eval_1s_with_eq_ind =
462 SlicesBatch::new(eval_1s_with_eq_ind, 1 << log_chunk_size);
463
464 for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
465 local_exec.sum_composition_evals(
466 &eval_1s_with_eq_ind,
467 evaluator,
468 batch_coeff,
469 &mut acc_1,
470 )?;
471 }
472 }
473
474 for group in buffers.chunks_mut(3) {
476 let Ok(
477 [
478 KernelBuffer::Ref(evals_0),
479 KernelBuffer::Ref(evals_1),
480 KernelBuffer::Mut(evals_inf),
481 ],
482 ) = TryInto::<&mut [_; 3]>::try_into(group)
483 else {
484 panic!(
485 "exec_kernels did not create the mapped buffers struct according to the mapping"
486 );
487 };
488 local_exec.add(log_chunk_size, *evals_0, *evals_1, evals_inf)?;
489 }
490
491 let mut acc_inf = local_exec.decl_value(F::ZERO)?;
493 let mut eval_infs_with_eq_ind = (0..multilins.len())
494 .map(|i| buffers[i * 3 + 2].to_ref())
495 .collect::<Vec<_>>();
496
497 eval_infs_with_eq_ind.push(eq_ind.to_ref());
498
499 let eval_infs_with_eq_ind =
500 SlicesBatch::new(eval_infs_with_eq_ind, 1 << log_chunk_size);
501
502 for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
503 local_exec.sum_composition_evals(
504 &eval_infs_with_eq_ind,
505 evaluator,
506 batch_coeff,
507 &mut acc_inf,
508 )?;
509 }
510
511 Ok(vec![acc_1, acc_inf])
512 },
513 kernel_mappings,
514 )
515 })?;
516
517 let evals = TryInto::<[F; 2]>::try_into(evals).expect("kernel returns two values");
518 Ok(evals)
519}
520
521#[cfg(test)]
522mod tests {
523 use binius_compute::cpu::layer::CpuLayerHolder;
524 use binius_compute_test_utils::bivariate_sumcheck::generic_test_bivariate_mlecheck_prove_verify;
525 use binius_fast_compute::layer::FastCpuLayerHolder;
526 use binius_field::{
527 arch::OptimalUnderlier, as_packed_field::PackedType, tower::CanonicalTowerFamily,
528 };
529 use binius_math::B128;
530
531 #[test]
532 fn test_bivariate_mlecheck_prove_verify() {
533 let compute_holder = CpuLayerHolder::<B128>::new(1 << 13, 1 << 12);
534 let n_vars = 8;
535 let n_multilins = 8;
536 let n_compositions = 8;
537 generic_test_bivariate_mlecheck_prove_verify(
538 compute_holder,
539 n_vars,
540 n_multilins,
541 n_compositions,
542 );
543 }
544
545 #[test]
546 fn test_bivariate_mlecheck_prove_verify_fast() {
547 type F = B128;
548 type P = PackedType<OptimalUnderlier, F>;
549
550 let compute_holder = FastCpuLayerHolder::<CanonicalTowerFamily, P>::new(1 << 13, 1 << 12);
551 let n_vars = 8;
552 let n_multilins = 8;
553 let n_compositions = 8;
554 generic_test_bivariate_mlecheck_prove_verify(
555 compute_holder,
556 n_vars,
557 n_multilins,
558 n_compositions,
559 );
560 }
561}