binius_core/protocols/sumcheck/v3/
bivariate_mlecheck.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{iter, mem, slice};
4
5use binius_compute::{
6	ComputeLayer, ComputeLayerExecutor, ComputeMemory, FSlice, KernelBuffer, KernelExecutor,
7	KernelMemMap, SizedSlice, SlicesBatch, alloc::ComputeAllocator, cpu::CpuMemory,
8};
9use binius_field::{
10	Field, TowerField,
11	util::{eq, powers},
12};
13use binius_math::{ArithCircuit, CompositionPoly, EvaluationOrder, evaluate_univariate};
14use binius_utils::bail;
15use itertools::Itertools;
16
17use super::bivariate_product::{PhaseState, SumcheckMultilinear};
18use crate::{
19	composition::{BivariateProduct, IndexComposition},
20	protocols::sumcheck::{
21		CompositeSumClaim, EqIndSumcheckClaim, Error, RoundCoeffs, prove::SumcheckProver,
22	},
23};
24
25/// MLEcheck prover implementation for the special case of bivariate product compositions over
26/// large-field multilinears.
27///
28/// This implements the [`SumcheckProver`] interface. The implementation uses a [`ComputeLayer`]
29/// instance for expensive operations and the input multilinears are provided as device memory
30/// slices.
31pub struct BivariateMLEcheckProver<
32	'a,
33	F: Field,
34	Hal: ComputeLayer<F>,
35	HostAllocatorType,
36	DeviceAllocatorType,
37> where
38	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
39	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
40{
41	hal: &'a Hal,
42	dev_alloc: &'a DeviceAllocatorType,
43	host_alloc: &'a HostAllocatorType,
44	n_vars_initial: usize,
45	n_vars_remaining: usize,
46	multilins: Vec<SumcheckMultilinear<'a, F, Hal::DevMem>>,
47	compositions: Vec<IndexComposition<BivariateProduct, 2>>,
48	last_coeffs_or_sums: PhaseState<F>,
49	eq_ind_prefix_eval: F,
50	// Wrapping it in Option is a temporary workaround for the lifetime problem
51	eq_ind_partial_evals: Option<SumcheckMultilinear<'a, F, Hal::DevMem>>,
52	eq_ind_challenges: Vec<F>,
53}
54
55impl<'a, F, Hal, HostAllocatorType, DeviceAllocatorType>
56	BivariateMLEcheckProver<'a, F, Hal, HostAllocatorType, DeviceAllocatorType>
57where
58	F: TowerField,
59	Hal: ComputeLayer<F>,
60	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
61	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
62{
63	#[allow(clippy::too_many_arguments)]
64	pub fn new(
65		hal: &'a Hal,
66		dev_alloc: &'a DeviceAllocatorType,
67		host_alloc: &'a HostAllocatorType,
68		claim: &EqIndSumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
69		multilins: Vec<FSlice<'a, F, Hal>>,
70		// Specify an existing tensor expansion for `eq_ind_challenges`. Avoids
71		// duplicate work.
72		eq_ind_partial_evals: FSlice<'a, F, Hal>,
73		eq_ind_challenges: Vec<F>,
74	) -> Result<Self, Error> {
75		let n_vars = claim.n_vars();
76
77		// Check shape of multilinear witness inputs.
78		assert_eq!(claim.n_multilinears(), multilins.len());
79		for multilin in &multilins {
80			if multilin.len() != 1 << n_vars {
81				bail!(Error::NumberOfVariablesMismatch);
82			}
83		}
84
85		// Wrap multilinear witness inputs as SumcheckMultilinears.
86		let multilins = multilins
87			.into_iter()
88			.map(SumcheckMultilinear::PreFold)
89			.collect();
90
91		let (compositions, sums) = claim
92			.eq_ind_composite_sums()
93			.iter()
94			.map(|CompositeSumClaim { composition, sum }| (composition.clone(), *sum))
95			.unzip();
96
97		// Only one value of the expanded equality indicator is used per each
98		// 1-variable subcube, thus it should be twice smaller.
99		if eq_ind_partial_evals.len() != 1 << n_vars.saturating_sub(1) {
100			bail!(Error::IncorrectEqIndPartialEvalsSize);
101		}
102
103		let eq_ind_partial_evals_buffer = SumcheckMultilinear::PreFold(eq_ind_partial_evals);
104
105		Ok(Self {
106			hal,
107			dev_alloc,
108			host_alloc,
109			n_vars_initial: n_vars,
110			n_vars_remaining: n_vars,
111			multilins,
112			compositions,
113			last_coeffs_or_sums: PhaseState::InitialSums(sums),
114			eq_ind_prefix_eval: F::ONE,
115			eq_ind_partial_evals: Some(eq_ind_partial_evals_buffer),
116			eq_ind_challenges,
117		})
118	}
119
120	fn update_eq_ind_prefix_eval(&mut self, challenge: F) {
121		// Update the running eq ind evaluation.
122		self.eq_ind_prefix_eval *= eq(self.eq_ind_challenges[self.n_vars_remaining - 1], challenge);
123	}
124
125	/// Returns the amount of host memory this sumcheck requires.
126	pub fn required_host_memory(
127		claim: &EqIndSumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
128	) -> usize {
129		// In `finish()`, the prover allocates a temporary host buffer for each of the fully folded
130		// multilinear evaluations, plus `eq_ind` evaluations, which will be appended at the end.
131		claim.n_multilinears() + 1
132	}
133
134	/// Returns the amount of device memory this sumcheck requires.
135	pub fn required_device_memory(
136		claim: &EqIndSumcheckClaim<F, IndexComposition<BivariateProduct, 2>>,
137		with_eq_ind_partial_evals: bool,
138	) -> usize {
139		// In `fold()`, the prover allocates device buffers for each of the folded multilinears,
140		// plus for `eq_ind`. Each of them is half the size of the original multilinears.
141		let n_multilinears = claim.n_multilinears() + if with_eq_ind_partial_evals { 0 } else { 1 };
142		n_multilinears * (1 << (claim.n_vars() - 1))
143	}
144
145	pub fn fold_multilinears(&mut self, challenge: F) -> Result<(), Error> {
146		use binius_compute::{FSlice, FSliceMut};
147
148		type PreparedExtrapolateLineArgs<'a, F, Hal> = (FSliceMut<'a, F, Hal>, FSlice<'a, F, Hal>);
149
150		let prepared_extrapolate_line_ops = self
151			.multilins
152			.drain(..)
153			.map(
154				|multilin| -> Result<
155					PreparedExtrapolateLineArgs<'a, F, Hal>,
156					binius_compute::Error,
157				> {
158					match multilin {
159						SumcheckMultilinear::PreFold(evals) => {
160							debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
161							let (evals_0, evals_1) =
162								Hal::DevMem::split_half(evals);
163
164							// Allocate new buffer for the folded evaluations and copy in evals_0.
165							let mut folded_evals =
166								self.dev_alloc.alloc(1 << (self.n_vars_remaining - 1))?;
167							self.hal.copy_d2d(evals_0, &mut folded_evals)?;
168							Ok((folded_evals, evals_1))
169						}
170						SumcheckMultilinear::PostFold(evals) => {
171							debug_assert_eq!(evals.len(), 1 << self.n_vars_remaining);
172							let (evals_0, evals_1) =
173								Hal::DevMem::split_half_mut(evals);
174							Ok((evals_0, Hal::DevMem::to_const(evals_1)))
175						}
176					}
177				},
178			)
179			.collect_vec();
180		self.hal.execute(|exec| {
181			self.multilins = exec.map(
182				prepared_extrapolate_line_ops.into_iter(),
183				|exec, extrapolate_line_args| {
184					let (mut evals_0, evals_1) = extrapolate_line_args?;
185					exec.extrapolate_line(&mut evals_0, evals_1, challenge)?;
186					Ok(SumcheckMultilinear::<F, Hal::DevMem>::PostFold(evals_0))
187				},
188			)?;
189
190			Ok(Vec::new())
191		})?;
192		Ok(())
193	}
194
195	pub fn fold_eq_ind(&mut self) -> Result<(), Error> {
196		let eq_ind_partial_evals = mem::take(&mut self.eq_ind_partial_evals).expect("exist");
197
198		let split_n_vars = self.n_vars_remaining - 2;
199
200		let (mut evals_0, evals_1) = match eq_ind_partial_evals {
201			SumcheckMultilinear::PreFold(evals) => {
202				let (evals_0, evals_1) = Hal::DevMem::split_half(evals);
203
204				let mut buffer = self.dev_alloc.alloc(evals_0.len())?;
205				self.hal.copy_d2d(evals_0, &mut buffer)?;
206
207				(buffer, evals_1)
208			}
209			SumcheckMultilinear::PostFold(evals) => {
210				let (evals_0, evals_1) = Hal::DevMem::split_half_mut(evals);
211
212				let evals_1 = Hal::DevMem::to_const(evals_1);
213
214				(evals_0, evals_1)
215			}
216		};
217
218		let kernel_mappings = vec![
219			KernelMemMap::ChunkedMut {
220				data: Hal::DevMem::to_owned_mut(&mut evals_0),
221				log_min_chunk_size: 0,
222			},
223			KernelMemMap::Chunked {
224				data: Hal::DevMem::narrow(&evals_1),
225				log_min_chunk_size: 0,
226			},
227		];
228
229		let _ = self.hal.execute(|exec| {
230			exec.map_kernels(
231				|local_exec, log_chunks, mut buffers| {
232					let log_chunk_size = split_n_vars - log_chunks;
233
234					let Ok([KernelBuffer::Mut(evals_0), KernelBuffer::Ref(evals_1)]) =
235						TryInto::<&mut [_; 2]>::try_into(buffers.as_mut_slice())
236					else {
237						panic!(
238							"exec_kernels did not create the mapped buffers struct according to the mapping"
239						);
240					};
241					local_exec.add_assign(log_chunk_size, *evals_1, evals_0)?;
242
243					Ok(())
244				},
245				kernel_mappings,
246			)?;
247
248			Ok(Vec::new())
249		})?;
250
251		self.eq_ind_partial_evals = Some(SumcheckMultilinear::PostFold(evals_0));
252
253		Ok(())
254	}
255}
256
257impl<F, Hal, HostAllocatorType, DeviceAllocatorType> SumcheckProver<F>
258	for BivariateMLEcheckProver<'_, F, Hal, HostAllocatorType, DeviceAllocatorType>
259where
260	F: TowerField,
261	Hal: ComputeLayer<F>,
262	HostAllocatorType: ComputeAllocator<F, CpuMemory>,
263	DeviceAllocatorType: ComputeAllocator<F, Hal::DevMem>,
264{
265	fn n_vars(&self) -> usize {
266		self.n_vars_initial
267	}
268
269	fn evaluation_order(&self) -> EvaluationOrder {
270		EvaluationOrder::HighToLow
271	}
272
273	fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
274		let multilins = self
275			.multilins
276			.iter()
277			.map(|multilin| multilin.const_slice())
278			.collect::<Vec<_>>();
279
280		let round_evals = calculate_round_evals(
281			self.hal,
282			self.n_vars_remaining,
283			batch_coeff,
284			&multilins,
285			self.eq_ind_partial_evals
286				.as_ref()
287				.expect("eq_ind_partial_evals not None")
288				.const_slice(),
289			&self.compositions,
290		)?;
291
292		let batched_sum = match self.last_coeffs_or_sums {
293			PhaseState::Coeffs(_) => {
294				bail!(Error::ExpectedFold);
295			}
296			PhaseState::InitialSums(ref sums) => evaluate_univariate(sums, batch_coeff),
297			PhaseState::BatchedSum(sum) => sum,
298		};
299
300		let alpha = self.eq_ind_challenges[self.n_vars_remaining - 1];
301
302		let prime_coeffs = calculate_round_coeffs_from_evals(batched_sum, round_evals, alpha);
303
304		self.last_coeffs_or_sums = PhaseState::Coeffs(prime_coeffs.clone());
305
306		// Convert v' polynomial into v polynomial
307		// eq(X, α) = (1 − α) + (2 α − 1) X
308		let prime_coeffs_scaled_by_constant_term = prime_coeffs.clone() * (F::ONE - alpha);
309
310		let mut prime_coeffs_scaled_by_linear_term = prime_coeffs * (alpha.double() - F::ONE);
311
312		prime_coeffs_scaled_by_linear_term.0.insert(0, F::ZERO); // Multiply prime polynomial by X
313
314		let coeffs = (prime_coeffs_scaled_by_constant_term + &prime_coeffs_scaled_by_linear_term)
315			* self.eq_ind_prefix_eval;
316
317		Ok(coeffs)
318	}
319
320	fn fold(&mut self, challenge: F) -> Result<(), Error> {
321		if self.n_vars_remaining == 0 {
322			bail!(Error::ExpectedFinish);
323		}
324
325		// Update the stored multilinear sums.
326		match self.last_coeffs_or_sums {
327			PhaseState::Coeffs(ref coeffs) => {
328				let new_sum = evaluate_univariate(&coeffs.0, challenge);
329				self.last_coeffs_or_sums = PhaseState::BatchedSum(new_sum);
330			}
331			PhaseState::InitialSums(_) | PhaseState::BatchedSum(_) => {
332				bail!(Error::ExpectedExecution);
333			}
334		}
335
336		self.update_eq_ind_prefix_eval(challenge);
337
338		self.fold_multilinears(challenge)?;
339
340		if self.n_vars_remaining - 1 != 0 {
341			self.fold_eq_ind()?;
342		}
343
344		self.n_vars_remaining -= 1;
345		Ok(())
346	}
347
348	fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
349		match self.last_coeffs_or_sums {
350			PhaseState::Coeffs(_) => {
351				bail!(Error::ExpectedFold);
352			}
353			_ => match self.n_vars_remaining {
354				0 => {}
355				_ => bail!(Error::ExpectedExecution),
356			},
357		};
358
359		// Copy the fully folded multilinear evaluations to the host.
360		let buffer = self.host_alloc.alloc(self.multilins.len())?;
361		for (multilin, dst_i) in iter::zip(self.multilins, &mut *buffer) {
362			let vals = multilin.const_slice();
363			debug_assert_eq!(vals.len(), 1);
364			self.hal.copy_d2h(vals, slice::from_mut(dst_i))?;
365		}
366
367		let mut res = buffer.to_vec();
368
369		res.push(self.eq_ind_prefix_eval);
370
371		Ok(res)
372	}
373}
374
375fn calculate_round_coeffs_from_evals<F: Field>(sum: F, evals: [F; 2], alpha: F) -> RoundCoeffs<F> {
376	let [y_1, y_inf] = evals;
377
378	let y_0 = (sum - y_1 * alpha) * (F::ONE - alpha).invert_or_zero();
379
380	// P(X) = c_2 x² + c_1 x + c_0
381	//
382	// P(0) =                  c_0
383	// P(1) = c_2    + c_1   + c_0
384	// P(∞) = c_2
385	let c_0 = y_0;
386	let c_2 = y_inf;
387	let c_1 = y_1 - c_0 - c_2;
388	RoundCoeffs(vec![c_0, c_1, c_2])
389}
390
391fn calculate_round_evals<'a, F: TowerField, Hal: ComputeLayer<F>>(
392	hal: &Hal,
393	n_vars: usize,
394	batch_coeff: F,
395	multilins: &[FSlice<'a, F, Hal>],
396	eq_ind_partial_evals: FSlice<'a, F, Hal>,
397	compositions: &[IndexComposition<BivariateProduct, 2>],
398) -> Result<[F; 2], Error> {
399	let prod_evaluators = compositions
400		.iter()
401		.map(|composition| {
402			let mut prod_expr = CompositionPoly::<F>::expression(&composition);
403			// add eq_ind
404			prod_expr *= ArithCircuit::var(multilins.len());
405
406			hal.compile_expr(&prod_expr)
407		})
408		.collect::<Result<Vec<_>, _>>()?;
409
410	// n_vars - 1 is the number of variables in the halves of the split multilinear.
411	let split_n_vars = n_vars - 1;
412	let mut kernel_mappings = multilins
413		.iter()
414		.copied()
415		.flat_map(|multilin| {
416			let (lo_half, hi_half) = Hal::DevMem::split_half(multilin);
417			[
418				KernelMemMap::Chunked {
419					data: lo_half,
420					log_min_chunk_size: 0,
421				},
422				KernelMemMap::Chunked {
423					data: hi_half,
424					log_min_chunk_size: 0,
425				},
426				// Evaluations of the multilinear at the extra evaluation point
427				KernelMemMap::Local {
428					log_size: split_n_vars,
429				},
430			]
431		})
432		.collect::<Vec<_>>();
433
434	kernel_mappings.push(KernelMemMap::Chunked {
435		data: eq_ind_partial_evals,
436		log_min_chunk_size: 0,
437	});
438
439	let batch_coeffs = powers(batch_coeff)
440		.take(compositions.len())
441		.collect::<Vec<_>>();
442
443	let evals = hal.execute(|exec| {
444		exec.accumulate_kernels(
445			|local_exec, log_chunks, mut buffers| {
446				let log_chunk_size = split_n_vars - log_chunks;
447
448				let eq_ind = buffers.pop().expect(
449					"The presence of eq_ind in the buffer is due to it being added earlier in the code.",
450				);
451
452				// Compute the composite evaluations at the point ONE.
453				let mut acc_1 = local_exec.decl_value(F::ZERO)?;
454				{
455					let mut eval_1s_with_eq_ind = (0..multilins.len())
456						.map(|i| buffers[i * 3 + 1].to_ref())
457						.collect::<Vec<_>>();
458
459					eval_1s_with_eq_ind.push(eq_ind.to_ref());
460
461					let eval_1s_with_eq_ind =
462						SlicesBatch::new(eval_1s_with_eq_ind, 1 << log_chunk_size);
463
464					for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
465						local_exec.sum_composition_evals(
466							&eval_1s_with_eq_ind,
467							evaluator,
468							batch_coeff,
469							&mut acc_1,
470						)?;
471					}
472				}
473
474				// Extrapolate the multilinear evaluations at the point Infinity.
475				for group in buffers.chunks_mut(3) {
476					let Ok(
477						[
478							KernelBuffer::Ref(evals_0),
479							KernelBuffer::Ref(evals_1),
480							KernelBuffer::Mut(evals_inf),
481						],
482					) = TryInto::<&mut [_; 3]>::try_into(group)
483					else {
484						panic!(
485							"exec_kernels did not create the mapped buffers struct according to the mapping"
486						);
487					};
488					local_exec.add(log_chunk_size, *evals_0, *evals_1, evals_inf)?;
489				}
490
491				// Compute the composite evaluations at the point Infinity.
492				let mut acc_inf = local_exec.decl_value(F::ZERO)?;
493				let mut eval_infs_with_eq_ind = (0..multilins.len())
494					.map(|i| buffers[i * 3 + 2].to_ref())
495					.collect::<Vec<_>>();
496
497				eval_infs_with_eq_ind.push(eq_ind.to_ref());
498
499				let eval_infs_with_eq_ind =
500					SlicesBatch::new(eval_infs_with_eq_ind, 1 << log_chunk_size);
501
502				for (&batch_coeff, evaluator) in iter::zip(&batch_coeffs, &prod_evaluators) {
503					local_exec.sum_composition_evals(
504						&eval_infs_with_eq_ind,
505						evaluator,
506						batch_coeff,
507						&mut acc_inf,
508					)?;
509				}
510
511				Ok(vec![acc_1, acc_inf])
512			},
513			kernel_mappings,
514		)
515	})?;
516
517	let evals = TryInto::<[F; 2]>::try_into(evals).expect("kernel returns two values");
518	Ok(evals)
519}
520
521#[cfg(test)]
522mod tests {
523	use binius_compute::cpu::layer::CpuLayerHolder;
524	use binius_compute_test_utils::bivariate_sumcheck::generic_test_bivariate_mlecheck_prove_verify;
525	use binius_fast_compute::layer::FastCpuLayerHolder;
526	use binius_field::{
527		arch::OptimalUnderlier, as_packed_field::PackedType, tower::CanonicalTowerFamily,
528	};
529	use binius_math::B128;
530
531	#[test]
532	fn test_bivariate_mlecheck_prove_verify() {
533		let compute_holder = CpuLayerHolder::<B128>::new(1 << 13, 1 << 12);
534		let n_vars = 8;
535		let n_multilins = 8;
536		let n_compositions = 8;
537		generic_test_bivariate_mlecheck_prove_verify(
538			compute_holder,
539			n_vars,
540			n_multilins,
541			n_compositions,
542		);
543	}
544
545	#[test]
546	fn test_bivariate_mlecheck_prove_verify_fast() {
547		type F = B128;
548		type P = PackedType<OptimalUnderlier, F>;
549
550		let compute_holder = FastCpuLayerHolder::<CanonicalTowerFamily, P>::new(1 << 13, 1 << 12);
551		let n_vars = 8;
552		let n_multilins = 8;
553		let n_compositions = 8;
554		generic_test_bivariate_mlecheck_prove_verify(
555			compute_holder,
556			n_vars,
557			n_multilins,
558			n_compositions,
559		);
560	}
561}