binius_core/protocols/sumcheck/prove/
prover_state.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
// Copyright 2024-2025 Irreducible Inc.

use std::{
	iter,
	sync::atomic::{AtomicBool, Ordering},
};

use binius_field::{
	util::powers, ExtensionField, Field, PackedExtension, PackedField, RepackedExtension,
};
use binius_hal::{ComputationBackend, RoundEvals, SumcheckEvaluator, SumcheckMultilinear};
use binius_math::{
	evaluate_univariate, CompositionPolyOS, MLEDirectAdapter, MultilinearPoly, MultilinearQuery,
};
use binius_utils::bail;
use getset::CopyGetters;
use itertools::izip;
use rayon::prelude::*;
use tracing::instrument;

use crate::{
	polynomial::Error as PolynomialError,
	protocols::sumcheck::{
		common::{determine_switchovers, equal_n_vars_check, RoundCoeffs},
		error::Error,
	},
};

pub trait SumcheckInterpolator<F: Field> {
	/// Given evaluations of the round polynomial, interpolate and return monomial coefficients
	///
	/// ## Arguments
	///
	/// * `round_evals`: the computed evaluations of the round polynomial
	fn round_evals_to_coeffs(
		&self,
		last_sum: F,
		round_evals: Vec<F>,
	) -> Result<Vec<F>, PolynomialError>;
}

#[derive(Debug)]
enum ProverStateCoeffsOrSums<F: Field> {
	Coeffs(Vec<RoundCoeffs<F>>),
	Sums(Vec<F>),
}

/// The stored state of a sumcheck prover, which encapsulates common implementation logic.
///
/// We expect that CPU sumcheck provers will internally maintain a [`ProverState`] instance and
/// customize the sumcheck logic through different [`SumcheckEvaluator`] implementations passed to
/// the common state object.
#[derive(Debug, CopyGetters)]
pub struct ProverState<'a, FDomain, P, M, Backend>
where
	FDomain: Field,
	P: PackedField,
	M: MultilinearPoly<P> + Send + Sync,
	Backend: ComputationBackend,
{
	/// The number of variables in the folded multilinears. This value decrements each round the
	/// state is folded.
	#[getset(get_copy = "pub")]
	n_vars: usize,
	multilinears: Vec<SumcheckMultilinear<P, M>>,
	evaluation_points: Vec<FDomain>,
	tensor_query: Option<MultilinearQuery<P>>,
	last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
	backend: &'a Backend,
}

