binius_core/protocols/sumcheck/prove/
front_loaded.rs1use std::{collections::VecDeque, iter};
4
5use binius_field::{Field, TowerField};
6use binius_utils::sorting::is_sorted_ascending;
7use bytes::BufMut;
8
9use super::batch_prove::SumcheckProver;
10use crate::{
11 fiat_shamir::CanSample,
12 protocols::sumcheck::{Error, RoundCoeffs},
13 transcript::TranscriptWriter,
14};
15
16#[derive(Debug)]
33pub struct BatchProver<F: Field, Prover> {
34 provers: VecDeque<(Prover, F)>,
35 round: usize,
36}
37
38impl<F, Prover> BatchProver<F, Prover>
39where
40 F: TowerField,
41 Prover: SumcheckProver<F>,
42{
43 pub fn new<Transcript>(provers: Vec<Prover>, transcript: &mut Transcript) -> Result<Self, Error>
51 where
52 Transcript: CanSample<F>,
53 {
54 if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars())) {
55 return Err(Error::ClaimsOutOfOrder);
56 }
57
58 if let Some(first_prover) = provers.first() {
59 if provers
60 .iter()
61 .any(|prover| prover.evaluation_order() != first_prover.evaluation_order())
62 {
63 return Err(Error::InconsistentEvaluationOrder);
64 }
65 }
66
67 let batch_coeffs = transcript.sample_vec(provers.len());
69 let provers = iter::zip(provers, batch_coeffs).collect();
70
71 Ok(Self { provers, round: 0 })
72 }
73
74 fn finish_claim_provers<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
75 where
76 B: BufMut,
77 {
78 while let Some((prover, _)) = self.provers.front() {
79 if prover.n_vars() != self.round {
80 break;
81 }
82 let (prover, _) = self.provers.pop_front().expect("front returned Some");
83 let multilinear_evals = Box::new(prover).finish()?;
84 transcript.write_scalar_slice(&multilinear_evals);
85 }
86 Ok(())
87 }
88
89 pub fn send_round_proof<B>(&mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
91 where
92 B: BufMut,
93 {
94 self.finish_claim_provers(transcript)?;
95
96 let mut round_coeffs = RoundCoeffs::default();
97 for (prover, batch_coeff) in &mut self.provers {
98 let prover_coeffs = prover.execute(*batch_coeff)?;
99 round_coeffs += &(prover_coeffs * *batch_coeff);
100 }
101
102 let round_proof = round_coeffs.truncate();
103 transcript.write_scalar_slice(round_proof.coeffs());
104 Ok(())
105 }
106
107 pub fn receive_challenge(&mut self, challenge: F) -> Result<(), Error> {
109 for (prover, _) in &mut self.provers {
110 prover.fold(challenge)?;
111 }
112 self.round += 1;
113 Ok(())
114 }
115
116 pub fn finish<B>(mut self, transcript: &mut TranscriptWriter<B>) -> Result<(), Error>
118 where
119 B: BufMut,
120 {
121 self.finish_claim_provers(transcript)?;
122 if !self.provers.is_empty() {
123 return Err(Error::ExpectedFold);
124 }
125 Ok(())
126 }
127}