1use binius_field::{
4 BinaryField, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable,
5 TowerField,
6};
7use binius_hal::ComputationBackend;
8use binius_math::{EvaluationDomainFactory, EvaluationOrder};
9use binius_utils::{bail, sorting::is_sorted_ascending};
10use itertools::izip;
11use tracing::instrument;
12
13use super::{
14 common::{BaseExpReductionOutput, ExpClaim, GKRExpProver, LayerClaim},
15 compositions::ProverExpComposition,
16 error::Error,
17 provers::{
18 CompositeSumClaimWithMultilinears, DynamicBaseExpProver, ExpProver, GeneratorExpProver,
19 },
20 witness::BaseExpWitness,
21};
22use crate::{
23 fiat_shamir::Challenger,
24 protocols::sumcheck::{self, BatchSumcheckOutput, CompositeSumClaim},
25 transcript::ProverTranscript,
26 witness::MultilinearWitness,
27};
28
29pub fn batch_prove<'a, FBase, F, P, FDomain, Challenger_, Backend>(
46 evaluation_order: EvaluationOrder,
47 witnesses: impl IntoIterator<Item = BaseExpWitness<'a, P, FBase>>,
48 claims: &[ExpClaim<F>],
49 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
50 transcript: &mut ProverTranscript<Challenger_>,
51 backend: &Backend,
52) -> Result<BaseExpReductionOutput<F>, Error>
53where
54 F: ExtensionField<FBase> + ExtensionField<FDomain> + TowerField,
55 FDomain: Field,
56 FBase: TowerField + ExtensionField<FDomain>,
57 P: PackedFieldIndexable<Scalar = F>
58 + PackedExtension<F, PackedSubfield = P>
59 + PackedExtension<FDomain>,
60 Backend: ComputationBackend,
61 Challenger_: Challenger,
62{
63 let witnesses = witnesses.into_iter().collect::<Vec<_>>();
64
65 if witnesses.len() != claims.len() {
66 bail!(Error::MismatchedWitnessClaimLength);
67 }
68
69 let mut layers_claims = Vec::new();
70
71 if witnesses.is_empty() {
72 return Ok(BaseExpReductionOutput { layers_claims });
73 }
74
75 if !is_sorted_ascending(claims.iter().map(|claim| claim.n_vars).rev()) {
77 bail!(Error::ClaimsOutOfOrder);
78 }
79
80 let mut provers = make_provers::<_, FBase>(witnesses, claims)?;
81
82 let max_exponent_bit_number = provers.first().map(|p| p.exponent_bit_width()).unwrap_or(0);
83
84 for layer_no in 0..max_exponent_bit_number {
85 let gkr_sumcheck_provers = build_layer_gkr_sumcheck_provers(
86 evaluation_order,
87 &mut provers,
88 layer_no,
89 evaluation_domain_factory.clone(),
90 backend,
91 )?;
92
93 let sumcheck_proof_output = sumcheck::batch_prove(gkr_sumcheck_provers, transcript)?;
94
95 let layer_exponent_claims = build_layer_exponent_bit_claims(
96 evaluation_order,
97 &mut provers,
98 sumcheck_proof_output,
99 layer_no,
100 )?;
101
102 layers_claims.push(layer_exponent_claims);
103
104 provers.retain(|prover| !prover.is_last_layer(layer_no));
105 }
106
107 Ok(BaseExpReductionOutput { layers_claims })
108}
109
110type GKRExpProvers<'a, F, P, FDomain, Backend> =
111 Vec<GKRExpProver<'a, FDomain, P, ProverExpComposition<F>, MultilinearWitness<'a, P>, Backend>>;
112
113#[instrument(skip_all, level = "debug")]
115fn build_layer_gkr_sumcheck_provers<'a, P, FDomain, Backend>(
116 evaluation_order: EvaluationOrder,
117 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
118 layer_no: usize,
119 evaluation_domain_factory: impl EvaluationDomainFactory<FDomain>,
120 backend: &'a Backend,
121) -> Result<GKRExpProvers<'a, P::Scalar, P, FDomain, Backend>, Error>
122where
123 FDomain: Field,
124 P: PackedFieldIndexable + PackedExtension<FDomain>,
125 P::Scalar: TowerField + ExtensionField<FDomain>,
126 Backend: ComputationBackend,
127{
128 assert!(!provers.is_empty());
129
130 let mut composite_claims = Vec::new();
131 let mut multilinears = Vec::new();
132
133 let first_eval_point = provers[0].layer_claim_eval_point().to_vec();
134 let mut eval_points = vec![first_eval_point];
135
136 let mut active_index = 0;
137
138 for i in 0..provers.len() {
140 if provers[i].layer_claim_eval_point() != eval_points[eval_points.len() - 1] {
141 let CompositeSumClaimsWithMultilinears {
142 composite_claims: eval_point_composite_claims,
143 multilinears: eval_point_multilinears,
144 } = build_eval_point_claims::<P>(&mut provers[active_index..i], layer_no)?;
145
146 if eval_point_composite_claims.is_empty() {
147 eval_points.pop();
149 } else {
150 composite_claims.push(eval_point_composite_claims);
151 multilinears.push(eval_point_multilinears);
152 }
153
154 eval_points.push(provers[i].layer_claim_eval_point().to_vec());
155 active_index = i;
156 }
157
158 if i == provers.len() - 1 {
159 let CompositeSumClaimsWithMultilinears {
160 composite_claims: eval_point_composite_claims,
161 multilinears: eval_point_multilinears,
162 } = build_eval_point_claims::<P>(&mut provers[active_index..], layer_no)?;
163
164 if !eval_point_composite_claims.is_empty() {
165 composite_claims.push(eval_point_composite_claims);
166 multilinears.push(eval_point_multilinears);
167 }
168 }
169 }
170
171 izip!(composite_claims, multilinears, eval_points)
172 .map(|(composite_claims, multilinears, eval_point)| {
173 GKRExpProver::<'a, FDomain, P, _, _, Backend>::new(
174 evaluation_order,
175 multilinears,
176 None,
177 composite_claims,
178 evaluation_domain_factory.clone(),
179 &eval_point,
180 backend,
181 )
182 })
183 .collect::<Result<Vec<_>, _>>()
184 .map_err(Error::from)
185}
186
187struct CompositeSumClaimsWithMultilinears<'a, P: PackedField> {
188 composite_claims: Vec<CompositeSumClaim<P::Scalar, ProverExpComposition<P::Scalar>>>,
189 multilinears: Vec<MultilinearWitness<'a, P>>,
190}
191
192fn build_eval_point_claims<'a, P>(
194 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
195 layer_no: usize,
196) -> Result<CompositeSumClaimsWithMultilinears<'a, P>, Error>
197where
198 P: PackedField,
199{
200 let (composite_claims_n_multilinears, n_claims) =
201 provers
202 .iter()
203 .fold((0, 0), |(n_multilinears, n_claims), prover| {
204 let layer_n_multilinears = prover.layer_n_multilinears(layer_no);
205 let layer_n_claims = prover.layer_n_claims(layer_no);
206
207 (n_multilinears + layer_n_multilinears, n_claims + layer_n_claims)
208 });
209
210 let mut multilinears = Vec::with_capacity(composite_claims_n_multilinears);
211
212 let mut composite_claims = Vec::with_capacity(n_claims);
213
214 for prover in provers {
215 let multilinears_index = multilinears.len();
216
217 let meta = prover.layer_composite_sum_claim(
218 layer_no,
219 composite_claims_n_multilinears,
220 multilinears_index,
221 )?;
222
223 if let Some(meta) = meta {
224 let CompositeSumClaimWithMultilinears {
225 claim,
226 multilinears: this_layer_multilinears,
227 } = meta;
228
229 composite_claims.push(claim);
230
231 multilinears.extend(this_layer_multilinears);
232 }
233 }
234 Ok(CompositeSumClaimsWithMultilinears {
235 composite_claims,
236 multilinears,
237 })
238}
239
240fn build_layer_exponent_bit_claims<'a, P>(
242 evaluation_order: EvaluationOrder,
243 provers: &mut [Box<dyn ExpProver<'a, P> + 'a>],
244 mut sumcheck_output: BatchSumcheckOutput<P::Scalar>,
245 layer_no: usize,
246) -> Result<Vec<LayerClaim<P::Scalar>>, Error>
247where
248 P: PackedField,
249{
250 let mut eval_claims_on_exponent_bit_columns = Vec::new();
251
252 for multilinear_evals in &mut sumcheck_output.multilinear_evals {
254 multilinear_evals.pop();
255 }
256
257 let mut multilinear_evals = sumcheck_output.multilinear_evals.into_iter().flatten();
258
259 for prover in provers {
260 let this_prover_n_multilinears = prover.layer_n_multilinears(layer_no);
261
262 let this_prover_multilinear_evals = multilinear_evals
263 .by_ref()
264 .take(this_prover_n_multilinears)
265 .collect::<Vec<_>>();
266
267 let exponent_bit_claims = prover.finish_layer(
268 evaluation_order,
269 layer_no,
270 &this_prover_multilinear_evals,
271 &sumcheck_output.challenges,
272 );
273
274 eval_claims_on_exponent_bit_columns.extend(exponent_bit_claims);
275 }
276
277 Ok(eval_claims_on_exponent_bit_columns)
278}
279
280fn make_provers<'a, P, FBase>(
282 witnesses: Vec<BaseExpWitness<'a, P, FBase>>,
283 claims: &[ExpClaim<P::Scalar>],
284) -> Result<Vec<Box<dyn ExpProver<'a, P> + 'a>>, Error>
285where
286 P: PackedField,
287 FBase: BinaryField,
288 P::Scalar: BinaryField + ExtensionField<FBase>,
289{
290 witnesses
291 .into_iter()
292 .zip(claims)
293 .map(|(witness, claim)| {
294 if witness.uses_dynamic_base() {
295 DynamicBaseExpProver::new(witness, claim)
296 .map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
297 } else {
298 GeneratorExpProver::<'a, P, FBase>::new(witness, claim)
299 .map(|prover| Box::new(prover) as Box<dyn ExpProver<'a, P> + 'a>)
300 }
301 })
302 .collect::<Result<Vec<_>, Error>>()
303}