binius_core/ring_switch/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_field::{PackedField, PackedFieldIndexable, TowerField};
6use binius_hal::{ComputationBackend, ComputationBackendExt};
7use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::checked_arithmetics::log2_ceil_usize;
10use tracing::instrument;
11
12use super::{
13	common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
14	eq_ind::RowBatchCoeffs,
15	error::Error,
16	tower_tensor_algebra::TowerTensorAlgebra,
17};
18use crate::{
19	fiat_shamir::{CanSample, Challenger},
20	piop::PIOPSumcheckClaim,
21	ring_switch::{common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd},
22	tower::{PackedTop, TowerFamily},
23	transcript::ProverTranscript,
24	witness::MultilinearWitness,
25};
26
27type FExt<Tower> = <Tower as TowerFamily>::B128;
28
29#[derive(Debug)]
30pub struct ReducedWitness<P: PackedField> {
31	pub transparents: Vec<MultilinearWitness<'static, P>>,
32	pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
33}
34
35#[tracing::instrument("ring_switch::prove", skip_all)]
36pub fn prove<F, P, M, Tower, Challenger_, Backend>(
37	system: &EvalClaimSystem<F>,
38	witnesses: &[M],
39	transcript: &mut ProverTranscript<Challenger_>,
40	backend: &Backend,
41) -> Result<ReducedWitness<P>, Error>
42where
43	F: TowerField,
44	P: PackedFieldIndexable<Scalar = F>,
45	M: MultilinearPoly<P> + Sync,
46	Tower: TowerFamily<B128 = F>,
47	F: PackedTop<Tower>,
48	Challenger_: Challenger,
49	Backend: ComputationBackend,
50{
51	if witnesses.len() != system.commit_meta.total_multilins() {
52		return Err(Error::InvalidWitness(
53			"witness length does not match the number of multilinears".into(),
54		));
55	}
56
57	// Sample enough randomness to batch tensor elements corresponding to claims that share an
58	// evaluation point prefix.
59	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
60	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
61	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
62
63	// For each evaluation point prefix, send one batched partial evaluation.
64	let tensor_elems = compute_partial_evals::<_, _, _, Tower, _>(system, witnesses, backend)?;
65	let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
66	let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
67		&scaled_tensor_elems,
68		&system.prefix_descs,
69		&system.eval_claim_to_prefix_desc_index,
70	)?;
71	let mut writer = transcript.message();
72	for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
73		debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
74		writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
75	}
76
77	// Sample the row-batching randomness.
78	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
79	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
80		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
81	));
82
83	let row_batched_evals =
84		compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
85	transcript.message().write_scalar_slice(&row_batched_evals);
86
87	// Create the reduced PIOP sumcheck witnesses.
88	let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
89		&system.sumcheck_claim_descs,
90		&system.suffix_descs,
91		row_batch_coeffs,
92		&mixing_coeffs,
93	)?;
94	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
95		.enumerate()
96		.map(|(idx, (claim_desc, eval))| {
97			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
98			PIOPSumcheckClaim {
99				n_vars: suffix_desc.suffix.len(),
100				committed: claim_desc.committed_idx,
101				transparent: idx,
102				sum: eval,
103			}
104		})
105		.collect::<Vec<_>>();
106
107	Ok(ReducedWitness {
108		transparents: ring_switch_eq_inds,
109		sumcheck_claims,
110	})
111}
112
113#[instrument(skip_all)]
114fn compute_partial_evals<F, P, M, Tower, Backend>(
115	system: &EvalClaimSystem<F>,
116	witnesses: &[M],
117	backend: &Backend,
118) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
119where
120	F: TowerField,
121	P: PackedField<Scalar = F>,
122	M: MultilinearPoly<P> + Sync,
123	Tower: TowerFamily<B128 = F>,
124	Backend: ComputationBackend,
125{
126	let suffix_queries = system
127		.suffix_descs
128		.par_iter()
129		.map(|desc| backend.multilinear_query(&desc.suffix))
130		.collect::<Result<Vec<_>, _>>()?;
131
132	let tensor_elems = system
133		.sumcheck_claim_descs
134		.par_iter()
135		.map(
136			|PIOPSumcheckClaimDesc {
137			     committed_idx,
138			     suffix_desc_idx,
139			     eval_claim: _,
140			 }| {
141				let suffix = &system.suffix_descs[*suffix_desc_idx];
142				let suffix_query = &suffix_queries[*suffix_desc_idx];
143				let partial_eval =
144					witnesses[*committed_idx].evaluate_partial_high(suffix_query.to_ref())?;
145				TowerTensorAlgebra::new(
146					suffix.kappa,
147					PackedField::iter_slice(partial_eval.evals())
148						.take(1 << suffix.kappa)
149						.collect(),
150				)
151			},
152		)
153		.collect::<Result<Vec<_>, _>>()?;
154
155	Ok(tensor_elems)
156}
157
158fn scale_tensor_elems<F, Tower>(
159	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
160	mixing_coeffs: &[F],
161) -> Vec<TowerTensorAlgebra<Tower>>
162where
163	F: TowerField,
164	Tower: TowerFamily<B128 = F>,
165{
166	// Precondition
167	assert!(tensor_elems.len() <= mixing_coeffs.len());
168
169	tensor_elems
170		.into_par_iter()
171		.zip(mixing_coeffs)
172		.map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
173		.collect()
174}
175
176fn mix_tensor_elems_for_prefixes<F, Tower>(
177	scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
178	prefix_descs: &[EvalClaimPrefixDesc<F>],
179	eval_claim_to_prefix_desc_index: &[usize],
180) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
181where
182	F: TowerField,
183	Tower: TowerFamily<B128 = F>,
184{
185	// Precondition
186	assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
187
188	let mut batched_tensor_elems = prefix_descs
189		.iter()
190		.map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
191		.collect::<Result<Vec<_>, _>>()?;
192	for (tensor_elem, &desc_index) in
193		iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
194	{
195		let mixed_val = &mut batched_tensor_elems[desc_index];
196		debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
197		mixed_val.add_assign(tensor_elem)?;
198	}
199	Ok(batched_tensor_elems)
200}
201
202#[instrument(skip_all)]
203fn compute_row_batched_sumcheck_evals<F, Tower>(
204	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
205	row_batch_coeffs: &[F],
206) -> Vec<F>
207where
208	F: TowerField,
209	Tower: TowerFamily<B128 = F>,
210	F: PackedTop<Tower>,
211{
212	tensor_elems
213		.into_par_iter()
214		.map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
215		.collect()
216}
217
218#[instrument(skip_all)]
219fn make_ring_switch_eq_inds<F, P, Tower>(
220	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
221	suffix_descs: &[EvalClaimSuffixDesc<F>],
222	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
223	mixing_coeffs: &[F],
224) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
225where
226	F: TowerField,
227	P: PackedFieldIndexable<Scalar = F>,
228	Tower: TowerFamily<B128 = F>,
229	F: PackedTop<Tower>,
230{
231	sumcheck_claim_descs
232		.par_iter()
233		.zip(mixing_coeffs)
234		.map(|(claim_desc, &mixing_coeff)| {
235			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
236			make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
237		})
238		.collect()
239}
240
241fn make_ring_switch_eq_ind<P, Tower>(
242	suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
243	row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
244	mixing_coeff: FExt<Tower>,
245) -> Result<MultilinearWitness<'static, P>, Error>
246where
247	P: PackedFieldIndexable<Scalar = FExt<Tower>>,
248	Tower: TowerFamily,
249{
250	let eq_ind = match suffix_desc.kappa {
251		7 => RingSwitchEqInd::<Tower::B1, _>::new(
252			suffix_desc.suffix.clone(),
253			row_batch_coeffs,
254			mixing_coeff,
255		)?
256		.multilinear_extension::<P>(),
257		4 => RingSwitchEqInd::<Tower::B8, _>::new(
258			suffix_desc.suffix.clone(),
259			row_batch_coeffs,
260			mixing_coeff,
261		)?
262		.multilinear_extension(),
263		3 => RingSwitchEqInd::<Tower::B16, _>::new(
264			suffix_desc.suffix.clone(),
265			row_batch_coeffs,
266			mixing_coeff,
267		)?
268		.multilinear_extension(),
269		2 => RingSwitchEqInd::<Tower::B32, _>::new(
270			suffix_desc.suffix.clone(),
271			row_batch_coeffs,
272			mixing_coeff,
273		)?
274		.multilinear_extension(),
275		1 => RingSwitchEqInd::<Tower::B64, _>::new(
276			suffix_desc.suffix.clone(),
277			row_batch_coeffs,
278			mixing_coeff,
279		)?
280		.multilinear_extension(),
281		0 => RingSwitchEqInd::<Tower::B128, _>::new(
282			suffix_desc.suffix.clone(),
283			row_batch_coeffs,
284			mixing_coeff,
285		)?
286		.multilinear_extension(),
287		_ => Err(Error::PackingDegreeNotSupported {
288			kappa: suffix_desc.kappa,
289		}),
290	}?;
291	Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
292}