binius_core/protocols/sumcheck/prove/
batch_sumcheck.rs1use binius_field::{Field, TowerField};
4use binius_math::EvaluationOrder;
5use binius_utils::{bail, sorting::is_sorted_ascending};
6use itertools::izip;
7use tracing::instrument;
8
9use crate::{
10 fiat_shamir::{CanSample, Challenger},
11 protocols::sumcheck::{
12 common::{BatchSumcheckOutput, RoundCoeffs},
13 error::Error,
14 },
15 transcript::ProverTranscript,
16};
17
18pub trait SumcheckProver<F: Field> {
39 fn n_vars(&self) -> usize;
43
44 fn evaluation_order(&self) -> EvaluationOrder;
46
47 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error>;
61
62 fn fold(&mut self, challenge: F) -> Result<(), Error>;
64
65 fn finish(self: Box<Self>) -> Result<Vec<F>, Error>;
68}
69
70impl<F: Field, Prover: SumcheckProver<F> + ?Sized> SumcheckProver<F> for Box<Prover> {
72 fn n_vars(&self) -> usize {
73 (**self).n_vars()
74 }
75
76 fn evaluation_order(&self) -> EvaluationOrder {
77 (**self).evaluation_order()
78 }
79
80 fn execute(&mut self, batch_coeff: F) -> Result<RoundCoeffs<F>, Error> {
81 (**self).execute(batch_coeff)
82 }
83
84 fn fold(&mut self, challenge: F) -> Result<(), Error> {
85 (**self).fold(challenge)
86 }
87
88 fn finish(self: Box<Self>) -> Result<Vec<F>, Error> {
89 (*self).finish()
90 }
91}
92
93#[instrument(skip_all, name = "sumcheck::batch_prove")]
102pub fn batch_prove<F, Prover, Challenger_>(
103 mut provers: Vec<Prover>,
104 transcript: &mut ProverTranscript<Challenger_>,
105) -> Result<BatchSumcheckOutput<F>, Error>
106where
107 F: TowerField,
108 Prover: SumcheckProver<F>,
109 Challenger_: Challenger,
110{
111 let Some(first_prover) = provers.first() else {
112 return Ok(BatchSumcheckOutput {
113 challenges: Vec::new(),
114 multilinear_evals: Vec::new(),
115 });
116 };
117
118 let evaluation_order = first_prover.evaluation_order();
119
120 if provers
121 .iter()
122 .any(|prover| prover.evaluation_order() != evaluation_order)
123 {
124 bail!(Error::InconsistentEvaluationOrder);
125 }
126
127 if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
129 bail!(Error::ClaimsOutOfOrder);
130 }
131
132 let n_rounds = provers
133 .iter()
134 .map(|prover| prover.n_vars())
135 .max()
136 .unwrap_or(0);
137
138 let mut batch_coeffs = Vec::with_capacity(provers.len());
139 let mut challenges = Vec::with_capacity(n_rounds);
140 for round_no in 0..n_rounds {
141 let n_vars = n_rounds - round_no;
142
143 while let Some(prover) = provers.get(batch_coeffs.len()) {
145 if prover.n_vars() != n_vars {
146 break;
147 }
148
149 let next_batch_coeff = transcript.sample();
150 batch_coeffs.push(next_batch_coeff);
151 }
152
153 let mut round_coeffs = RoundCoeffs::default();
155 for (&batch_coeff, prover) in izip!(&batch_coeffs, &mut provers) {
156 let prover_coeffs = prover.execute(batch_coeff)?;
157 round_coeffs += &(prover_coeffs * batch_coeff);
158 }
159
160 let round_proof = round_coeffs.truncate();
161 transcript
162 .message()
163 .write_scalar_slice(round_proof.coeffs());
164
165 let challenge = transcript.sample();
166 challenges.push(challenge);
167
168 for prover in &mut provers[..batch_coeffs.len()] {
169 prover.fold(challenge)?;
170 }
171 }
172
173 for prover in &provers[batch_coeffs.len()..] {
175 debug_assert_eq!(prover.n_vars(), 0);
176 let _next_batch_coeff: F = transcript.sample();
177 }
178
179 let multilinear_evals = provers
180 .into_iter()
181 .map(|prover| Box::new(prover).finish())
182 .collect::<Result<Vec<_>, _>>()?;
183
184 let mut writer = transcript.message();
185 for multilinear_evals in &multilinear_evals {
186 writer.write_scalar_slice(multilinear_evals);
187 }
188
189 if EvaluationOrder::HighToLow == evaluation_order {
190 challenges.reverse();
191 }
192
193 let output = BatchSumcheckOutput {
194 challenges,
195 multilinear_evals,
196 };
197
198 Ok(output)
199}