binius_core/protocols/sumcheck/prove/
front_loaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_utils::sorting::is_sorted_ascending;
7use bytes::BufMut;
8
9use super::batch_prove::SumcheckProver;
10use crate::{
11	fiat_shamir::CanSample,
12	protocols::sumcheck::{Error, RoundCoeffs},
13	transcript::TranscriptWriter,
14};
15
16/// Prover for a front-loaded batch sumcheck protocol execution.
17///
18/// Prover that satisfies the verification logic in
19/// [`crate::protocols::sumcheck::front_loaded`]. See that module for protocol information.
20///
21///
22/// This exposes a round-by-round interface so that the protocol call be interleaved with other
23/// interactive protocols, sharing the same sequence of challenges. The verification logic must be
24/// invoked with a specific sequence of calls, continuing for as many rounds as necessary until all
25/// claims are finished.
26///
27/// 1. construct a new verifier with [`BatchProver::new`]
28/// 2. if all rounds are complete, call [`BatchProver::finish`], otherwise proceed
29/// 3. call [`BatchProver::send_round_proof`]
30/// 4. sample a random challenge and call [`BatchProver::receive_challenge`] with it
31/// 5. repeat from step 2
32#[derive(Debug)]
33pub struct BatchProver<F: Field, Prover> {
34	provers: VecDeque<(Prover, F)>,
35	round: usize,
36}
37
38impl<F, Prover> BatchProver<F, Prover>
39where
40	F: TowerField,
41	Prover: SumcheckProver<F>,
42{
43	/// Constructs a new prover for the front-loaded batched sumcheck.
44	///
45	/// The constructor samples batching coefficients from the proof transcript.
46	///
47	/// ## Throws
48	///
49	/// * if the claims are not sorted in ascending order by number of variables
50	pub fn new<Transcript>(provers: Vec<Prover>, transcript: &mut Transcript) -> Result<Self, Error>
51	where
52		Transcript: CanSample<F>,
53	{
54		if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) {
55			return Err(Error::ClaimsOutOfOrder);
56		}
57
58		if let Some(first_prover) = provers.first() {
59			if provers
60				.iter()
61				.any(|prover| prover.evaluation_order() != first_prover.evaluation_order())
62			{
63				return Err(Error::InconsistentEvaluationOrder);
64			}
65		}
66
67		// Sample batch mixing coefficients
68		let batch_coeffs = transcript.sample_vec(provers.len());
69		let provers = iter::zip(provers, batch_coeffs).collect();
70
71		Ok(Self { provers, round: 0 })
72	}
73
74	fn finish_claim_provers<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
75	where
76		B: BufMut,
77	{
78		while let Some((prover, _)) = self.provers.front() {
79			if prover.n_vars() != self.round {
80				break;
81			}
82			let (prover, _) = self.provers.pop_front().expect("front returned Some");
83			let multilinear_evals = Box::new(prover).finish()?;
84			transcript.write_scalar_slice(&multilinear_evals);
85		}
86		Ok(())
87	}
88
89	/// Computes the round message and writes it to the proof transcript.
90	pub fn send_round_proof<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
91	where
92		B: BufMut,
93	{
94		self.finish_claim_provers(transcript)?;
95
96		let mut round_coeffs = RoundCoeffs::default();
97		for (prover, batch_coeff) in &mut self.provers {
98			let prover_coeffs = prover.execute(*batch_coeff)?;
99			round_coeffs += &(prover_coeffs * *batch_coeff);
100		}
101
102		let round_proof = round_coeffs.truncate();
103		transcript.write_scalar_slice(round_proof.coeffs());
104		Ok(())
105	}
106
107	/// Finishes an interaction round by reducing the instance with the verifier challenge.
108	pub fn receive_challenge(&mut self, challenge: F) -> Result<(), Error> {
109		for (prover, _) in &mut self.provers {
110			prover.fold(challenge)?;
111		}
112		self.round += 1;
113		Ok(())
114	}
115
116	/// Finishes the remaining instance provers and checks that all rounds are completed.
117	pub fn finish<B>(mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
118	where
119		B: BufMut,
120	{
121		self.finish_claim_provers(transcript)?;
122		if !self.provers.is_empty() {
123			return Err(Error::ExpectedFold);
124		}
125		Ok(())
126	}
127}