impl<'a, FDomain, F, P, M, Backend> ProverState<'a, FDomain, P, M, Backend>
where
	FDomain: Field,
	F: Field + ExtensionField<FDomain>,
	P: PackedField<Scalar = F> + PackedExtension<FDomain>,
	M: MultilinearPoly<P> + Send + Sync,
	Backend: ComputationBackend,
{
	pub fn new(
		multilinears: Vec<M>,
		claimed_sums: Vec<F>,
		evaluation_points: Vec<FDomain>,
		switchover_fn: impl Fn(usize) -> usize,
		backend: &'a Backend,
	) -> Result<Self, Error> {
		let switchover_rounds = determine_switchovers(&multilinears, switchover_fn);
		Self::new_with_switchover_rounds(
			multilinears,
			&switchover_rounds,
			claimed_sums,
			evaluation_points,
			backend,
		)
	}

	#[instrument(
		skip_all,
		level = "debug",
		name = "ProverState::new_with_switchover_rounds"
	)]
	pub fn new_with_switchover_rounds(
		multilinears: Vec<M>,
		switchover_rounds: &[usize],
		claimed_sums: Vec<F>,
		evaluation_points: Vec<FDomain>,
		backend: &'a Backend,
	) -> Result<Self, Error> {
		let n_vars = equal_n_vars_check(&multilinears)?;

		if multilinears.len() != switchover_rounds.len() {
			bail!(Error::MultilinearSwitchoverSizeMismatch);
		}

		let max_switchover_round = switchover_rounds.iter().copied().max().unwrap_or_default();

		let multilinears = iter::zip(multilinears, switchover_rounds)
			.map(|(multilinear, &switchover_round)| SumcheckMultilinear::Transparent {
				multilinear,
				switchover_round,
			})
			.collect();

		let tensor_query = MultilinearQuery::with_capacity(max_switchover_round + 1);
		Ok(Self {
			n_vars,
			multilinears,
			evaluation_points,
			tensor_query: Some(tensor_query),
			last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
			backend,
		})
	}

	#[instrument(skip_all, name = "ProverState::fold", level = "debug")]
	pub fn fold(&mut self, challenge: F) -> Result<(), Error> {
		if self.n_vars == 0 {
			bail!(Error::ExpectedFinish);
		}

		// Update the stored multilinear sums.
		match self.last_coeffs_or_sums {
			ProverStateCoeffsOrSums::Coeffs(ref round_coeffs) => {
				let new_sums = round_coeffs
					.par_iter()
					.map(|coeffs| evaluate_univariate(&coeffs.0, challenge))
					.collect();
				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Sums(new_sums);
			}
			ProverStateCoeffsOrSums::Sums(_) => {
				bail!(Error::ExpectedExecution);
			}
		}

		// Update the tensor query.
		if let Some(tensor_query) = self.tensor_query.take() {
			self.tensor_query = Some(tensor_query.update(&[challenge])?);
		}

		// Use Relaxed ordering for writes and the read, because:
		// * all writes can only update this value in the same direction of false->true
		// * the barrier at the end of rayon "parallel for" is a big enough synchronization point to be Relaxed about memory ordering of accesses to this Atomic.
		let any_transparent_left = AtomicBool::new(false);
		self.multilinears
			.par_iter_mut()
			.try_for_each(|multilinear| {
				match multilinear {
					SumcheckMultilinear::Transparent {
						multilinear: inner_multilinear,
						ref mut switchover_round,
					} => {
						if *switchover_round == 0 {
							let tensor_query = self.tensor_query.as_ref()
							.expect(
								"tensor_query is guaranteed to be Some while there is still a transparent multilinear"
							);

							// At switchover, perform inner products in large field and save them in a
							// newly created MLE.
							let large_field_folded_multilinear = MLEDirectAdapter::from(
								inner_multilinear.evaluate_partial_low(tensor_query.to_ref())?,
							);

							*multilinear = SumcheckMultilinear::Folded {
								large_field_folded_multilinear,
							};
						} else {
							*switchover_round -= 1;
							any_transparent_left.store(true, Ordering::Relaxed);
						}
					}
					SumcheckMultilinear::Folded {
						ref mut large_field_folded_multilinear,
					} => {
						// Post-switchover, simply plug in challenge for the zeroth variable.
						*large_field_folded_multilinear = MLEDirectAdapter::from(
							large_field_folded_multilinear.evaluate_zeroth_variable(challenge)?,
						);
					}
				};
				Ok::<(), Error>(())
			})?;

		if !any_transparent_left.load(Ordering::Relaxed) {
			self.tensor_query = None;
		}

		self.n_vars -= 1;
		Ok(())
	}

	pub fn finish(self) -> Result<Vec<F>, Error> {
		match self.last_coeffs_or_sums {
			ProverStateCoeffsOrSums::Coeffs(_) => {
				bail!(Error::ExpectedFold);
			}
			ProverStateCoeffsOrSums::Sums(_) => match self.n_vars {
				0 => {}
				_ => bail!(Error::ExpectedExecution),
			},
		};

		let empty_query = MultilinearQuery::with_capacity(0);
		self.multilinears
			.into_iter()
			.map(|multilinear| {
				match multilinear {
					SumcheckMultilinear::Transparent {
						multilinear: inner_multilinear,
						..
					} => {
						let tensor_query = self.tensor_query.as_ref()
							.expect(
								"tensor_query is guaranteed to be Some while there is still a transparent multilinear"
							);
						inner_multilinear.evaluate(tensor_query.to_ref())
					}
					SumcheckMultilinear::Folded {
						large_field_folded_multilinear,
					} => large_field_folded_multilinear.evaluate(empty_query.to_ref()),
				}
				.map_err(Error::MathError)
			})
			.collect()
	}

	/// Calculate the accumulated evaluations for the first sumcheck round.
	#[instrument(skip_all, level = "debug")]
	pub fn calculate_first_round_evals<PBase, Evaluator, Composition>(
		&self,
		evaluators: &[Evaluator],
	) -> Result<Vec<RoundEvals<F>>, Error>
	where
		PBase: PackedField<Scalar: ExtensionField<FDomain>> + PackedExtension<FDomain>,
		P: PackedField<Scalar: ExtensionField<PBase::Scalar>> + RepackedExtension<PBase>,
		Evaluator: SumcheckEvaluator<PBase, P, Composition> + Sync,
		Composition: CompositionPolyOS<P>,
	{
		Ok(self.backend.sumcheck_compute_first_round_evals(
			self.n_vars,
			&self.multilinears,
			evaluators,
			&self.evaluation_points,
		)?)
	}

	/// Calculate the accumulated evaluations for an arbitrary sumcheck round.
	///
	/// See [`Self::calculate_first_round_evals`] for an optimized version of this method that
	/// operates over small fields in the first round.
	#[instrument(skip_all, level = "debug")]
	pub fn calculate_later_round_evals<Evaluator, Composition>(
		&self,
		evaluators: &[Evaluator],
	) -> Result<Vec<RoundEvals<F>>, Error>
	where
		Evaluator: SumcheckEvaluator<P, P, Composition> + Sync,
		Composition: CompositionPolyOS<P>,
	{
		Ok(self.backend.sumcheck_compute_later_round_evals(
			self.n_vars,
			self.tensor_query.as_ref().map(Into::into),
			&self.multilinears,
			evaluators,
			&self.evaluation_points,
		)?)
	}

	/// Calculate the batched round coefficients from the domain evaluations.
	///
	/// This both performs the polynomial interpolation over the evaluations and the mixing with
	/// the batching coefficient.
	pub fn calculate_round_coeffs_from_evals<Interpolator: SumcheckInterpolator<F>>(
		&mut self,
		interpolators: &[Interpolator],
		batch_coeff: F,
		evals: Vec<RoundEvals<F>>,
	) -> Result<RoundCoeffs<F>, Error> {
		let coeffs = match self.last_coeffs_or_sums {
			ProverStateCoeffsOrSums::Coeffs(_) => {
				bail!(Error::ExpectedFold);
			}
			ProverStateCoeffsOrSums::Sums(ref sums) => {
				if interpolators.len() != sums.len() {
					bail!(Error::IncorrectNumberOfEvaluators {
						expected: sums.len(),
					});
				}

				let coeffs = izip!(interpolators, sums, evals)
					.map(|(evaluator, &sum, RoundEvals(evals))| {
						let coeffs = evaluator.round_evals_to_coeffs(sum, evals)?;
						Ok::<_, Error>(RoundCoeffs(coeffs))
					})
					.collect::<Result<Vec<_>, _>>()?;
				self.last_coeffs_or_sums = ProverStateCoeffsOrSums::Coeffs(coeffs.clone());
				coeffs
			}
		};

		let batched_coeffs = coeffs
			.into_iter()
			.zip(powers(batch_coeff))
			.map(|(coeffs, scalar)| coeffs * scalar)
			.fold(RoundCoeffs::default(), |accum, coeffs| accum + &coeffs);

		Ok(batched_coeffs)
	}
}