binius_core/protocols/sumcheck/prove/
front_loaded.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_utils::sorting::is_sorted_ascending;
7use bytes::BufMut;
8
9use super::{batch_sumcheck::SumcheckProver, logging::PIOPCompilerFoldData};
10use crate::{
11	fiat_shamir::{CanSample, Challenger},
12	protocols::sumcheck::{BatchSumcheckOutput, Error, RoundCoeffs},
13	transcript::{ProverTranscript, TranscriptWriter},
14};
15
16/// Prover for a front-loaded batch sumcheck protocol execution.
17///
18/// Prover that satisfies the verification logic in
19/// [`crate::protocols::sumcheck::front_loaded`]. See that module for protocol information.
20///
21///
22/// This exposes a round-by-round interface so that the protocol call be interleaved with other
23/// interactive protocols, sharing the same sequence of challenges. The verification logic must be
24/// invoked with a specific sequence of calls, continuing for as many rounds as necessary until all
25/// claims are finished.
26///
27/// 1. construct a new verifier with [`BatchProver::new`]
28/// 2. if all rounds are complete, call [`BatchProver::finish`], otherwise proceed
29/// 3. call [`BatchProver::send_round_proof`]
30/// 4. sample a random challenge and call [`BatchProver::receive_challenge`] with it
31/// 5. repeat from step 2
32#[derive(Debug)]
33pub struct BatchProver<F: Field, Prover> {
34	provers: VecDeque<(Prover, F)>,
35	multilinear_evals: Vec<Vec<F>>,
36	round: usize,
37}
38
39impl<F, Prover> BatchProver<F, Prover>
40where
41	F: TowerField,
42	Prover: SumcheckProver<F>,
43{
44	/// Constructs a new prover for the front-loaded batched sumcheck.
45	///
46	/// The constructor samples batching coefficients from the proof transcript.
47	///
48	/// ## Throws
49	///
50	/// * if the claims are not sorted in ascending order by number of variables
51	pub fn new<Transcript>(provers: Vec<Prover>, transcript: &mut Transcript) -> Result<Self, Error>
52	where
53		Transcript: CanSample<F>,
54	{
55		// Sample batch mixing coefficients
56		let batch_coeffs = transcript.sample_vec(provers.len());
57
58		Self::new_prebatched(batch_coeffs, provers)
59	}
60
61	/// Returns total number of batched sumcheck rounds
62	pub fn total_rounds(&self) -> usize
63	where
64		Prover: SumcheckProver<F>,
65	{
66		self.provers
67			.back()
68			.map(|(prover, _)| prover.n_vars())
69			.unwrap_or(0)
70	}
71
72	/// Constructs a new prover for the front-loaded batched sumcheck with
73	/// specified batching coefficients.
74	///
75	/// ## Throws
76	///
77	/// * if the claims are not sorted in ascending order by number of variables
78	pub fn new_prebatched(batch_coeffs: Vec<F>, provers: Vec<Prover>) -> Result<Self, Error> {
79		if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) {
80			return Err(Error::ClaimsOutOfOrder);
81		}
82
83		if batch_coeffs.len() != provers.len() {
84			return Err(Error::IncorrectNumberOfBatchCoeffs);
85		}
86
87		if let Some(first_prover) = provers.first() {
88			if provers
89				.iter()
90				.any(|prover| prover.evaluation_order() != first_prover.evaluation_order())
91			{
92				return Err(Error::InconsistentEvaluationOrder);
93			}
94		}
95
96		let provers = iter::zip(provers, batch_coeffs).collect();
97
98		Ok(Self {
99			provers,
100			multilinear_evals: Vec::new(),
101			round: 0,
102		})
103	}
104
105	fn finish_claim_provers<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
106	where
107		B: BufMut,
108	{
109		while let Some((prover, _)) = self.provers.front() {
110			if prover.n_vars() != self.round {
111				break;
112			}
113			let (prover, _) = self.provers.pop_front().expect("front returned Some");
114			let claim_multilinear_evals = Box::new(prover).finish()?;
115			transcript.write_scalar_slice(&claim_multilinear_evals);
116			self.multilinear_evals.push(claim_multilinear_evals);
117		}
118		Ok(())
119	}
120
121	/// Computes the round message and writes it to the proof transcript.
122	pub fn send_round_proof<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
123	where
124		B: BufMut,
125	{
126		self.finish_claim_provers(transcript)?;
127
128		let mut round_coeffs = RoundCoeffs::default();
129		for (prover, batch_coeff) in &mut self.provers {
130			let prover_coeffs = prover.execute(*batch_coeff)?;
131			round_coeffs += &(prover_coeffs * *batch_coeff);
132		}
133
134		let round_proof = round_coeffs.truncate();
135		transcript.write_scalar_slice(round_proof.coeffs());
136		Ok(())
137	}
138
139	/// Finishes an interaction round by reducing the instance with the verifier challenge.
140	pub fn receive_challenge(&mut self, challenge: F) -> Result<(), Error> {
141		for (prover, _) in &mut self.provers {
142			let dimensions_data = PIOPCompilerFoldData::new(prover);
143			let _span = tracing::debug_span!(
144				"[task] (PIOP Compiler) Fold",
145				phase = "piop_compiler",
146				round = self.round,
147				dimensions_data = ?dimensions_data,
148			)
149			.entered();
150			prover.fold(challenge)?;
151		}
152		self.round += 1;
153		Ok(())
154	}
155
156	/// Finishes the remaining instance provers and checks that all rounds are completed.
157	pub fn finish<B>(mut self, transcript: &mut TranscriptWriter<B>) -> Result<Vec<Vec<F>>, Error>
158	where
159		B: BufMut,
160	{
161		self.finish_claim_provers(transcript)?;
162		if !self.provers.is_empty() {
163			return Err(Error::ExpectedFold);
164		}
165
166		Ok(self.multilinear_evals)
167	}
168
169	/// Returns the iterator over the provers.
170	pub fn provers(&self) -> impl Iterator<Item = &Prover> {
171		self.provers.iter().map(|(prover, _)| prover)
172	}
173
174	/// Proves a front-loaded batch sumcheck protocol execution.
175	pub fn run<Challenger_: Challenger>(
176		mut self,
177		transcript: &mut ProverTranscript<Challenger_>,
178	) -> Result<BatchSumcheckOutput<F>, Error> {
179		let round_count = self.total_rounds();
180
181		let mut challenges = Vec::with_capacity(round_count);
182		for _round_no in 0..round_count {
183			self.send_round_proof(&mut transcript.message())?;
184
185			let challenge = transcript.sample();
186			challenges.push(challenge);
187
188			self.receive_challenge(challenge)?;
189		}
190
191		let multilinear_evals = self.finish(&mut transcript.message())?;
192
193		Ok(BatchSumcheckOutput {
194			challenges,
195			multilinear_evals,
196		})
197	}
198}