binius_core/protocols/sumcheck/v3/
bivariate_product.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{iter, slice};
4
5use binius_compute::{
6	ComputeLayer, ComputeLayerExecutor, ComputeMemory, FSlice, KernelBuffer, KernelExecutor,
7	KernelMemMap, SizedSlice, SlicesBatch, alloc::ComputeAllocator, cpu::CpuMemory,
8};
9use binius_field::{Field, TowerField, util::powers};
10use binius_math::{CompositionPoly, EvaluationOrder, evaluate_univariate};
11use binius_utils::bail;
12use itertools::Itertools;
13
14use crate::{
15	composition::{BivariateProduct, IndexComposition},
16	protocols::sumcheck::{
17		CompositeSumClaim, Error, RoundCoeffs, SumcheckClaim, prove::SumcheckProver,
18	},
19};
20
21/// Sumcheck prover implementation for the special case of bivariate product compositions over
22/// large-field multilinears.
23///
24/// This implements the [`SumcheckProver`] interface. The implementation uses a [`ComputeLayer`]
25/// instance for expensive operations and the input multilinears are provided as device memory
26/// slices.
27pub struct BivariateSumcheckProver<
28	'a,
29	'b,
30	F: Field,
31	Hal: ComputeLayer<F>,
32	DeviceAllocatorType,
33	HostAllocatorType,
34> where
35	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
36	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
37	'a: 'b,
38{
39	hal: &'a Hal,
40	dev_alloc: &'a DeviceAllocatorType,
41	host_alloc: &'a HostAllocatorType,
42	n_vars_initial: usize,
43	n_vars_remaining: usize,
44	multilins: Vec<SumcheckMultilinear<'b, F, Hal::DevMem>>,
45	compositions: Vec<IndexComposition<BivariateProduct, 2>>,
46	last_coeffs_or_sums: PhaseState<F>,
47}
48
49impl<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType>
50	BivariateSumcheckProver<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType>
51where
52	F: TowerField,
53	Hal: ComputeLayer<F>,
54	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
55	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
56{
57	pub fn new(
58		hal: &'a Hal,
59		dev_alloc: &'a DeviceAllocatorType,
60		host_alloc: &'a HostAllocatorType,
61		claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
62		multilins: Vec<FSlice<'b, F, Hal>>,
63	) -> Result<Self, Error> {
64		let n_vars = claim.n_vars();
65
66		// Check shape of multilinear witness inputs.
67		assert_eq!(claim.n_multilinears(), multilins.len());
68		for multilin in &multilins {
69			if multilin.len() != 1 << n_vars {
70				bail!(Error::NumberOfVariablesMismatch);
71			}
72		}
73
74		// Wrap multilinear witness inputs as SumcheckMultilinears.
75		let multilins = multilins
76			.into_iter()
77			.map(SumcheckMultilinear::PreFold)
78			.collect();
79
80		let (compositions, sums) = claim
81			.composite_sums()
82			.iter()
83			.map(|CompositeSumClaim { composition, sum }| (composition.clone(), *sum))
84			.unzip();
85
86		Ok(Self {
87			hal,
88			dev_alloc,
89			host_alloc,
90			n_vars_initial: n_vars,
91			n_vars_remaining: n_vars,
92			multilins,
93			compositions,
94			last_coeffs_or_sums: PhaseState::InitialSums(sums),
95		})
96	}
97
98	/// Returns the amount of host memory this sumcheck requires.
99	pub fn required_host_memory(
100		claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
101	) -> usize {
102		// In `finish()`, prover allocates a temporary host buffer for each of the fully folded
103		// multilinear evaluations.
104		claim.n_multilinears()
105	}
106
107	/// Returns the amount of device memory this sumcheck requires.
108	pub fn required_device_memory(
109		claim: &SumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
110	) -> usize {
111		// In `fold()`, prover allocates device buffers for each of the folded multilinears. They
112		// are each half of the size of the original multilinears.
113		claim.n_multilinears() * (1 << (claim.n_vars() - 1))
114	}
115}
116
117impl<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType> SumcheckProver<F>
118	for BivariateSumcheckProver<'a, 'b, F, Hal, DeviceAllocatorType, HostAllocatorType>
119where
120	F: TowerField,
121	Hal: ComputeLayer<F>,
122	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
123	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
124{
125	fn n_vars(&self) -> usize {
126		self.n_vars_initial
127	}
128
129	fn evaluation_order(&self) -> EvaluationOrder {
130		EvaluationOrder::HighToLow
131	}
132
133	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
134		let multilins = self
135			.multilins
136			.iter()
137			.map(|multilin| multilin.const_slice())
138			.collect::<Vec<_>>();
139		let round_evals = calculate_round_evals(
140			self.hal,
141			self.n_vars_remaining,
142			batch_coeff,
143			&multilins,
144			&self.compositions,
145		)?;
146
147		let batched_sum = match self.last_coeffs_or_sums {
148			PhaseState::Coeffs(_) => {
149				bail!(Error::ExpectedFold);
150			}
151			PhaseState::InitialSums(ref sums) => evaluate_univariate(sums, batch_coeff),
152			PhaseState::BatchedSum(sum) => sum,
153		};
154		let round_coeffs = calculate_round_coeffs_from_evals(batched_sum, round_evals);
155		self.last_coeffs_or_sums = PhaseState::Coeffs(round_coeffs.clone());
156
157		// This is because the batched verifier reads a batched polynomial from the transcript with
158		// max degree from all compositions being proven. If our compoisition is empty, we add a
159		// degree 0 polynomial to whatever is already being written to the transcript by the batch
160		// prover.
161		if self.compositions.is_empty() {
162			Ok(RoundCoeffs(vec![]))
163		} else {
164			Ok(round_coeffs)
165		}
166	}
167
168	fn fold(&mut self, challenge: F) -> Result<(), Error> {
169		use binius_compute::{FSlice, FSliceMut};
170
171		type PreparedExtrapolateLineArgs<'a, F, Hal> = (FSliceMut<'a, F, Hal>, FSlice<'a, F, Hal>);
172
173		if self.n_vars_remaining == 0 {
174			bail!(Error::ExpectedFinish);
175		}
176
177		// Update the stored multilinear sums.
178		match self.last_coeffs_or_sums {
179			PhaseState::Coeffs(ref coeffs) => {
180				let new_sum = evaluate_univariate(&coeffs.0, challenge);
181				self.last_coeffs_or_sums = PhaseState::BatchedSum(new_sum);
182			}
183			PhaseState::InitialSums(_) | PhaseState::BatchedSum(_) => {
184				bail!(Error::ExpectedExecution);
185			}
186		}
187
188		let prepared_extrapolate_line_ops =
189			self.multilins
190				.drain(..)
191				.map(
192					|multilin| -> Result<
193						PreparedExtrapolateLineArgs<'b, F, Hal>,
194						binius_compute::Error,
195					> {
196						match multilin {
197							SumcheckMultilinear::PreFold(evals) => {
198								debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
199								let (evals_0, evals_1) = Hal::DevMem::split_half(evals);
200								// Allocate new buffer for the folded evaluations and copy in
201								// evals_0.
202								let mut folded_evals= self.dev_alloc.alloc(1 << (self.n_vars_remaining - 1))?;
203								self.hal.copy_d2d(evals_0, &mut folded_evals)?;
204								Ok((folded_evals, evals_1))
205							}
206							SumcheckMultilinear::PostFold(evals) => {
207								debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
208								let (evals_0, evals_1) = Hal::DevMem::split_half_mut(evals);
209								Ok((evals_0, Hal::DevMem::to_const(evals_1)))
210							}
211						}
212					},
213				)
214				.collect_vec();
215
216		// Fold the multilinears
217		let _ = self.hal.execute(|exec| {
218			self.multilins = exec.map(
219				prepared_extrapolate_line_ops.into_iter(),
220				|exec, extrapolate_line_args| {
221					let (mut evals_0, evals_1) = extrapolate_line_args?;
222					exec.extrapolate_line(&mut evals_0, evals_1, challenge)?;
223					Ok(SumcheckMultilinear::<F, Hal::DevMem>::PostFold(evals_0))
224				},
225			)?;
226
227			Ok(Vec::new())
228		})?;
229
230		self.n_vars_remaining -= 1;
231		Ok(())
232	}
233
234	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
235		match self.last_coeffs_or_sums {
236			PhaseState::Coeffs(_) => {
237				bail!(Error::ExpectedFold);
238			}
239			_ => match self.n_vars_remaining {
240				0 => {}
241				_ => bail!(Error::ExpectedExecution),
242			},
243		};
244
245		// Copy the fully folded multilinear evaluations to the host.
246		let buffer = self.host_alloc.alloc(self.multilins.len())?;
247		for (multilin, dst_i) in iter::zip(self.multilins, &mut *buffer) {
248			let vals = multilin.const_slice();
249			debug_assert_eq!(vals.len(), 1);
250			self.hal.copy_d2h(vals, slice::from_mut(dst_i))?;
251		}
252		Ok(buffer.to_vec())
253	}
254}
255
256/// A multilinear polynomial that is being processed by a sumcheck prover.
257#[derive(Debug, Clone)]
258pub enum SumcheckMultilinear<'a, F, Mem: ComputeMemory<F>> {
259	PreFold(Mem::FSlice<'a>),
260	PostFold(Mem::FSliceMut<'a>),
261}
262
263impl<'a, F, Mem: ComputeMemory<F>> SumcheckMultilinear<'a, F, Mem> {
264	pub fn const_slice(&self) -> Mem::FSlice<'_> {
265		match self {
266			Self::PreFold(slice) => Mem::narrow(slice),
267			Self::PostFold(slice) => Mem::as_const(slice),
268		}
269	}
270}
271
272/// Calculates the evaluations of the products of pairs of partially specialized multilinear
273/// polynomials for sumcheck.
274///
275/// This performs round evaluation for a special case sumcheck prover for a sumcheck over bivariate
276/// products of multilinear polynomials, defined over the same field as the sumcheck challenges,
277/// using high-to-low variable binding order.
278///
279/// In more detail, this function takes a slice of multilinear polynomials over a large field `F`
280/// and a description of pairs of them, and computes the hypercube sum of the evaluations of these
281/// pairs of multilinears at select points. The evaluation points are 0 and the "infinity" point.
282/// The meaning of the infinity evaluation point is described in the documentation of
283/// [`binius_math::EvaluationDomain`].
284///
285/// The evaluations are batched by mixing with the powers of a batch coefficient.
286///
287/// ## Mathematical Definition
288///
289/// Let $\alpha$ be the batching coefficient, $P_0, \ldots, P_{m-1}$ be the multilinears, and
290/// $t \in \left(\mathbb{N}^2 \right)^k$ be a sequence of tuples of indices. This returns two
291/// values:
292///
293/// $$
294/// z_1 = \sum_{i=0}^{k-1} \alpha^i \sum_{v \in B_{m-1}} P_{t_{0,i}}(v || 1) P_{t_{1,i}}(v || 1)
295/// \\\\
296/// z_\infty = \sum_{i=0}^{k-1} \alpha^i \sum_{v \in B_{m-1}} P_{t_{0,i}}(v || \infty)
297/// P_{t_{1,i}}(v || \infty)
298/// $$
299///
300/// ## Returns
301///
302/// Returns the batched, summed evaluations at 1 and infinity.
303pub fn calculate_round_evals<'a, F: TowerField, HAL: ComputeLayer<F>>(
304	hal: &HAL,
305	n_vars: usize,
306	batch_coeff: F,
307	multilins: &[FSlice<'a, F, HAL>],
308	compositions: &[IndexComposition<BivariateProduct, 2>],
309) -> Result<[F; 2], Error> {
310	let prod_evaluators = compositions
311		.iter()
312		.map(|composition| hal.compile_expr(&CompositionPoly::<F>::expression(&composition)))
313		.collect::<Result<Vec<_>, _>>()?;
314
315	// n_vars - 1 is the number of variables in the halves of the split multilinear.
316	let split_n_vars = n_vars - 1;
317	let kernel_mappings = multilins
318		.iter()
319		.copied()
320		.flat_map(|multilin| {
321			let (lo_half, hi_half) = HAL::DevMem::split_half(multilin);
322			[
323				KernelMemMap::Chunked {
324					data: lo_half,
325					log_min_chunk_size: 0,
326				},
327				KernelMemMap::Chunked {
328					data: hi_half,
329					log_min_chunk_size: 0,
330				},
331				// Evaluations of the multilinear at the extra evaluation point
332				KernelMemMap::Local {
333					log_size: split_n_vars,
334				},
335			]
336		})
337		.collect();
338
339	let batch_coeffs = powers(batch_coeff)
340		.take(compositions.len())
341		.collect::<Vec<_>>();
342
343	let evals = hal.execute(|exec| {
344		exec.accumulate_kernels(
345			|local_exec, log_chunks, mut buffers| {
346				let log_chunk_size = split_n_vars - log_chunks;
347
348				// Compute the composite evaluations at the point ONE.
349				let mut acc_1 = local_exec.decl_value(F::ZERO)?;
350				{
351					let eval_1s = SlicesBatch::new(
352						(0..multilins.len())
353							.map(|i| buffers[i * 3 + 1].to_ref())
354							.collect(),
355						1 << log_chunk_size,
356					);
357					for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
358						local_exec.sum_composition_evals(
359							&eval_1s,
360							evaluator,
361							batch_coeff,
362							&mut acc_1,
363						)?;
364					}
365				}
366
367				// Extrapolate the multilinear evaluations at the point Infinity.
368				for group in buffers.chunks_mut(3) {
369					let Ok(
370						[
371							KernelBuffer::Ref(evals_0),
372							KernelBuffer::Ref(evals_1),
373							KernelBuffer::Mut(evals_inf),
374						],
375					) = TryInto::<&mut [_; 3]>::try_into(group)
376					else {
377						panic!(
378							"exec_kernels did not create the mapped buffers struct according to the mapping"
379						);
380					};
381					local_exec.add(log_chunk_size, *evals_0, *evals_1, evals_inf)?;
382				}
383
384				// Compute the composite evaluations at the point Infinity.
385				let mut acc_inf = local_exec.decl_value(F::ZERO)?;
386				let eval_infs = SlicesBatch::new(
387					(0..multilins.len())
388						.map(|i| buffers[i * 3 + 2].to_ref())
389						.collect(),
390					1 << log_chunk_size,
391				);
392				for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
393					local_exec.sum_composition_evals(
394						&eval_infs,
395						evaluator,
396						batch_coeff,
397						&mut acc_inf,
398					)?;
399				}
400
401				Ok(vec![acc_1, acc_inf])
402			},
403			kernel_mappings,
404		)
405	})?;
406	let evals = TryInto::<[F; 2]>::try_into(evals).expect("kernel returns two values");
407	Ok(evals)
408}
409
410fn calculate_round_coeffs_from_evals<F: Field>(sum: F, evals: [F; 2]) -> RoundCoeffs<F> {
411	let [y_1, y_inf] = evals;
412	let y_0 = sum - y_1;
413
414	// P(X) = c_2 x² + c_1 x + c_0
415	//
416	// P(0) =                  c_0
417	// P(1) = c_2    + c_1   + c_0
418	// P(∞) = c_2
419
420	let c_0 = y_0;
421	let c_2 = y_inf;
422	let c_1 = y_1 - c_0 - c_2;
423	RoundCoeffs(vec![c_0, c_1, c_2])
424}
425
426#[derive(Debug)]
427pub enum PhaseState<F: Field> {
428	Coeffs(RoundCoeffs<F>),
429	InitialSums(Vec<F>),
430	BatchedSum(F),
431}
432
433#[cfg(test)]
434mod tests {
435	use binius_compute::cpu::layer::CpuLayerHolder;
436	use binius_compute_test_utils::bivariate_sumcheck::{
437		generic_test_bivariate_sumcheck_prove_verify, generic_test_calculate_round_evals,
438	};
439	use binius_fast_compute::layer::FastCpuLayerHolder;
440	use binius_field::{
441		BinaryField128b, arch::OptimalUnderlier, as_packed_field::PackedType,
442		tower::CanonicalTowerFamily,
443	};
444	use binius_math::B128;
445
446	#[test]
447	fn test_calculate_round_evals() {
448		let compute_holder = CpuLayerHolder::new(1 << 11, 1 << 10);
449		let n_vars = 8;
450		generic_test_calculate_round_evals(compute_holder, n_vars)
451	}
452
453	#[test]
454	fn test_calculate_round_evals_fast_cpu() {
455		type F = BinaryField128b;
456		type Packed = PackedType<OptimalUnderlier, F>;
457
458		let compute_holder =
459			FastCpuLayerHolder::<CanonicalTowerFamily, Packed>::new(1 << 11, 1 << 10);
460		let n_vars = 8;
461		generic_test_calculate_round_evals(compute_holder, n_vars)
462	}
463
464	#[test]
465	fn test_bivariate_sumcheck_prove_verify() {
466		let n_vars = 8;
467		let n_multilins = 8;
468		let n_compositions = 8;
469
470		let compute_holder = CpuLayerHolder::<B128>::new(1 << 13, 1 << 12);
471		generic_test_bivariate_sumcheck_prove_verify(
472			compute_holder,
473			n_vars,
474			n_multilins,
475			n_compositions,
476		)
477	}
478
479	#[test]
480	fn test_bivariate_sumcheck_prove_verify_fast() {
481		type F = BinaryField128b;
482		type Packed = PackedType<OptimalUnderlier, F>;
483
484		let n_vars = 8;
485		let n_multilins = 8;
486		let n_compositions = 8;
487		let compute_holder =
488			FastCpuLayerHolder::<CanonicalTowerFamily, Packed>::new(1 << 13, 1 << 12);
489
490		generic_test_bivariate_sumcheck_prove_verify(
491			compute_holder,
492			n_vars,
493			n_multilins,
494			n_compositions,
495		)
496	}
497}