binius_core/ring_switch/
prove.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, sync::Arc};
4
5use binius_field::{
6	tower::{PackedTop, TowerFamily},
7	PackedField, PackedFieldIndexable, TowerField,
8};
9use binius_hal::ComputationBackend;
10use binius_math::{MLEDirectAdapter, MultilinearPoly, MultilinearQuery};
11use binius_maybe_rayon::prelude::*;
12use binius_utils::checked_arithmetics::log2_ceil_usize;
13use tracing::instrument;
14
15use super::{
16	common::{EvalClaimPrefixDesc, EvalClaimSystem, PIOPSumcheckClaimDesc},
17	eq_ind::RowBatchCoeffs,
18	error::Error,
19	logging::MLEFoldHisgDimensionsData,
20	tower_tensor_algebra::TowerTensorAlgebra,
21};
22use crate::{
23	fiat_shamir::{CanSample, Challenger},
24	piop::PIOPSumcheckClaim,
25	protocols::evalcheck::subclaims::MemoizedData,
26	ring_switch::{
27		common::EvalClaimSuffixDesc, eq_ind::RingSwitchEqInd, logging::CalculateRingSwitchEqIndData,
28	},
29	transcript::ProverTranscript,
30	witness::MultilinearWitness,
31};
32
33type FExt<Tower> = <Tower as TowerFamily>::B128;
34
35#[derive(Debug)]
36pub struct ReducedWitness<P: PackedField> {
37	pub transparents: Vec<MultilinearWitness<'static, P>>,
38	pub sumcheck_claims: Vec<PIOPSumcheckClaim<P::Scalar>>,
39}
40
41pub fn prove<F, P, M, Tower, Challenger_, Backend>(
42	system: &EvalClaimSystem<F>,
43	witnesses: &[M],
44	transcript: &mut ProverTranscript<Challenger_>,
45	memoized_data: MemoizedData<P, Backend>,
46	backend: &Backend,
47) -> Result<ReducedWitness<P>, Error>
48where
49	F: TowerField + PackedTop<Tower>,
50	P: PackedFieldIndexable<Scalar = F>,
51	M: MultilinearPoly<P> + Sync,
52	Tower: TowerFamily<B128 = F>,
53	Challenger_: Challenger,
54	Backend: ComputationBackend,
55{
56	if witnesses.len() != system.commit_meta.total_multilins() {
57		return Err(Error::InvalidWitness(
58			"witness length does not match the number of multilinears".into(),
59		));
60	}
61
62	// Sample enough randomness to batch tensor elements corresponding to claims that share an
63	// evaluation point prefix.
64	let n_mixing_challenges = log2_ceil_usize(system.sumcheck_claim_descs.len());
65	let mixing_challenges = transcript.sample_vec(n_mixing_challenges);
66	let dimensions_data = MLEFoldHisgDimensionsData::new(witnesses);
67	let mle_fold_high_span = tracing::debug_span!(
68		"[task] (Ring Switch) MLE Fold High",
69		phase = "ring_switch",
70		perfetto_category = "task.main",
71		dimensions_data = ?dimensions_data,
72	)
73	.entered();
74
75	let mixing_coeffs = MultilinearQuery::expand(&mixing_challenges).into_expansion();
76
77	// For each evaluation point prefix, send one batched partial evaluation.
78	let tensor_elems =
79		compute_partial_evals::<_, _, _, Tower, _>(system, witnesses, memoized_data, backend)?;
80	let scaled_tensor_elems = scale_tensor_elems(tensor_elems, &mixing_coeffs);
81	let mixed_tensor_elems = mix_tensor_elems_for_prefixes(
82		&scaled_tensor_elems,
83		&system.prefix_descs,
84		&system.eval_claim_to_prefix_desc_index,
85	)?;
86	drop(mle_fold_high_span);
87	let mut writer = transcript.message();
88	for (mixed_tensor_elem, prefix_desc) in iter::zip(mixed_tensor_elems, &system.prefix_descs) {
89		debug_assert_eq!(mixed_tensor_elem.vertical_elems().len(), 1 << prefix_desc.kappa());
90		writer.write_scalar_slice(mixed_tensor_elem.vertical_elems());
91	}
92
93	// Sample the row-batching randomness.
94	let row_batch_challenges = transcript.sample_vec(system.max_claim_kappa());
95	let row_batch_coeffs = Arc::new(RowBatchCoeffs::new(
96		MultilinearQuery::<F, _>::expand(&row_batch_challenges).into_expansion(),
97	));
98
99	let row_batched_evals =
100		compute_row_batched_sumcheck_evals(scaled_tensor_elems, row_batch_coeffs.coeffs());
101	transcript.message().write_scalar_slice(&row_batched_evals);
102
103	// Create the reduced PIOP sumcheck witnesses.
104	let dimensions_data = CalculateRingSwitchEqIndData::new(system.suffix_descs.iter());
105	let calculate_ring_switch_eq_ind_span = tracing::debug_span!(
106		"[task] Calculate Ring Switch Eq Ind",
107		phase = "ring_switch",
108		perfetto_category = "task.main",
109		dimensions_data = ?dimensions_data,
110	)
111	.entered();
112
113	let ring_switch_eq_inds = make_ring_switch_eq_inds::<_, P, Tower>(
114		&system.sumcheck_claim_descs,
115		&system.suffix_descs,
116		row_batch_coeffs,
117		&mixing_coeffs,
118	)?;
119	drop(calculate_ring_switch_eq_ind_span);
120
121	let sumcheck_claims = iter::zip(&system.sumcheck_claim_descs, row_batched_evals)
122		.enumerate()
123		.map(|(idx, (claim_desc, eval))| {
124			let suffix_desc = &system.suffix_descs[claim_desc.suffix_desc_idx];
125			PIOPSumcheckClaim {
126				n_vars: suffix_desc.suffix.len(),
127				committed: claim_desc.committed_idx,
128				transparent: idx,
129				sum: eval,
130			}
131		})
132		.collect::<Vec<_>>();
133
134	Ok(ReducedWitness {
135		transparents: ring_switch_eq_inds,
136		sumcheck_claims,
137	})
138}
139
140#[instrument(skip_all)]
141fn compute_partial_evals<F, P, M, Tower, Backend>(
142	system: &EvalClaimSystem<F>,
143	witnesses: &[M],
144	mut memoized_data: MemoizedData<P, Backend>,
145	backend: &Backend,
146) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
147where
148	F: TowerField,
149	P: PackedField<Scalar = F>,
150	M: MultilinearPoly<P> + Sync,
151	Tower: TowerFamily<B128 = F>,
152	Backend: ComputationBackend,
153{
154	let suffixes = system
155		.suffix_descs
156		.iter()
157		.map(|desc| Arc::as_ref(&desc.suffix))
158		.collect::<Vec<_>>();
159
160	memoized_data.memoize_query_par(suffixes, backend)?;
161
162	let tensor_elems = system
163		.sumcheck_claim_descs
164		.par_iter()
165		.map(
166			|PIOPSumcheckClaimDesc {
167			     committed_idx,
168			     suffix_desc_idx,
169			     eval_claim,
170			 }| {
171				let suffix_desc = &system.suffix_descs[*suffix_desc_idx];
172
173				let mut elems = if let Some(partial_eval) =
174					memoized_data.partial_eval(eval_claim.id, Arc::as_ref(&suffix_desc.suffix))
175				{
176					PackedField::iter_slice(
177						partial_eval.packed_evals().expect("packed_evals exist"),
178					)
179					.take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
180					.collect::<Vec<_>>()
181				} else {
182					let suffix_query = memoized_data
183						.full_query_readonly(&suffix_desc.suffix)
184						.expect("memoized above");
185					let partial_eval =
186						witnesses[*committed_idx].evaluate_partial_high(suffix_query.into())?;
187					PackedField::iter_slice(partial_eval.evals())
188						.take((1 << suffix_desc.kappa).min(1 << partial_eval.n_vars()))
189						.collect::<Vec<_>>()
190				};
191
192				if elems.len() < (1 << suffix_desc.kappa) {
193					elems = elems
194						.into_iter()
195						.cycle()
196						.take(1 << suffix_desc.kappa)
197						.collect();
198				}
199				TowerTensorAlgebra::new(suffix_desc.kappa, elems)
200			},
201		)
202		.collect::<Result<Vec<_>, _>>()?;
203
204	Ok(tensor_elems)
205}
206
207fn scale_tensor_elems<F, Tower>(
208	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
209	mixing_coeffs: &[F],
210) -> Vec<TowerTensorAlgebra<Tower>>
211where
212	F: TowerField,
213	Tower: TowerFamily<B128 = F>,
214{
215	// Precondition
216	assert!(tensor_elems.len() <= mixing_coeffs.len());
217
218	tensor_elems
219		.into_par_iter()
220		.zip(mixing_coeffs)
221		.map(|(tensor_elem, &mixing_coeff)| tensor_elem.scale_vertical(mixing_coeff))
222		.collect()
223}
224
225fn mix_tensor_elems_for_prefixes<F, Tower>(
226	scaled_tensor_elems: &[TowerTensorAlgebra<Tower>],
227	prefix_descs: &[EvalClaimPrefixDesc<F>],
228	eval_claim_to_prefix_desc_index: &[usize],
229) -> Result<Vec<TowerTensorAlgebra<Tower>>, Error>
230where
231	F: TowerField,
232	Tower: TowerFamily<B128 = F>,
233{
234	// Precondition
235	assert_eq!(scaled_tensor_elems.len(), eval_claim_to_prefix_desc_index.len());
236
237	let mut batched_tensor_elems = prefix_descs
238		.iter()
239		.map(|desc| TowerTensorAlgebra::zero(desc.kappa()))
240		.collect::<Result<Vec<_>, _>>()?;
241	for (tensor_elem, &desc_index) in
242		iter::zip(scaled_tensor_elems, eval_claim_to_prefix_desc_index)
243	{
244		let mixed_val = &mut batched_tensor_elems[desc_index];
245		debug_assert_eq!(mixed_val.kappa(), tensor_elem.kappa());
246		mixed_val.add_assign(tensor_elem)?;
247	}
248	Ok(batched_tensor_elems)
249}
250
251#[instrument(skip_all)]
252fn compute_row_batched_sumcheck_evals<F, Tower>(
253	tensor_elems: Vec<TowerTensorAlgebra<Tower>>,
254	row_batch_coeffs: &[F],
255) -> Vec<F>
256where
257	F: TowerField,
258	Tower: TowerFamily<B128 = F>,
259	F: PackedTop<Tower>,
260{
261	tensor_elems
262		.into_par_iter()
263		.map(|tensor_elem| tensor_elem.fold_vertical(row_batch_coeffs))
264		.collect()
265}
266
267fn make_ring_switch_eq_inds<F, P, Tower>(
268	sumcheck_claim_descs: &[PIOPSumcheckClaimDesc<F>],
269	suffix_descs: &[EvalClaimSuffixDesc<F>],
270	row_batch_coeffs: Arc<RowBatchCoeffs<F>>,
271	mixing_coeffs: &[F],
272) -> Result<Vec<MultilinearWitness<'static, P>>, Error>
273where
274	F: TowerField + PackedTop<Tower>,
275	P: PackedFieldIndexable<Scalar = F>,
276	Tower: TowerFamily<B128 = F>,
277{
278	sumcheck_claim_descs
279		.par_iter()
280		.zip(mixing_coeffs)
281		.map(|(claim_desc, &mixing_coeff)| {
282			let suffix_desc = &suffix_descs[claim_desc.suffix_desc_idx];
283			make_ring_switch_eq_ind::<P, Tower>(suffix_desc, row_batch_coeffs.clone(), mixing_coeff)
284		})
285		.collect()
286}
287
288fn make_ring_switch_eq_ind<P, Tower>(
289	suffix_desc: &EvalClaimSuffixDesc<FExt<Tower>>,
290	row_batch_coeffs: Arc<RowBatchCoeffs<FExt<Tower>>>,
291	mixing_coeff: FExt<Tower>,
292) -> Result<MultilinearWitness<'static, P>, Error>
293where
294	P: PackedFieldIndexable<Scalar = FExt<Tower>>,
295	Tower: TowerFamily,
296{
297	let eq_ind = match suffix_desc.kappa {
298		7 => RingSwitchEqInd::<Tower::B1, _>::new(
299			suffix_desc.suffix.clone(),
300			row_batch_coeffs,
301			mixing_coeff,
302		)?
303		.multilinear_extension::<P>(),
304		4 => RingSwitchEqInd::<Tower::B8, _>::new(
305			suffix_desc.suffix.clone(),
306			row_batch_coeffs,
307			mixing_coeff,
308		)?
309		.multilinear_extension(),
310		3 => RingSwitchEqInd::<Tower::B16, _>::new(
311			suffix_desc.suffix.clone(),
312			row_batch_coeffs,
313			mixing_coeff,
314		)?
315		.multilinear_extension(),
316		2 => RingSwitchEqInd::<Tower::B32, _>::new(
317			suffix_desc.suffix.clone(),
318			row_batch_coeffs,
319			mixing_coeff,
320		)?
321		.multilinear_extension(),
322		1 => RingSwitchEqInd::<Tower::B64, _>::new(
323			suffix_desc.suffix.clone(),
324			row_batch_coeffs,
325			mixing_coeff,
326		)?
327		.multilinear_extension(),
328		0 => RingSwitchEqInd::<Tower::B128, _>::new(
329			suffix_desc.suffix.clone(),
330			row_batch_coeffs,
331			mixing_coeff,
332		)?
333		.multilinear_extension(),
334		_ => Err(Error::PackingDegreeNotSupported {
335			kappa: suffix_desc.kappa,
336		}),
337	}?;
338	Ok(MLEDirectAdapter::from(eq_ind).upcast_arc_dyn())
339}