binius_core/ring_switch/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_field::{PackedField, PackedFieldIndexable, TowerField};
6use binius_hal::ComputationBackend;
7use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
8use binius_maybe_rayon::prelude::*;
9use binius_utils::checked_arithmetics::log2_ceil_usize;
10use tracing::instrument;
11
12use super::{
13	common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
14	eq_ind::RowBatchCoeffs,
15	error::Error,
16	tower_tensor_algebra::TowerTensorAlgebra,
17};
18use crate::{
19	fiat_shamir::{CanSample, Challenger},
20	piop::PIOPSumcheckClaim,
21	protocols::evalcheck::subclaims::MemoizedData,
22	ring_switch::{common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd},
23	tower::{PackedTop, TowerFamily},
24	transcript::ProverTranscript,
25	witness::MultilinearWitness,
26};
27
28type FExt<Tower> = <Tower as TowerFamily>::B128;
29
30#[derive(Debug)]
31pub struct ReducedWitness<P: PackedField> {
32	pub transparents: Vec<MultilinearWitness<'static, P>>,
33	pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
34}
35
36#[tracing::instrument("ring_switch::prove", skip_all)]
37pub fn prove<F, P, M, Tower, Challenger_, Backend>(
38	system: &EvalClaimSystem<F>,
39	witnesses: &[M],
40	transcript: &mut ProverTranscript<Challenger_>,
41	memoized_data: MemoizedData<P, Backend>,
42	backend: &Backend,
43) -> Result<ReducedWitness<P>, Error>
44where
45	F: TowerField,
46	P: PackedFieldIndexable<Scalar = F>,
47	M: MultilinearPoly<P> + Sync,
48	Tower: TowerFamily<B128 = F>,
49	F: PackedTop<Tower>,
50	Challenger_: Challenger,
51	Backend: ComputationBackend,
52{
53	if witnesses.len() != system.commit_meta.total_multilins() {
54		return Err(Error::InvalidWitness(
55			"witness length does not match the number of multilinears".into(),
56		));
57	}
58
59	// Sample enough randomness to batch tensor elements corresponding to claims that share an
60	// evaluation point prefix.
61	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
62	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
63	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
64
65	// For each evaluation point prefix, send one batched partial evaluation.
66	let tensor_elems =
67		compute_partial_evals::<_, _, _, Tower, _>(system, witnesses, memoized_data, backend)?;
68	let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
69	let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
70		&scaled_tensor_elems,
71		&system.prefix_descs,
72		&system.eval_claim_to_prefix_desc_index,
73	)?;
74	let mut writer = transcript.message();
75	for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
76		debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
77		writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
78	}
79
80	// Sample the row-batching randomness.
81	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
82	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
83		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
84	));
85
86	let row_batched_evals =
87		compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
88	transcript.message().write_scalar_slice(&row_batched_evals);
89
90	// Create the reduced PIOP sumcheck witnesses.
91	let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
92		&system.sumcheck_claim_descs,
93		&system.suffix_descs,
94		row_batch_coeffs,
95		&mixing_coeffs,
96	)?;
97	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
98		.enumerate()
99		.map(|(idx, (claim_desc, eval))| {
100			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
101			PIOPSumcheckClaim {
102				n_vars: suffix_desc.suffix.len(),
103				committed: claim_desc.committed_idx,
104				transparent: idx,
105				sum: eval,
106			}
107		})
108		.collect::<Vec<_>>();
109
110	Ok(ReducedWitness {
111		transparents: ring_switch_eq_inds,
112		sumcheck_claims,
113	})
114}
115
116#[instrument(skip_all)]
117fn compute_partial_evals<F, P, M, Tower, Backend>(
118	system: &EvalClaimSystem<F>,
119	witnesses: &[M],
120	mut memoized_data: MemoizedData<P, Backend>,
121	backend: &Backend,
122) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
123where
124	F: TowerField,
125	P: PackedField<Scalar = F>,
126	M: MultilinearPoly<P> + Sync,
127	Tower: TowerFamily<B128 = F>,
128	Backend: ComputationBackend,
129{
130	let suffixes = system
131		.suffix_descs
132		.iter()
133		.map(|desc| Arc::as_ref(&desc.suffix))
134		.collect::<Vec<_>>();
135
136	memoized_data.memoize_query_par(&suffixes, backend)?;
137
138	let tensor_elems = system
139		.sumcheck_claim_descs
140		.par_iter()
141		.map(
142			|PIOPSumcheckClaimDesc {
143			     committed_idx,
144			     suffix_desc_idx,
145			     eval_claim,
146			 }| {
147				let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
148
149				let elems = if let Some(partial_eval) =
150					memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
151				{
152					PackedField::iter_slice(
153						partial_eval.packed_evals().expect("packed_evals exist"),
154					)
155					.take(1 << suffix_desc.kappa)
156					.collect()
157				} else {
158					let suffix_query = memoized_data
159						.full_query_readonly(&suffix_desc.suffix)
160						.expect("memoized above");
161					let partial_eval =
162						witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
163					PackedField::iter_slice(partial_eval.evals())
164						.take(1 << suffix_desc.kappa)
165						.collect()
166				};
167
168				TowerTensorAlgebra::new(suffix_desc.kappa, elems)
169			},
170		)
171		.collect::<Result<Vec<_>, _>>()?;
172
173	Ok(tensor_elems)
174}
175
176fn scale_tensor_elems<F, Tower>(
177	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
178	mixing_coeffs: &[F],
179) -> Vec<TowerTensorAlgebra<Tower>>
180where
181	F: TowerField,
182	Tower: TowerFamily<B128 = F>,
183{
184	// Precondition
185	assert!(tensor_elems.len() <= mixing_coeffs.len());
186
187	tensor_elems
188		.into_par_iter()
189		.zip(mixing_coeffs)
190		.map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
191		.collect()
192}
193
194fn mix_tensor_elems_for_prefixes<F, Tower>(
195	scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
196	prefix_descs: &[EvalClaimPrefixDesc<F>],
197	eval_claim_to_prefix_desc_index: &[usize],
198) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
199where
200	F: TowerField,
201	Tower: TowerFamily<B128 = F>,
202{
203	// Precondition
204	assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
205
206	let mut batched_tensor_elems = prefix_descs
207		.iter()
208		.map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
209		.collect::<Result<Vec<_>, _>>()?;
210	for (tensor_elem, &desc_index) in
211		iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
212	{
213		let mixed_val = &mut batched_tensor_elems[desc_index];
214		debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
215		mixed_val.add_assign(tensor_elem)?;
216	}
217	Ok(batched_tensor_elems)
218}
219
220#[instrument(skip_all)]
221fn compute_row_batched_sumcheck_evals<F, Tower>(
222	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
223	row_batch_coeffs: &[F],
224) -> Vec<F>
225where
226	F: TowerField,
227	Tower: TowerFamily<B128 = F>,
228	F: PackedTop<Tower>,
229{
230	tensor_elems
231		.into_par_iter()
232		.map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
233		.collect()
234}
235
236#[instrument(skip_all)]
237fn make_ring_switch_eq_inds<F, P, Tower>(
238	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
239	suffix_descs: &[EvalClaimSuffixDesc<F>],
240	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
241	mixing_coeffs: &[F],
242) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
243where
244	F: TowerField,
245	P: PackedFieldIndexable<Scalar = F>,
246	Tower: TowerFamily<B128 = F>,
247	F: PackedTop<Tower>,
248{
249	sumcheck_claim_descs
250		.par_iter()
251		.zip(mixing_coeffs)
252		.map(|(claim_desc, &mixing_coeff)| {
253			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
254			make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
255		})
256		.collect()
257}
258
259fn make_ring_switch_eq_ind<P, Tower>(
260	suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
261	row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
262	mixing_coeff: FExt<Tower>,
263) -> Result<MultilinearWitness<'static, P>, Error>
264where
265	P: PackedFieldIndexable<Scalar = FExt<Tower>>,
266	Tower: TowerFamily,
267{
268	let eq_ind = match suffix_desc.kappa {
269		7 => RingSwitchEqInd::<Tower::B1, _>::new(
270			suffix_desc.suffix.clone(),
271			row_batch_coeffs,
272			mixing_coeff,
273		)?
274		.multilinear_extension::<P>(),
275		4 => RingSwitchEqInd::<Tower::B8, _>::new(
276			suffix_desc.suffix.clone(),
277			row_batch_coeffs,
278			mixing_coeff,
279		)?
280		.multilinear_extension(),
281		3 => RingSwitchEqInd::<Tower::B16, _>::new(
282			suffix_desc.suffix.clone(),
283			row_batch_coeffs,
284			mixing_coeff,
285		)?
286		.multilinear_extension(),
287		2 => RingSwitchEqInd::<Tower::B32, _>::new(
288			suffix_desc.suffix.clone(),
289			row_batch_coeffs,
290			mixing_coeff,
291		)?
292		.multilinear_extension(),
293		1 => RingSwitchEqInd::<Tower::B64, _>::new(
294			suffix_desc.suffix.clone(),
295			row_batch_coeffs,
296			mixing_coeff,
297		)?
298		.multilinear_extension(),
299		0 => RingSwitchEqInd::<Tower::B128, _>::new(
300			suffix_desc.suffix.clone(),
301			row_batch_coeffs,
302			mixing_coeff,
303		)?
304		.multilinear_extension(),
305		_ => Err(Error::PackingDegreeNotSupported {
306			kappa: suffix_desc.kappa,
307		}),
308	}?;
309	Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
310}