binius_core/protocols/sumcheck/prove/
batch_prove_univariate_zerocheck.rs1use binius_field::{Field, TowerField};
4use binius_utils::{bail, sorting::is_sorted_ascending};
5use tracing::instrument;
6
7use crate::{
8 fiat_shamir::{CanSample, Challenger},
9 protocols::sumcheck::{
10 prove::{batch_prove::BatchProveStart, SumcheckProver},
11 univariate::LagrangeRoundEvals,
12 Error,
13 },
14 transcript::ProverTranscript,
15};
16
17pub trait UnivariateZerocheckProver<'a, F: Field> {
31 fn n_vars(&self) -> usize;
33
34 fn domain_size(&self, skip_rounds: usize) -> usize;
36
37 fn execute_univariate_round(
45 &mut self,
46 skip_rounds: usize,
47 max_domain_size: usize,
48 batch_coeff: F,
49 ) -> Result<LagrangeRoundEvals<F>, Error>;
50
51 fn fold_univariate_round(
53 self: Box<Self>,
54 challenge: F,
55 ) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error>;
56}
57
58impl<'a, F: Field, Prover: UnivariateZerocheckProver<'a, F> + ?Sized>
60 UnivariateZerocheckProver<'a, F> for Box<Prover>
61{
62 fn n_vars(&self) -> usize {
63 (**self).n_vars()
64 }
65
66 fn domain_size(&self, skip_rounds: usize) -> usize {
67 (**self).domain_size(skip_rounds)
68 }
69
70 fn execute_univariate_round(
71 &mut self,
72 skip_rounds: usize,
73 max_domain_size: usize,
74 batch_coeff: F,
75 ) -> Result<LagrangeRoundEvals<F>, Error> {
76 (**self).execute_univariate_round(skip_rounds, max_domain_size, batch_coeff)
77 }
78
79 fn fold_univariate_round(
80 self: Box<Self>,
81 challenge: F,
82 ) -> Result<Box<dyn SumcheckProver<F> + 'a>, Error> {
83 (*self).fold_univariate_round(challenge)
84 }
85}
86
87#[derive(Debug)]
88pub struct BatchZerocheckUnivariateProveOutput<F: Field, Prover> {
89 pub univariate_challenge: F,
90 pub batch_prove_start: BatchProveStart<F, Prover>,
91}
92
93#[allow(clippy::type_complexity)]
103#[instrument(skip_all, level = "debug")]
104pub fn batch_prove_zerocheck_univariate_round<'a, F, Prover, Challenger_>(
105 mut provers: Vec<Prover>,
106 skip_rounds: usize,
107 transcript: &mut ProverTranscript<Challenger_>,
108) -> Result<BatchZerocheckUnivariateProveOutput<F, Box<dyn SumcheckProver<F> + 'a>>, Error>
109where
110 F: TowerField,
111 Prover: UnivariateZerocheckProver<'a, F>,
112 Challenger_: Challenger,
113{
114 if !is_sorted_ascending(provers.iter().map(|prover| prover.n_vars()).rev()) {
116 bail!(Error::ClaimsOutOfOrder);
117 }
118
119 let max_n_vars = provers.first().map(|prover| prover.n_vars()).unwrap_or(0);
120 let min_n_vars = provers.last().map(|prover| prover.n_vars()).unwrap_or(0);
121
122 if max_n_vars - min_n_vars > skip_rounds {
123 bail!(Error::TooManySkippedRounds);
124 }
125
126 let max_domain_size = provers
127 .iter()
128 .map(|prover| prover.domain_size(skip_rounds + prover.n_vars() - max_n_vars))
129 .max()
130 .unwrap_or(0);
131
132 let mut batch_coeffs = Vec::with_capacity(provers.len());
133 let mut round_evals = LagrangeRoundEvals::zeros(max_domain_size);
134 for prover in &mut provers {
135 let next_batch_coeff = transcript.sample();
136 batch_coeffs.push(next_batch_coeff);
137
138 let prover_round_evals = prover.execute_univariate_round(
139 skip_rounds + prover.n_vars() - max_n_vars,
140 max_domain_size,
141 next_batch_coeff,
142 )?;
143
144 round_evals.add_assign_lagrange(&(prover_round_evals * next_batch_coeff))?;
145 }
146
147 let zeros_prefix_len = (1 << (skip_rounds + min_n_vars - max_n_vars)).min(max_domain_size);
148 if zeros_prefix_len != round_evals.zeros_prefix_len {
149 bail!(Error::IncorrectZerosPrefixLen);
150 }
151
152 transcript.message().write_scalar_slice(&round_evals.evals);
153 let univariate_challenge = transcript.sample();
154
155 let mut reduction_provers = Vec::with_capacity(provers.len());
156 for prover in provers {
157 let regular_prover = Box::new(prover).fold_univariate_round(univariate_challenge)?;
158 reduction_provers.push(regular_prover);
159 }
160
161 let batch_prove_start = BatchProveStart {
162 batch_coeffs,
163 reduction_provers,
164 };
165
166 let output = BatchZerocheckUnivariateProveOutput {
167 univariate_challenge,
168 batch_prove_start,
169 };
170
171 Ok(output)
172}