binius_core/constraint_system/
validate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField1b, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackendExt;
5use binius_math::MultilinearPoly;
6use binius_utils::bail;
7
8use super::{
9	channel::{self, Boundary},
10	error::Error,
11	ConstraintSystem,
12};
13use crate::{
14	oracle::{
15		ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
16		ShiftVariant,
17	},
18	polynomial::{
19		test_utils::decompose_index_to_hypercube_point, ArithCircuitPoly, MultilinearComposite,
20	},
21	protocols::sumcheck::prove::zerocheck,
22	witness::MultilinearExtensionIndex,
23};
24
25pub fn validate_witness<F, P>(
26	constraint_system: &ConstraintSystem<F>,
27	boundaries: &[Boundary<F>],
28	witness: &MultilinearExtensionIndex<'_, P>,
29) -> Result<(), Error>
30where
31	P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
32	F: TowerField,
33{
34	// Check the constraint sets
35	for constraint_set in &constraint_system.table_constraints {
36		let multilinears = constraint_set
37			.oracle_ids
38			.iter()
39			.map(|id| witness.get_multilin_poly(*id))
40			.collect::<Result<Vec<_>, _>>()?;
41
42		let mut zero_claims = vec![];
43		for constraint in &constraint_set.constraints {
44			match constraint.predicate {
45				ConstraintPredicate::Zero => zero_claims.push((
46					constraint.name.clone(),
47					ArithCircuitPoly::with_n_vars(
48						multilinears.len(),
49						constraint.composition.clone(),
50					)?,
51				)),
52				ConstraintPredicate::Sum(_) => unimplemented!(),
53			}
54		}
55		zerocheck::validate_witness(&multilinears, &zero_claims)?;
56	}
57
58	// Check that nonzero oracles are non-zero over the entire hypercube
59	nonzerocheck::validate_witness(
60		witness,
61		&constraint_system.oracles,
62		&constraint_system.non_zero_oracle_ids,
63	)?;
64
65	// Check that the channels balance with flushes and boundaries
66	channel::validate_witness(
67		witness,
68		&constraint_system.flushes,
69		boundaries,
70		constraint_system.max_channel_id,
71	)?;
72
73	// Check consistency of virtual oracle witnesses (eg. that shift polynomials are actually shifts).
74	for oracle in constraint_system.oracles.iter() {
75		validate_virtual_oracle_witness(oracle, &constraint_system.oracles, witness)?;
76	}
77
78	Ok(())
79}
80
81pub fn validate_virtual_oracle_witness<F, P>(
82	oracle: MultilinearPolyOracle<F>,
83	oracles: &MultilinearOracleSet<F>,
84	witness: &MultilinearExtensionIndex<P>,
85) -> Result<(), Error>
86where
87	P: PackedField<Scalar = F>,
88	F: TowerField,
89{
90	let oracle_label = &oracle.label();
91	let n_vars = oracle.n_vars();
92	let poly = witness.get_multilin_poly(oracle.id())?;
93
94	if poly.n_vars() != n_vars {
95		bail!(Error::VirtualOracleNvarsMismatch {
96			oracle: oracle_label.into(),
97			oracle_num_vars: n_vars,
98			witness_num_vars: poly.n_vars(),
99		})
100	}
101
102	match oracle.variant {
103		MultilinearPolyVariant::Committed => {
104			// Committed oracles don't need to be checked as they are allowed to contain any data here
105		}
106		MultilinearPolyVariant::Transparent(inner) => {
107			for i in 0..1 << n_vars {
108				let got = poly.evaluate_on_hypercube(i)?;
109				let expected = inner
110					.poly()
111					.evaluate(&decompose_index_to_hypercube_point(n_vars, i))?;
112				check_eval(oracle_label, i, expected, got)?;
113			}
114		}
115		MultilinearPolyVariant::LinearCombination(linear_combination) => {
116			let uncombined_polys = linear_combination
117				.polys()
118				.map(|id| witness.get_multilin_poly(id))
119				.collect::<Result<Vec<_>, _>>()?;
120			for i in 0..1 << n_vars {
121				let got = poly.evaluate_on_hypercube(i)?;
122				let expected = linear_combination
123					.coefficients()
124					.zip(uncombined_polys.iter())
125					.try_fold(linear_combination.offset(), |acc, (coeff, poly)| {
126						Ok::<F, Error>(acc + poly.evaluate_on_hypercube_and_scale(i, coeff)?)
127					})?;
128				check_eval(oracle_label, i, expected, got)?;
129			}
130		}
131		MultilinearPolyVariant::Repeating { id, .. } => {
132			let unrepeated_poly = witness.get_multilin_poly(id)?;
133			let unrepeated_n_vars = oracles.n_vars(id);
134			for i in 0..1 << n_vars {
135				let got = poly.evaluate_on_hypercube(i)?;
136				let expected =
137					unrepeated_poly.evaluate_on_hypercube(i % (1 << unrepeated_n_vars))?;
138				check_eval(oracle_label, i, expected, got)?;
139			}
140		}
141		MultilinearPolyVariant::Shifted(shifted) => {
142			let unshifted_poly = witness.get_multilin_poly(shifted.id())?;
143			let block_len = 1 << shifted.block_size();
144			let shift_offset = shifted.shift_offset();
145			for block_start in (0..1 << n_vars).step_by(block_len) {
146				match shifted.shift_variant() {
147					ShiftVariant::CircularLeft => {
148						for offset_after in 0..block_len {
149							check_eval(
150								oracle_label,
151								block_start + offset_after,
152								unshifted_poly.evaluate_on_hypercube(
153									block_start
154										+ (offset_after + (block_len - shift_offset)) % block_len,
155								)?,
156								poly.evaluate_on_hypercube(block_start + offset_after)?,
157							)?;
158						}
159					}
160					ShiftVariant::LogicalLeft => {
161						for offset_after in 0..shift_offset {
162							check_eval(
163								oracle_label,
164								block_start + offset_after,
165								F::ZERO,
166								poly.evaluate_on_hypercube(block_start + offset_after)?,
167							)?;
168						}
169						for offset_after in shift_offset..block_len {
170							check_eval(
171								oracle_label,
172								block_start + offset_after,
173								unshifted_poly.evaluate_on_hypercube(
174									block_start + offset_after - shift_offset,
175								)?,
176								poly.evaluate_on_hypercube(block_start + offset_after)?,
177							)?;
178						}
179					}
180					ShiftVariant::LogicalRight => {
181						for offset_after in 0..block_len - shift_offset {
182							check_eval(
183								oracle_label,
184								block_start + offset_after,
185								unshifted_poly.evaluate_on_hypercube(
186									block_start + offset_after + shift_offset,
187								)?,
188								poly.evaluate_on_hypercube(block_start + offset_after)?,
189							)?;
190						}
191						for offset_after in block_len - shift_offset..block_len {
192							check_eval(
193								oracle_label,
194								block_start + offset_after,
195								F::ZERO,
196								poly.evaluate_on_hypercube(block_start + offset_after)?,
197							)?;
198						}
199					}
200				}
201			}
202		}
203		MultilinearPolyVariant::Projected(projected) => {
204			let unprojected_poly = witness.get_multilin_poly(projected.id())?;
205			let partial_query =
206				binius_hal::make_portable_backend().multilinear_query(projected.values())?;
207			let projected_poly = unprojected_poly
208				.evaluate_partial(partial_query.to_ref(), projected.start_index())?;
209
210			for i in 0..1 << n_vars {
211				check_eval(
212					oracle_label,
213					i,
214					projected_poly.evaluate_on_hypercube(i)?,
215					poly.evaluate_on_hypercube(i)?,
216				)?;
217			}
218		}
219		MultilinearPolyVariant::ZeroPadded(inner_id) => {
220			let unpadded_poly = witness.get_multilin_poly(inner_id)?;
221			for i in 0..1 << unpadded_poly.n_vars() {
222				check_eval(
223					oracle_label,
224					i,
225					unpadded_poly.evaluate_on_hypercube(i)?,
226					poly.evaluate_on_hypercube(i)?,
227				)?;
228			}
229			for i in 1 << unpadded_poly.n_vars()..1 << n_vars {
230				check_eval(oracle_label, i, F::ZERO, poly.evaluate_on_hypercube(i)?)?;
231			}
232		}
233		MultilinearPolyVariant::Packed(ref packed) => {
234			let expected = witness.get_multilin_poly(packed.id())?;
235			let got = witness.get_multilin_poly(oracle.id())?;
236			if expected.packed_evals() != got.packed_evals() {
237				return Err(Error::PackedUnderlierMismatch {
238					oracle: oracle_label.into(),
239				});
240			}
241		}
242		MultilinearPolyVariant::Composite(composite_mle) => {
243			let inner_polys = composite_mle
244				.polys()
245				.map(|id| witness.get_multilin_poly(id))
246				.collect::<Result<Vec<_>, _>>()?;
247			let composite = MultilinearComposite::new(n_vars, composite_mle.c(), inner_polys)?;
248			for i in 0..1 << n_vars {
249				let got = poly.evaluate_on_hypercube(i)?;
250				let expected = composite.evaluate_on_hypercube(i)?;
251				check_eval(oracle_label, i, expected, got)?;
252			}
253		}
254	}
255	Ok(())
256}
257
258fn check_eval<F: TowerField>(
259	oracle_label: &str,
260	index: usize,
261	expected: F,
262	got: F,
263) -> Result<(), Error> {
264	if expected == got {
265		Ok(())
266	} else {
267		Err(Error::VirtualOracleEvalMismatch {
268			oracle: oracle_label.into(),
269			index,
270			reason: format!("Expected {expected}, got {got}"),
271		})
272	}
273}
274
275pub mod nonzerocheck {
276	use binius_field::{PackedField, TowerField};
277	use binius_math::MultilinearPoly;
278	use binius_maybe_rayon::prelude::*;
279	use binius_utils::bail;
280
281	use crate::{
282		oracle::{MultilinearOracleSet, OracleId},
283		protocols::sumcheck::Error,
284		witness::MultilinearExtensionIndex,
285	};
286
287	pub fn validate_witness<F, P>(
288		witness: &MultilinearExtensionIndex<P>,
289		oracles: &MultilinearOracleSet<P::Scalar>,
290		oracle_ids: &[OracleId],
291	) -> Result<(), Error>
292	where
293		P: PackedField<Scalar = F>,
294		F: TowerField,
295	{
296		oracle_ids.into_par_iter().try_for_each(|id| {
297			let multilinear = witness.get_multilin_poly(*id)?;
298			(0..(1 << multilinear.n_vars()))
299				.into_par_iter()
300				.try_for_each(|hypercube_index| {
301					if multilinear.evaluate_on_hypercube(hypercube_index)? == F::ZERO {
302						bail!(Error::NonzerocheckNaiveValidationFailure {
303							hypercube_index,
304							oracle: oracles.oracle(*id).label()
305						})
306					}
307					Ok(())
308				})
309		})
310	}
311}