binius_core/constraint_system/
validate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{BinaryField1b, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackendExt;
5use binius_math::MultilinearPoly;
6use binius_utils::bail;
7
8use super::{
9	channel::{self, Boundary},
10	error::Error,
11	ConstraintSystem,
12};
13use crate::{
14	oracle::{
15		ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
16		ShiftVariant,
17	},
18	polynomial::{
19		test_utils::decompose_index_to_hypercube_point, ArithCircuitPoly, MultilinearComposite,
20	},
21	protocols::sumcheck::prove::zerocheck,
22	witness::MultilinearExtensionIndex,
23};
24
25pub fn validate_witness<F, P>(
26	constraint_system: &ConstraintSystem<F>,
27	boundaries: &[Boundary<F>],
28	witness: &MultilinearExtensionIndex<'_, P>,
29) -> Result<(), Error>
30where
31	P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
32	F: TowerField,
33{
34	// Check the constraint sets
35	for constraint_set in &constraint_system.table_constraints {
36		let multilinears = constraint_set
37			.oracle_ids
38			.iter()
39			.map(|id| witness.get_multilin_poly(*id))
40			.collect::<Result<Vec<_>, _>>()?;
41
42		let mut zero_claims = vec![];
43		for constraint in &constraint_set.constraints {
44			match constraint.predicate {
45				ConstraintPredicate::Zero => zero_claims.push((
46					constraint.name.clone(),
47					ArithCircuitPoly::with_n_vars(
48						multilinears.len(),
49						constraint.composition.clone(),
50					)?,
51				)),
52				ConstraintPredicate::Sum(_) => unimplemented!(),
53			}
54		}
55		zerocheck::validate_witness(&multilinears, &zero_claims)?;
56	}
57
58	// Check that nonzero oracles are non-zero over the entire hypercube
59	nonzerocheck::validate_witness(
60		witness,
61		&constraint_system.oracles,
62		&constraint_system.non_zero_oracle_ids,
63	)?;
64
65	// Check that the channels balance with flushes and boundaries
66	channel::validate_witness(
67		witness,
68		&constraint_system.flushes,
69		boundaries,
70		constraint_system.max_channel_id,
71	)?;
72
73	// Check consistency of virtual oracle witnesses (eg. that shift polynomials are actually shifts).
74	for oracle in constraint_system.oracles.iter() {
75		validate_virtual_oracle_witness(oracle, &constraint_system.oracles, witness)?;
76	}
77
78	Ok(())
79}
80
81pub fn validate_virtual_oracle_witness<F, P>(
82	oracle: MultilinearPolyOracle<F>,
83	oracles: &MultilinearOracleSet<F>,
84	witness: &MultilinearExtensionIndex<P>,
85) -> Result<(), Error>
86where
87	P: PackedField<Scalar = F>,
88	F: TowerField,
89{
90	let oracle_label = &oracle.label();
91	let n_vars = oracle.n_vars();
92	let poly = witness.get_multilin_poly(oracle.id())?;
93
94	if poly.n_vars() != n_vars {
95		bail!(Error::VirtualOracleNvarsMismatch {
96			oracle: oracle_label.into(),
97			oracle_num_vars: n_vars,
98			witness_num_vars: poly.n_vars(),
99		})
100	}
101
102	match oracle.variant {
103		MultilinearPolyVariant::Committed => {
104			// Committed oracles don't need to be checked as they are allowed to contain any data here
105		}
106		MultilinearPolyVariant::Transparent(inner) => {
107			for i in 0..1 << n_vars {
108				let got = poly.evaluate_on_hypercube(i)?;
109				let expected = inner
110					.poly()
111					.evaluate(&decompose_index_to_hypercube_point(n_vars, i))?;
112				check_eval(oracle_label, i, expected, got)?;
113			}
114		}
115		MultilinearPolyVariant::LinearCombination(linear_combination) => {
116			let uncombined_polys = linear_combination
117				.polys()
118				.map(|id| witness.get_multilin_poly(id))
119				.collect::<Result<Vec<_>, _>>()?;
120			for i in 0..1 << n_vars {
121				let got = poly.evaluate_on_hypercube(i)?;
122				let expected = linear_combination
123					.coefficients()
124					.zip(uncombined_polys.iter())
125					.try_fold(linear_combination.offset(), |acc, (coeff, poly)| {
126						Ok::<F, Error>(acc + poly.evaluate_on_hypercube_and_scale(i, coeff)?)
127					})?;
128				check_eval(oracle_label, i, expected, got)?;
129			}
130		}
131		MultilinearPolyVariant::Repeating { id, .. } => {
132			let unrepeated_poly = witness.get_multilin_poly(id)?;
133			let unrepeated_n_vars = oracles.n_vars(id);
134			for i in 0..1 << n_vars {
135				let got = poly.evaluate_on_hypercube(i)?;
136				let expected =
137					unrepeated_poly.evaluate_on_hypercube(i % (1 << unrepeated_n_vars))?;
138				check_eval(oracle_label, i, expected, got)?;
139			}
140		}
141		MultilinearPolyVariant::Shifted(shifted) => {
142			let unshifted_poly = witness.get_multilin_poly(shifted.id())?;
143			let block_len = 1 << shifted.block_size();
144			let shift_offset = shifted.shift_offset();
145			for block_start in (0..1 << n_vars).step_by(block_len) {
146				match shifted.shift_variant() {
147					ShiftVariant::CircularLeft => {
148						for offset_after in 0..block_len {
149							check_eval(
150								oracle_label,
151								block_start + offset_after,
152								unshifted_poly.evaluate_on_hypercube(
153									block_start
154										+ (offset_after + (block_len - shift_offset)) % block_len,
155								)?,
156								poly.evaluate_on_hypercube(block_start + offset_after)?,
157							)?;
158						}
159					}
160					ShiftVariant::LogicalLeft => {
161						for offset_after in 0..shift_offset {
162							check_eval(
163								oracle_label,
164								block_start + offset_after,
165								F::ZERO,
166								poly.evaluate_on_hypercube(block_start + offset_after)?,
167							)?;
168						}
169						for offset_after in shift_offset..block_len {
170							check_eval(
171								oracle_label,
172								block_start + offset_after,
173								unshifted_poly.evaluate_on_hypercube(
174									block_start + offset_after - shift_offset,
175								)?,
176								poly.evaluate_on_hypercube(block_start + offset_after)?,
177							)?;
178						}
179					}
180					ShiftVariant::LogicalRight => {
181						for offset_after in 0..block_len - shift_offset {
182							check_eval(
183								oracle_label,
184								block_start + offset_after,
185								unshifted_poly.evaluate_on_hypercube(
186									block_start + offset_after + shift_offset,
187								)?,
188								poly.evaluate_on_hypercube(block_start + offset_after)?,
189							)?;
190						}
191						for offset_after in block_len - shift_offset..block_len {
192							check_eval(
193								oracle_label,
194								block_start + offset_after,
195								F::ZERO,
196								poly.evaluate_on_hypercube(block_start + offset_after)?,
197							)?;
198						}
199					}
200				}
201			}
202		}
203		MultilinearPolyVariant::Projected(projected) => {
204			let unprojected_poly = witness.get_multilin_poly(projected.id())?;
205			let partial_query =
206				binius_hal::make_portable_backend().multilinear_query(projected.values())?;
207			let projected_poly = unprojected_poly
208				.evaluate_partial(partial_query.to_ref(), projected.start_index())?;
209
210			for i in 0..1 << n_vars {
211				check_eval(
212					oracle_label,
213					i,
214					projected_poly.evaluate_on_hypercube(i)?,
215					poly.evaluate_on_hypercube(i)?,
216				)?;
217			}
218		}
219		MultilinearPolyVariant::ZeroPadded(padded) => {
220			let inner_id = padded.id();
221			let unpadded_poly = witness.get_multilin_poly(inner_id)?;
222			let n_pad_vars = padded.n_pad_vars();
223			let start_index = padded.start_index();
224			let nonzero_index = padded.nonzero_index();
225			let padded_poly = unpadded_poly.zero_pad(n_pad_vars, start_index, nonzero_index)?;
226			for i in 0..1 << n_pad_vars {
227				check_eval(
228					oracle_label,
229					i,
230					padded_poly.evaluate_on_hypercube(i)?,
231					poly.evaluate_on_hypercube(i)?,
232				)?;
233			}
234		}
235		MultilinearPolyVariant::Packed(ref packed) => {
236			let expected = witness.get_multilin_poly(packed.id())?;
237			let got = witness.get_multilin_poly(oracle.id())?;
238			if expected.packed_evals() != got.packed_evals() {
239				return Err(Error::PackedUnderlierMismatch {
240					oracle: oracle_label.into(),
241				});
242			}
243		}
244		MultilinearPolyVariant::Composite(composite_mle) => {
245			let inner_polys = composite_mle
246				.polys()
247				.map(|id| witness.get_multilin_poly(id))
248				.collect::<Result<Vec<_>, _>>()?;
249			let composite = MultilinearComposite::new(n_vars, composite_mle.c(), inner_polys)?;
250			for i in 0..1 << n_vars {
251				let got = poly.evaluate_on_hypercube(i)?;
252				let expected = composite.evaluate_on_hypercube(i)?;
253				check_eval(oracle_label, i, expected, got)?;
254			}
255		}
256	}
257	Ok(())
258}
259
260fn check_eval<F: TowerField>(
261	oracle_label: &str,
262	index: usize,
263	expected: F,
264	got: F,
265) -> Result<(), Error> {
266	if expected == got {
267		Ok(())
268	} else {
269		Err(Error::VirtualOracleEvalMismatch {
270			oracle: oracle_label.into(),
271			index,
272			reason: format!("Expected {expected}, got {got}"),
273		})
274	}
275}
276
277pub mod nonzerocheck {
278	use binius_field::{PackedField, TowerField};
279	use binius_math::MultilinearPoly;
280	use binius_maybe_rayon::prelude::*;
281	use binius_utils::bail;
282
283	use crate::{
284		oracle::{MultilinearOracleSet, OracleId},
285		protocols::sumcheck::Error,
286		witness::MultilinearExtensionIndex,
287	};
288
289	pub fn validate_witness<F, P>(
290		witness: &MultilinearExtensionIndex<P>,
291		oracles: &MultilinearOracleSet<P::Scalar>,
292		oracle_ids: &[OracleId],
293	) -> Result<(), Error>
294	where
295		P: PackedField<Scalar = F>,
296		F: TowerField,
297	{
298		oracle_ids.into_par_iter().try_for_each(|id| {
299			let multilinear = witness.get_multilin_poly(*id)?;
300			(0..(1 << multilinear.n_vars()))
301				.into_par_iter()
302				.try_for_each(|hypercube_index| {
303					if multilinear.evaluate_on_hypercube(hypercube_index)? == F::ZERO {
304						bail!(Error::NonzerocheckNaiveValidationFailure {
305							hypercube_index,
306							oracle: oracles.oracle(*id).label()
307						})
308					}
309					Ok(())
310				})
311		})
312	}
313}