1use std::{iter, slice};
4
5use binius_compute::{
6 ComputeLayer, ComputeLayerExecutor, ComputeMemory, FSlice, KernelBuffer, KernelExecutor,
7 KernelMemMap, SizedSlice, SlicesBatch, alloc::ComputeAllocator, cpu::CpuMemory,
8};
9use binius_field::{Field, TowerField, util::powers};
10use binius_math::{CompositionPoly, EvaluationOrder, evaluate_univariate};
11use binius_utils::bail;
12use itertools::Itertools;
13
14use crate::{
15 composition::{BivariateProduct, IndexComposition},
16 protocols::sumcheck::{
17 CompositeSumClaim, Error, RoundCoeffs, SumcheckClaim, prove::SumcheckProver,
18 },
19};
20
21pub struct BivariateSumcheckProver<
28 'a,
29 'b,
30 F: Field,
31 Hal: ComputeLayer<F>,
32 DeviceAllocatorType,
33 HostAllocatorType,
34> where
35 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
36 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
37 'a: 'b,
38{
39 hal: &'a Hal,
40 dev_alloc: &'a DeviceAllocatorType,
41 host_alloc: &'a HostAllocatorType,
42 n_vars_initial: usize,
43 n_vars_remaining: usize,
44 multilins: Vec<SumcheckMultilinear<'b, F, Hal::DevMem>>,
45 compositions: Vec<IndexComposition<BivariateProduct, 2>>,
46 last_coeffs_or_sums: PhaseState<F>,
47}
48
49impl<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType>
50 BivariateSumcheckProver<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType>
51where
52 F: TowerField,
53 Hal: ComputeLayer<F>,
54 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
55 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
56{
57 pub fn new(
58 hal: &'a Hal,
59 dev_alloc: &'a DeviceAllocatorType,
60 host_alloc: &'a HostAllocatorType,
61 claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
62 multilins: Vec<FSlice<'b, F, Hal>>,
63 ) -> Result<Self, Error> {
64 let n_vars = claim.n_vars();
65
66 assert_eq!(claim.n_multilinears(), multilins.len());
68 for multilin in &multilins {
69 if multilin.len() != 1 << n_vars {
70 bail!(Error::NumberOfVariablesMismatch);
71 }
72 }
73
74 let multilins = multilins
76 .into_iter()
77 .map(SumcheckMultilinear::PreFold)
78 .collect();
79
80 let (compositions, sums) = claim
81 .composite_sums()
82 .iter()
83 .map(|CompositeSumClaim { composition, sum }| (composition.clone(), *sum))
84 .unzip();
85
86 Ok(Self {
87 hal,
88 dev_alloc,
89 host_alloc,
90 n_vars_initial: n_vars,
91 n_vars_remaining: n_vars,
92 multilins,
93 compositions,
94 last_coeffs_or_sums: PhaseState::InitialSums(sums),
95 })
96 }
97
98 pub fn required_host_memory(
100 claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
101 ) -> usize {
102 claim.n_multilinears()
105 }
106
107 pub fn required_device_memory(
109 claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
110 ) -> usize {
111 claim.n_multilinears() * (1 << (claim.n_vars() - 1))
114 }
115}
116
117impl<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType> SumcheckProver<F>
118 for BivariateSumcheckProver<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType>
119where
120 F: TowerField,
121 Hal: ComputeLayer<F>,
122 DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
123 HostAllocatorType: ComputeAllocator<F, CpuMemory>,
124{
125 fn n_vars(&self) -> usize {
126 self.n_vars_initial
127 }
128
129 fn evaluation_order(&self) -> EvaluationOrder {
130 EvaluationOrder::HighToLow
131 }
132
133 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
134 let multilins = self
135 .multilins
136 .iter()
137 .map(|multilin| multilin.const_slice())
138 .collect::<Vec<_>>();
139 let round_evals = calculate_round_evals(
140 self.hal,
141 self.n_vars_remaining,
142 batch_coeff,
143 &multilins,
144 &self.compositions,
145 )?;
146
147 let batched_sum = match self.last_coeffs_or_sums {
148 PhaseState::Coeffs(_) => {
149 bail!(Error::ExpectedFold);
150 }
151 PhaseState::InitialSums(ref sums) => evaluate_univariate(sums, batch_coeff),
152 PhaseState::BatchedSum(sum) => sum,
153 };
154 let round_coeffs = calculate_round_coeffs_from_evals(batched_sum, round_evals);
155 self.last_coeffs_or_sums = PhaseState::Coeffs(round_coeffs.clone());
156
157 if self.compositions.is_empty() {
162 Ok(RoundCoeffs(vec![]))
163 } else {
164 Ok(round_coeffs)
165 }
166 }
167
168 fn fold(&mut self, challenge: F) -> Result<(), Error> {
169 use binius_compute::{FSlice, FSliceMut};
170
171 type PreparedExtrapolateLineArgs<'a, F, Hal> = (FSliceMut<'a, F, Hal>, FSlice<'a, F, Hal>);
172
173 if self.n_vars_remaining == 0 {
174 bail!(Error::ExpectedFinish);
175 }
176
177 match self.last_coeffs_or_sums {
179 PhaseState::Coeffs(ref coeffs) => {
180 let new_sum = evaluate_univariate(&coeffs.0, challenge);
181 self.last_coeffs_or_sums = PhaseState::BatchedSum(new_sum);
182 }
183 PhaseState::InitialSums(_) | PhaseState::BatchedSum(_) => {
184 bail!(Error::ExpectedExecution);
185 }
186 }
187
188 let prepared_extrapolate_line_ops =
189 self.multilins
190 .drain(..)
191 .map(
192 |multilin| -> Result<
193 PreparedExtrapolateLineArgs<'b, F, Hal>,
194 binius_compute::Error,
195 > {
196 match multilin {
197 SumcheckMultilinear::PreFold(evals) => {
198 debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
199 let (evals_0, evals_1) = Hal::DevMem::split_half(evals);
200 let mut folded_evals= self.dev_alloc.alloc(1 << (self.n_vars_remaining - 1))?;
203 self.hal.copy_d2d(evals_0, &mut folded_evals)?;
204 Ok((folded_evals, evals_1))
205 }
206 SumcheckMultilinear::PostFold(evals) => {
207 debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
208 let (evals_0, evals_1) = Hal::DevMem::split_half_mut(evals);
209 Ok((evals_0, Hal::DevMem::to_const(evals_1)))
210 }
211 }
212 },
213 )
214 .collect_vec();
215
216 let _ = self.hal.execute(|exec| {
218 self.multilins = exec.map(
219 prepared_extrapolate_line_ops.into_iter(),
220 |exec, extrapolate_line_args| {
221 let (mut evals_0, evals_1) = extrapolate_line_args?;
222 exec.extrapolate_line(&mut evals_0, evals_1, challenge)?;
223 Ok(SumcheckMultilinear::<F, Hal::DevMem>::PostFold(evals_0))
224 },
225 )?;
226
227 Ok(Vec::new())
228 })?;
229
230 self.n_vars_remaining -= 1;
231 Ok(())
232 }
233
234 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
235 match self.last_coeffs_or_sums {
236 PhaseState::Coeffs(_) => {
237 bail!(Error::ExpectedFold);
238 }
239 _ => match self.n_vars_remaining {
240 0 => {}
241 _ => bail!(Error::ExpectedExecution),
242 },
243 };
244
245 let buffer = self.host_alloc.alloc(self.multilins.len())?;
247 for (multilin, dst_i) in iter::zip(self.multilins, &mut *buffer) {
248 let vals = multilin.const_slice();
249 debug_assert_eq!(vals.len(), 1);
250 self.hal.copy_d2h(vals, slice::from_mut(dst_i))?;
251 }
252 Ok(buffer.to_vec())
253 }
254}
255
256#[derive(Debug, Clone)]
258pub enum SumcheckMultilinear<'a, F, Mem: ComputeMemory<F>> {
259 PreFold(Mem::FSlice<'a>),
260 PostFold(Mem::FSliceMut<'a>),
261}
262
263impl<'a, F, Mem: ComputeMemory<F>> SumcheckMultilinear<'a, F, Mem> {
264 pub fn const_slice(&self) -> Mem::FSlice<'_> {
265 match self {
266 Self::PreFold(slice) => Mem::narrow(slice),
267 Self::PostFold(slice) => Mem::as_const(slice),
268 }
269 }
270}
271
272pub fn calculate_round_evals<'a, F: TowerField, HAL: ComputeLayer<F>>(
304 hal: &HAL,
305 n_vars: usize,
306 batch_coeff: F,
307 multilins: &[FSlice<'a, F, HAL>],
308 compositions: &[IndexComposition<BivariateProduct, 2>],
309) -> Result<[F; 2], Error> {
310 let prod_evaluators = compositions
311 .iter()
312 .map(|composition| hal.compile_expr(&CompositionPoly::<F>::expression(&composition)))
313 .collect::<Result<Vec<_>, _>>()?;
314
315 let split_n_vars = n_vars - 1;
317 let kernel_mappings = multilins
318 .iter()
319 .copied()
320 .flat_map(|multilin| {
321 let (lo_half, hi_half) = HAL::DevMem::split_half(multilin);
322 [
323 KernelMemMap::Chunked {
324 data: lo_half,
325 log_min_chunk_size: 0,
326 },
327 KernelMemMap::Chunked {
328 data: hi_half,
329 log_min_chunk_size: 0,
330 },
331 KernelMemMap::Local {
333 log_size: split_n_vars,
334 },
335 ]
336 })
337 .collect();
338
339 let batch_coeffs = powers(batch_coeff)
340 .take(compositions.len())
341 .collect::<Vec<_>>();
342
343 let evals = hal.execute(|exec| {
344 exec.accumulate_kernels(
345 |local_exec, log_chunks, mut buffers| {
346 let log_chunk_size = split_n_vars - log_chunks;
347
348 let mut acc_1 = local_exec.decl_value(F::ZERO)?;
350 {
351 let eval_1s = SlicesBatch::new(
352 (0..multilins.len())
353 .map(|i| buffers[i * 3 + 1].to_ref())
354 .collect(),
355 1 << log_chunk_size,
356 );
357 for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
358 local_exec.sum_composition_evals(
359 &eval_1s,
360 evaluator,
361 batch_coeff,
362 &mut acc_1,
363 )?;
364 }
365 }
366
367 for group in buffers.chunks_mut(3) {
369 let Ok(
370 [
371 KernelBuffer::Ref(evals_0),
372 KernelBuffer::Ref(evals_1),
373 KernelBuffer::Mut(evals_inf),
374 ],
375 ) = TryInto::<&mut [_; 3]>::try_into(group)
376 else {
377 panic!(
378 "exec_kernels did not create the mapped buffers struct according to the mapping"
379 );
380 };
381 local_exec.add(log_chunk_size, *evals_0, *evals_1, evals_inf)?;
382 }
383
384 let mut acc_inf = local_exec.decl_value(F::ZERO)?;
386 let eval_infs = SlicesBatch::new(
387 (0..multilins.len())
388 .map(|i| buffers[i * 3 + 2].to_ref())
389 .collect(),
390 1 << log_chunk_size,
391 );
392 for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
393 local_exec.sum_composition_evals(
394 &eval_infs,
395 evaluator,
396 batch_coeff,
397 &mut acc_inf,
398 )?;
399 }
400
401 Ok(vec![acc_1, acc_inf])
402 },
403 kernel_mappings,
404 )
405 })?;
406 let evals = TryInto::<[F; 2]>::try_into(evals).expect("kernel returns two values");
407 Ok(evals)
408}
409
410fn calculate_round_coeffs_from_evals<F: Field>(sum: F, evals: [F; 2]) -> RoundCoeffs<F> {
411 let [y_1, y_inf] = evals;
412 let y_0 = sum - y_1;
413
414 let c_0 = y_0;
421 let c_2 = y_inf;
422 let c_1 = y_1 - c_0 - c_2;
423 RoundCoeffs(vec![c_0, c_1, c_2])
424}
425
426#[derive(Debug)]
427pub enum PhaseState<F: Field> {
428 Coeffs(RoundCoeffs<F>),
429 InitialSums(Vec<F>),
430 BatchedSum(F),
431}
432
433#[cfg(test)]
434mod tests {
435 use binius_compute::cpu::layer::CpuLayerHolder;
436 use binius_compute_test_utils::bivariate_sumcheck::{
437 generic_test_bivariate_sumcheck_prove_verify, generic_test_calculate_round_evals,
438 };
439 use binius_fast_compute::layer::FastCpuLayerHolder;
440 use binius_field::{
441 BinaryField128b, arch::OptimalUnderlier, as_packed_field::PackedType,
442 tower::CanonicalTowerFamily,
443 };
444 use binius_math::B128;
445
446 #[test]
447 fn test_calculate_round_evals() {
448 let compute_holder = CpuLayerHolder::new(1 << 11, 1 << 10);
449 let n_vars = 8;
450 generic_test_calculate_round_evals(compute_holder, n_vars)
451 }
452
453 #[test]
454 fn test_calculate_round_evals_fast_cpu() {
455 type F = BinaryField128b;
456 type Packed = PackedType<OptimalUnderlier, F>;
457
458 let compute_holder =
459 FastCpuLayerHolder::<CanonicalTowerFamily, Packed>::new(1 << 11, 1 << 10);
460 let n_vars = 8;
461 generic_test_calculate_round_evals(compute_holder, n_vars)
462 }
463
464 #[test]
465 fn test_bivariate_sumcheck_prove_verify() {
466 let n_vars = 8;
467 let n_multilins = 8;
468 let n_compositions = 8;
469
470 let compute_holder = CpuLayerHolder::<B128>::new(1 << 13, 1 << 12);
471 generic_test_bivariate_sumcheck_prove_verify(
472 compute_holder,
473 n_vars,
474 n_multilins,
475 n_compositions,
476 )
477 }
478
479 #[test]
480 fn test_bivariate_sumcheck_prove_verify_fast() {
481 type F = BinaryField128b;
482 type Packed = PackedType<OptimalUnderlier, F>;
483
484 let n_vars = 8;
485 let n_multilins = 8;
486 let n_compositions = 8;
487 let compute_holder =
488 FastCpuLayerHolder::<CanonicalTowerFamily, Packed>::new(1 << 13, 1 << 12);
489
490 generic_test_bivariate_sumcheck_prove_verify(
491 compute_holder,
492 n_vars,
493 n_multilins,
494 n_compositions,
495 )
496 }
497}