binius_core/constraint_system/
validate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_fast_compute::arith_circuit::ArithCircuitPoly;
4use binius_field::{BinaryField1b, PackedExtension, PackedField, TowerField};
5use binius_hal::ComputationBackendExt;
6use binius_math::MultilinearPoly;
7use binius_utils::bail;
8
9use super::{
10	ConstraintSystem,
11	channel::{self, Boundary},
12	error::Error,
13};
14use crate::{
15	oracle::{
16		ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
17		ShiftVariant,
18	},
19	polynomial::{MultilinearComposite, test_utils::decompose_index_to_hypercube_point},
20	protocols::sumcheck::prove::zerocheck,
21	witness::MultilinearExtensionIndex,
22};
23
24pub fn validate_witness<F, P>(
25	constraint_system: &ConstraintSystem<F>,
26	boundaries: &[Boundary<F>],
27	witness: &MultilinearExtensionIndex<'_, P>,
28) -> Result<(), Error>
29where
30	P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
31	F: TowerField,
32{
33	// Check the constraint sets
34	for constraint_set in &constraint_system.table_constraints {
35		let multilinears = constraint_set
36			.oracle_ids
37			.iter()
38			.map(|id| witness.get_multilin_poly(*id))
39			.collect::<Result<Vec<_>, _>>()?;
40
41		let mut zero_claims = vec![];
42		for constraint in &constraint_set.constraints {
43			match constraint.predicate {
44				ConstraintPredicate::Zero => zero_claims.push((
45					constraint.name.clone(),
46					ArithCircuitPoly::with_n_vars(
47						multilinears.len(),
48						constraint.composition.clone(),
49					)?,
50				)),
51				ConstraintPredicate::Sum(_) => unimplemented!(),
52			}
53		}
54		zerocheck::validate_witness(&multilinears, &zero_claims)?;
55	}
56
57	// Check that nonzero oracles are non-zero over the entire hypercube
58	nonzerocheck::validate_witness(
59		witness,
60		&constraint_system.oracles,
61		&constraint_system.non_zero_oracle_ids,
62	)?;
63
64	// Check that the channels balance with flushes and boundaries
65	channel::validate_witness(
66		witness,
67		&constraint_system.flushes,
68		boundaries,
69		constraint_system.max_channel_id,
70	)?;
71
72	// Check consistency of virtual oracle witnesses (eg. that shift polynomials are actually
73	// shifts).
74	for oracle in constraint_system.oracles.polys() {
75		validate_virtual_oracle_witness(oracle, &constraint_system.oracles, witness)?;
76	}
77
78	Ok(())
79}
80
81pub fn validate_virtual_oracle_witness<F, P>(
82	oracle: &MultilinearPolyOracle<F>,
83	oracles: &MultilinearOracleSet<F>,
84	witness: &MultilinearExtensionIndex<P>,
85) -> Result<(), Error>
86where
87	P: PackedField<Scalar = F>,
88	F: TowerField,
89{
90	let oracle_label = &oracle.label();
91	let n_vars = oracle.n_vars();
92	let poly = witness.get_multilin_poly(oracle.id())?;
93
94	match oracle.variant {
95		MultilinearPolyVariant::Structured(_) if poly.n_vars() > n_vars => {
96			bail!(Error::VirtualOracleNvarsMismatch {
97				oracle: oracle_label.into(),
98				condition: '<',
99				oracle_num_vars: n_vars,
100				witness_num_vars: poly.n_vars(),
101			})
102		}
103		_ if poly.n_vars() != n_vars => {
104			bail!(Error::VirtualOracleNvarsMismatch {
105				oracle: oracle_label.into(),
106				condition: '=',
107				oracle_num_vars: n_vars,
108				witness_num_vars: poly.n_vars(),
109			})
110		}
111		_ => (),
112	}
113
114	match &oracle.variant {
115		MultilinearPolyVariant::Committed => {
116			// Committed oracles don't need to be checked as they are allowed to contain any data
117			// here
118		}
119		MultilinearPolyVariant::Transparent(inner) => {
120			for i in 0..1 << n_vars {
121				let got = poly.evaluate_on_hypercube(i)?;
122				let expected = inner
123					.poly()
124					.evaluate(&decompose_index_to_hypercube_point(n_vars, i))?;
125				check_eval(oracle_label, i, expected, got)?;
126			}
127		}
128		MultilinearPolyVariant::Structured(inner) => {
129			for i in 0..1 << n_vars {
130				let got = poly.evaluate_on_hypercube(i)?;
131				// NOTE: that the arith circuit can have more query parameters than `n_vars`, that's
132				//       the reason we use the arith circuit's n_vars here.
133				let eval_point = decompose_index_to_hypercube_point(inner.n_vars(), i);
134				let expected = inner.evaluate(&eval_point)?;
135				check_eval(oracle_label, i, expected, got)?;
136			}
137		}
138		MultilinearPolyVariant::LinearCombination(linear_combination) => {
139			let uncombined_polys = linear_combination
140				.polys()
141				.map(|id| witness.get_multilin_poly(id))
142				.collect::<Result<Vec<_>, _>>()?;
143			for i in 0..1 << n_vars {
144				let got = poly.evaluate_on_hypercube(i)?;
145				let expected = linear_combination
146					.coefficients()
147					.zip(uncombined_polys.iter())
148					.try_fold(linear_combination.offset(), |acc, (coeff, poly)| {
149						Ok::<F, Error>(acc + poly.evaluate_on_hypercube_and_scale(i, coeff)?)
150					})?;
151				check_eval(oracle_label, i, expected, got)?;
152			}
153		}
154		MultilinearPolyVariant::Repeating { id, .. } => {
155			let unrepeated_poly = witness.get_multilin_poly(*id)?;
156			let unrepeated_n_vars = oracles.n_vars(*id);
157			for i in 0..1 << n_vars {
158				let got = poly.evaluate_on_hypercube(i)?;
159				let expected =
160					unrepeated_poly.evaluate_on_hypercube(i % (1 << unrepeated_n_vars))?;
161				check_eval(oracle_label, i, expected, got)?;
162			}
163		}
164		MultilinearPolyVariant::Shifted(shifted) => {
165			let unshifted_poly = witness.get_multilin_poly(shifted.id())?;
166			let block_len = 1 << shifted.block_size();
167			let shift_offset = shifted.shift_offset();
168			for block_start in (0..1 << n_vars).step_by(block_len) {
169				match shifted.shift_variant() {
170					ShiftVariant::CircularLeft => {
171						for offset_after in 0..block_len {
172							check_eval(
173								oracle_label,
174								block_start + offset_after,
175								unshifted_poly.evaluate_on_hypercube(
176									block_start
177										+ (offset_after + (block_len - shift_offset)) % block_len,
178								)?,
179								poly.evaluate_on_hypercube(block_start + offset_after)?,
180							)?;
181						}
182					}
183					ShiftVariant::LogicalLeft => {
184						for offset_after in 0..shift_offset {
185							check_eval(
186								oracle_label,
187								block_start + offset_after,
188								F::ZERO,
189								poly.evaluate_on_hypercube(block_start + offset_after)?,
190							)?;
191						}
192						for offset_after in shift_offset..block_len {
193							check_eval(
194								oracle_label,
195								block_start + offset_after,
196								unshifted_poly.evaluate_on_hypercube(
197									block_start + offset_after - shift_offset,
198								)?,
199								poly.evaluate_on_hypercube(block_start + offset_after)?,
200							)?;
201						}
202					}
203					ShiftVariant::LogicalRight => {
204						for offset_after in 0..block_len - shift_offset {
205							check_eval(
206								oracle_label,
207								block_start + offset_after,
208								unshifted_poly.evaluate_on_hypercube(
209									block_start + offset_after + shift_offset,
210								)?,
211								poly.evaluate_on_hypercube(block_start + offset_after)?,
212							)?;
213						}
214						for offset_after in block_len - shift_offset..block_len {
215							check_eval(
216								oracle_label,
217								block_start + offset_after,
218								F::ZERO,
219								poly.evaluate_on_hypercube(block_start + offset_after)?,
220							)?;
221						}
222					}
223				}
224			}
225		}
226		MultilinearPolyVariant::Projected(projected) => {
227			let unprojected_poly = witness.get_multilin_poly(projected.id())?;
228			let partial_query =
229				binius_hal::make_portable_backend().multilinear_query(projected.values())?;
230			let projected_poly = unprojected_poly
231				.evaluate_partial(partial_query.to_ref(), projected.start_index())?;
232
233			for i in 0..1 << n_vars {
234				check_eval(
235					oracle_label,
236					i,
237					projected_poly.evaluate_on_hypercube(i)?,
238					poly.evaluate_on_hypercube(i)?,
239				)?;
240			}
241		}
242		MultilinearPolyVariant::ZeroPadded(padded) => {
243			let inner_id = padded.id();
244			let unpadded_poly = witness.get_multilin_poly(inner_id)?;
245			let n_pad_vars = padded.n_pad_vars();
246			let start_index = padded.start_index();
247			let nonzero_index = padded.nonzero_index();
248			let padded_poly = unpadded_poly.zero_pad(n_pad_vars, start_index, nonzero_index)?;
249			for i in 0..1 << n_pad_vars {
250				check_eval(
251					oracle_label,
252					i,
253					padded_poly.evaluate_on_hypercube(i)?,
254					poly.evaluate_on_hypercube(i)?,
255				)?;
256			}
257		}
258		MultilinearPolyVariant::Packed(packed) => {
259			let expected = witness.get_multilin_poly(packed.id())?;
260			let got = witness.get_multilin_poly(oracle.id())?;
261			if expected.packed_evals() != got.packed_evals() {
262				return Err(Error::PackedUnderlierMismatch {
263					oracle: oracle_label.into(),
264				});
265			}
266		}
267		MultilinearPolyVariant::Composite(composite_mle) => {
268			let inner_polys = composite_mle
269				.polys()
270				.map(|id| witness.get_multilin_poly(id))
271				.collect::<Result<Vec<_>, _>>()?;
272			let composite = MultilinearComposite::new(n_vars, composite_mle.c(), inner_polys)?;
273			for i in 0..1 << n_vars {
274				let got = poly.evaluate_on_hypercube(i)?;
275				let expected = composite.evaluate_on_hypercube(i)?;
276				check_eval(oracle_label, i, expected, got)?;
277			}
278		}
279	}
280	Ok(())
281}
282
283fn check_eval<F: TowerField>(
284	oracle_label: &str,
285	index: usize,
286	expected: F,
287	got: F,
288) -> Result<(), Error> {
289	if expected == got {
290		Ok(())
291	} else {
292		Err(Error::VirtualOracleEvalMismatch {
293			oracle: oracle_label.into(),
294			index,
295			reason: format!("Expected {expected}, got {got}"),
296		})
297	}
298}
299
300pub mod nonzerocheck {
301	use binius_field::{PackedField, TowerField};
302	use binius_math::MultilinearPoly;
303	use binius_maybe_rayon::prelude::*;
304	use binius_utils::bail;
305
306	use crate::{
307		oracle::{MultilinearOracleSet, OracleId},
308		protocols::sumcheck::Error,
309		witness::MultilinearExtensionIndex,
310	};
311
312	pub fn validate_witness<F, P>(
313		witness: &MultilinearExtensionIndex<P>,
314		oracles: &MultilinearOracleSet<P::Scalar>,
315		oracle_ids: &[OracleId],
316	) -> Result<(), Error>
317	where
318		P: PackedField<Scalar = F>,
319		F: TowerField,
320	{
321		oracle_ids.into_par_iter().try_for_each(|id| {
322			let multilinear = witness.get_multilin_poly(*id)?;
323			(0..(1 << multilinear.n_vars()))
324				.into_par_iter()
325				.try_for_each(|hypercube_index| {
326					if multilinear.evaluate_on_hypercube(hypercube_index)? == F::ZERO {
327						bail!(Error::NonzerocheckNaiveValidationFailure {
328							hypercube_index,
329							oracle: oracles[*id].label()
330						})
331					}
332					Ok(())
333				})
334		})
335	}
336}