binius_core/protocols/sumcheck/
common.rs1use std::ops::{Add, AddAssign, Mul, MulAssign};
4
5use binius_field::{
6 util::{inner_product_unchecked, powers},
7 ExtensionField, Field, PackedField,
8};
9use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly};
10use binius_utils::bail;
11use getset::{CopyGetters, Getters};
12
13use super::error::Error;
14
15#[derive(Debug, Clone, Getters, CopyGetters)]
23pub struct CompositeSumClaim<F: Field, Composition> {
24 pub composition: Composition,
25 pub sum: F,
26}
27
28#[derive(Debug, Clone, CopyGetters)]
37pub struct SumcheckClaim<F: Field, C> {
38 #[getset(get_copy = "pub")]
39 n_vars: usize,
40 #[getset(get_copy = "pub")]
41 n_multilinears: usize,
42 composite_sums: Vec<CompositeSumClaim<F, C>>,
43}
44
45impl<F: Field, Composition> SumcheckClaim<F, Composition>
46where
47 Composition: CompositionPoly<F>,
48{
49 pub fn new(
56 n_vars: usize,
57 n_multilinears: usize,
58 composite_sums: Vec<CompositeSumClaim<F, Composition>>,
59 ) -> Result<Self, Error> {
60 for CompositeSumClaim {
61 ref composition, ..
62 } in &composite_sums
63 {
64 if composition.n_vars() != n_multilinears {
65 bail!(Error::InvalidComposition {
66 actual: composition.n_vars(),
67 expected: n_multilinears,
68 });
69 }
70 }
71 Ok(Self {
72 n_vars,
73 n_multilinears,
74 composite_sums,
75 })
76 }
77
78 pub fn max_individual_degree(&self) -> usize {
80 self.composite_sums
81 .iter()
82 .map(|composite_sum| composite_sum.composition.degree())
83 .max()
84 .unwrap_or(0)
85 }
86
87 pub fn composite_sums(&self) -> &[CompositeSumClaim<F, Composition>] {
88 &self.composite_sums
89 }
90}
91
92#[derive(Debug, Default, Clone, PartialEq, Eq)]
96pub struct RoundCoeffs<F: Field>(pub Vec<F>);
97
98impl<F: Field> RoundCoeffs<F> {
99 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundCoeffs<FI> {
101 RoundCoeffs(self.0.into_iter().map(Into::into).collect())
102 }
103
104 pub fn truncate(mut self) -> RoundProof<F> {
106 self.0.pop();
107 RoundProof(self)
108 }
109}
110
111impl<F: Field> Add<&Self> for RoundCoeffs<F> {
112 type Output = Self;
113
114 fn add(mut self, rhs: &Self) -> Self::Output {
115 self += rhs;
116 self
117 }
118}
119
120impl<F: Field> AddAssign<&Self> for RoundCoeffs<F> {
121 fn add_assign(&mut self, rhs: &Self) {
122 if self.0.len() < rhs.0.len() {
123 self.0.resize(rhs.0.len(), F::ZERO);
124 }
125
126 for (lhs_i, &rhs_i) in self.0.iter_mut().zip(rhs.0.iter()) {
127 *lhs_i += rhs_i;
128 }
129 }
130}
131
132impl<F: Field> Mul<F> for RoundCoeffs<F> {
133 type Output = Self;
134
135 fn mul(mut self, rhs: F) -> Self::Output {
136 self *= rhs;
137 self
138 }
139}
140
141impl<F: Field> MulAssign<F> for RoundCoeffs<F> {
142 fn mul_assign(&mut self, rhs: F) {
143 for coeff in &mut self.0 {
144 *coeff *= rhs;
145 }
146 }
147}
148
149#[derive(Debug, Default, Clone, PartialEq, Eq)]
156pub struct RoundProof<F: Field>(pub RoundCoeffs<F>);
157
158impl<F: Field> RoundProof<F> {
159 pub fn recover(self, sum: F) -> RoundCoeffs<F> {
180 let Self(RoundCoeffs(mut coeffs)) = self;
181 let first_coeff = coeffs.first().copied().unwrap_or(F::ZERO);
182 let last_coeff = sum - first_coeff - coeffs.iter().sum::<F>();
183 coeffs.push(last_coeff);
184 RoundCoeffs(coeffs)
185 }
186
187 pub fn coeffs(&self) -> &[F] {
189 &self.0 .0
190 }
191
192 pub fn isomorphic<FI: Field + From<F>>(self) -> RoundProof<FI> {
194 RoundProof(self.0.isomorphic())
195 }
196}
197
198#[derive(Debug, Default, Clone, PartialEq, Eq)]
200pub struct Proof<F: Field> {
201 pub rounds: Vec<RoundProof<F>>,
203 pub multilinear_evals: Vec<Vec<F>>,
210}
211
212#[derive(Debug, PartialEq, Eq)]
213pub struct BatchSumcheckOutput<F: Field> {
214 pub challenges: Vec<F>,
215 pub multilinear_evals: Vec<Vec<F>>,
216}
217
218impl<F: Field> BatchSumcheckOutput<F> {
219 pub fn isomorphic<FI: Field + From<F>>(self) -> BatchSumcheckOutput<FI> {
220 BatchSumcheckOutput {
221 challenges: self.challenges.into_iter().map(Into::into).collect(),
222 multilinear_evals: self
223 .multilinear_evals
224 .into_iter()
225 .map(|prover_evals| prover_evals.into_iter().map(Into::into).collect())
226 .collect(),
227 }
228 }
229}
230
231pub fn standard_switchover_heuristic(k: isize) -> impl Fn(usize) -> usize + Copy {
234 move |extension_degree: usize| {
235 ((extension_degree.ilog2() as isize + k).max(0) as usize).saturating_sub(1)
236 }
237}
238
239pub const fn immediate_switchover_heuristic(_extension_degree: usize) -> usize {
241 0
242}
243
244pub fn equal_n_vars_check<'a, P, M>(
246 multilinears: impl IntoIterator<Item = &'a M>,
247) -> Result<usize, Error>
248where
249 P: PackedField,
250 M: MultilinearPoly<P> + 'a,
251{
252 let mut multilinears = multilinears.into_iter();
253 let n_vars = multilinears
254 .next()
255 .map(|multilinear| multilinear.n_vars())
256 .unwrap_or_default();
257 for multilinear in multilinears {
258 if multilinear.n_vars() != n_vars {
259 bail!(Error::NumberOfVariablesMismatch);
260 }
261 }
262 Ok(n_vars)
263}
264
265pub fn small_field_embedding_degree_check<F, FBase, P, M>(multilinears: &[M]) -> Result<(), Error>
270where
271 F: Field + ExtensionField<FBase>,
272 FBase: Field,
273 P: PackedField<Scalar = F>,
274 M: MultilinearPoly<P>,
275{
276 for multilinear in multilinears {
277 if multilinear.log_extension_degree() < F::LOG_DEGREE {
278 bail!(Error::MultilinearEvalsCannotBeEmbeddedInBaseField);
279 }
280 }
281
282 Ok(())
283}
284
285pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item = F>) -> F {
287 batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
289}
290
291pub fn interpolation_domains_for_composition_degrees<FDomain>(
293 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
294 degrees: impl IntoIterator<Item = usize>,
295) -> Result<Vec<InterpolationDomain<FDomain>>, Error>
296where
297 FDomain: Field,
298{
299 degrees
300 .into_iter()
301 .map(|degree| Ok(evaluation_domain_factory.create(degree + 1)?.into()))
302 .collect()
303}
304
305pub fn get_nontrivial_evaluation_points<F: Field>(
310 domains: &[InterpolationDomain<F>],
311) -> Result<Vec<F>, Error> {
312 let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
313 return Ok(Vec::new());
314 };
315
316 #[allow(clippy::get_first)]
317 if !domains.iter().all(|domain| {
318 (domain.size() <= 2 || domain.with_infinity())
319 && domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
320 && domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
321 }) {
322 bail!(Error::IncorrectSumcheckEvaluationDomain);
323 }
324
325 let finite_points = largest_domain.finite_points();
326
327 if domains
328 .iter()
329 .any(|domain| !finite_points.starts_with(domain.finite_points()))
330 {
331 bail!(Error::NonProperPrefixEvaluationDomain);
332 }
333
334 let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
335 Ok(nontrivial_evaluation_points)
336}
337
338#[cfg(test)]
339mod tests {
340 use binius_field::BinaryField64b;
341
342 use super::*;
343
344 type F = BinaryField64b;
345
346 #[test]
347 fn test_round_coeffs_truncate_non_empty() {
348 let coeffs = RoundCoeffs(vec![F::from(1), F::from(2), F::from(3)]);
349 let truncated = coeffs.truncate();
350 assert_eq!(truncated.0 .0, vec![F::from(1), F::from(2)]);
351 }
352
353 #[test]
354 fn test_round_coeffs_truncate_empty() {
355 let coeffs = RoundCoeffs::<F>(vec![]);
356 let truncated = coeffs.truncate();
357 assert!(truncated.0 .0.is_empty());
358 }
359}