binius_core/constraint_system/
validate.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_fast_compute::arith_circuit::ArithCircuitPoly;
4use binius_field::{BinaryField1b, PackedExtension, PackedField, TowerField};
5use binius_hal::ComputationBackendExt;
6use binius_math::MultilinearPoly;
7use binius_utils::bail;
8
9use super::{
10	ConstraintSystem,
11	channel::{self, Boundary},
12	error::Error,
13};
14use crate::{
15	oracle::{
16		ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
17		ShiftVariant,
18	},
19	polynomial::{MultilinearComposite, test_utils::decompose_index_to_hypercube_point},
20	protocols::sumcheck::prove::zerocheck,
21	witness::MultilinearExtensionIndex,
22};
23
24pub fn validate_witness<F, P>(
25	constraint_system: &ConstraintSystem<F>,
26	boundaries: &[Boundary<F>],
27	table_sizes: &[usize],
28	witness: &MultilinearExtensionIndex<'_, P>,
29) -> Result<(), Error>
30where
31	P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
32	F: TowerField,
33{
34	constraint_system.check_table_sizes(table_sizes)?;
35
36	let ConstraintSystem {
37		oracles: unsized_oracles,
38		table_constraints,
39		non_zero_oracle_ids,
40		flushes,
41		channel_count,
42		table_size_specs: _,
43		exponents: _,
44	} = constraint_system;
45
46	let oracles = unsized_oracles.instantiate(table_sizes)?;
47
48	// Check the constraint sets
49	for constraint_set in table_constraints {
50		if table_sizes[constraint_set.table_id] == 0 {
51			continue;
52		}
53		if constraint_set.oracle_ids.is_empty() {
54			continue;
55		}
56
57		let multilinears = constraint_set
58			.oracle_ids
59			.iter()
60			.map(|id| witness.get_multilin_poly(*id))
61			.collect::<Result<Vec<_>, _>>()?;
62
63		let mut zero_claims = vec![];
64		for constraint in &constraint_set.constraints {
65			match constraint.predicate {
66				ConstraintPredicate::Zero => zero_claims.push((
67					constraint.name.clone(),
68					ArithCircuitPoly::with_n_vars(
69						multilinears.len(),
70						constraint.composition.clone(),
71					)?,
72				)),
73				ConstraintPredicate::Sum(_) => unimplemented!(),
74			}
75		}
76		zerocheck::validate_witness(&multilinears, &zero_claims)?;
77	}
78
79	// Check that nonzero oracles are non-zero over the entire hypercube
80	nonzerocheck::validate_witness(witness, &oracles, non_zero_oracle_ids)?;
81
82	// Check that the channels balance with flushes and boundaries
83	channel::validate_witness(witness, flushes, boundaries, table_sizes, *channel_count)?;
84
85	// Check consistency of virtual oracle witnesses (eg. that shift polynomials are actually
86	// shifts).
87	for oracle in oracles.polys() {
88		validate_virtual_oracle_witness(oracle, &oracles, witness)?;
89	}
90
91	Ok(())
92}
93
94pub fn validate_virtual_oracle_witness<F, P>(
95	oracle: &MultilinearPolyOracle<F>,
96	oracles: &MultilinearOracleSet<F>,
97	witness: &MultilinearExtensionIndex<P>,
98) -> Result<(), Error>
99where
100	P: PackedField<Scalar = F>,
101	F: TowerField,
102{
103	let oracle_label = &oracle.label();
104	let n_vars = oracle.n_vars();
105	let poly = witness.get_multilin_poly(oracle.id())?;
106
107	match oracle.variant {
108		MultilinearPolyVariant::Structured(_) if poly.n_vars() > n_vars => {
109			bail!(Error::VirtualOracleNvarsMismatch {
110				oracle: oracle_label.into(),
111				condition: '<',
112				oracle_num_vars: n_vars,
113				witness_num_vars: poly.n_vars(),
114			})
115		}
116		_ if poly.n_vars() != n_vars => {
117			bail!(Error::VirtualOracleNvarsMismatch {
118				oracle: oracle_label.into(),
119				condition: '=',
120				oracle_num_vars: n_vars,
121				witness_num_vars: poly.n_vars(),
122			})
123		}
124		_ => (),
125	}
126
127	match &oracle.variant {
128		MultilinearPolyVariant::Committed => {
129			// Committed oracles don't need to be checked as they are allowed to contain any data
130			// here
131		}
132		MultilinearPolyVariant::Transparent(inner) => {
133			for i in 0..1 << n_vars {
134				let got = poly.evaluate_on_hypercube(i)?;
135				let expected = inner
136					.poly()
137					.evaluate(&decompose_index_to_hypercube_point(n_vars, i))?;
138				check_eval(oracle_label, i, expected, got)?;
139			}
140		}
141		MultilinearPolyVariant::Structured(inner) => {
142			for i in 0..1 << n_vars {
143				let got = poly.evaluate_on_hypercube(i)?;
144				// NOTE: that the arith circuit can have more query parameters than `n_vars`, that's
145				//       the reason we use the arith circuit's n_vars here.
146				let eval_point = decompose_index_to_hypercube_point(inner.n_vars(), i);
147				let expected = inner.evaluate(&eval_point)?;
148				check_eval(oracle_label, i, expected, got)?;
149			}
150		}
151		MultilinearPolyVariant::LinearCombination(linear_combination) => {
152			let uncombined_polys = linear_combination
153				.polys()
154				.map(|id| witness.get_multilin_poly(id))
155				.collect::<Result<Vec<_>, _>>()?;
156			for i in 0..1 << n_vars {
157				let got = poly.evaluate_on_hypercube(i)?;
158				let expected = linear_combination
159					.coefficients()
160					.zip(uncombined_polys.iter())
161					.try_fold(linear_combination.offset(), |acc, (coeff, poly)| {
162						Ok::<F, Error>(acc + poly.evaluate_on_hypercube_and_scale(i, coeff)?)
163					})?;
164				check_eval(oracle_label, i, expected, got)?;
165			}
166		}
167		MultilinearPolyVariant::Repeating { id, .. } => {
168			let unrepeated_poly = witness.get_multilin_poly(*id)?;
169			let unrepeated_n_vars = oracles.n_vars(*id);
170			for i in 0..1 << n_vars {
171				let got = poly.evaluate_on_hypercube(i)?;
172				let expected =
173					unrepeated_poly.evaluate_on_hypercube(i % (1 << unrepeated_n_vars))?;
174				check_eval(oracle_label, i, expected, got)?;
175			}
176		}
177		MultilinearPolyVariant::Shifted(shifted) => {
178			let unshifted_poly = witness.get_multilin_poly(shifted.id())?;
179			let block_len = 1 << shifted.block_size();
180			let shift_offset = shifted.shift_offset();
181			for block_start in (0..1 << n_vars).step_by(block_len) {
182				match shifted.shift_variant() {
183					ShiftVariant::CircularLeft => {
184						for offset_after in 0..block_len {
185							check_eval(
186								oracle_label,
187								block_start + offset_after,
188								unshifted_poly.evaluate_on_hypercube(
189									block_start
190										+ (offset_after + (block_len - shift_offset)) % block_len,
191								)?,
192								poly.evaluate_on_hypercube(block_start + offset_after)?,
193							)?;
194						}
195					}
196					ShiftVariant::LogicalLeft => {
197						for offset_after in 0..shift_offset {
198							check_eval(
199								oracle_label,
200								block_start + offset_after,
201								F::ZERO,
202								poly.evaluate_on_hypercube(block_start + offset_after)?,
203							)?;
204						}
205						for offset_after in shift_offset..block_len {
206							check_eval(
207								oracle_label,
208								block_start + offset_after,
209								unshifted_poly.evaluate_on_hypercube(
210									block_start + offset_after - shift_offset,
211								)?,
212								poly.evaluate_on_hypercube(block_start + offset_after)?,
213							)?;
214						}
215					}
216					ShiftVariant::LogicalRight => {
217						for offset_after in 0..block_len - shift_offset {
218							check_eval(
219								oracle_label,
220								block_start + offset_after,
221								unshifted_poly.evaluate_on_hypercube(
222									block_start + offset_after + shift_offset,
223								)?,
224								poly.evaluate_on_hypercube(block_start + offset_after)?,
225							)?;
226						}
227						for offset_after in block_len - shift_offset..block_len {
228							check_eval(
229								oracle_label,
230								block_start + offset_after,
231								F::ZERO,
232								poly.evaluate_on_hypercube(block_start + offset_after)?,
233							)?;
234						}
235					}
236				}
237			}
238		}
239		MultilinearPolyVariant::Projected(projected) => {
240			let unprojected_poly = witness.get_multilin_poly(projected.id())?;
241			let partial_query =
242				binius_hal::make_portable_backend().multilinear_query(projected.values())?;
243			let projected_poly = unprojected_poly
244				.evaluate_partial(partial_query.to_ref(), projected.start_index())?;
245
246			for i in 0..1 << n_vars {
247				check_eval(
248					oracle_label,
249					i,
250					projected_poly.evaluate_on_hypercube(i)?,
251					poly.evaluate_on_hypercube(i)?,
252				)?;
253			}
254		}
255		MultilinearPolyVariant::ZeroPadded(padded) => {
256			let inner_id = padded.id();
257			let unpadded_poly = witness.get_multilin_poly(inner_id)?;
258			let n_pad_vars = padded.n_pad_vars();
259			let start_index = padded.start_index();
260			let nonzero_index = padded.nonzero_index();
261			let padded_poly = unpadded_poly.zero_pad(n_pad_vars, start_index, nonzero_index)?;
262			for i in 0..1 << n_pad_vars {
263				check_eval(
264					oracle_label,
265					i,
266					padded_poly.evaluate_on_hypercube(i)?,
267					poly.evaluate_on_hypercube(i)?,
268				)?;
269			}
270		}
271		MultilinearPolyVariant::Packed(packed) => {
272			let expected = witness.get_multilin_poly(packed.id())?;
273			let got = witness.get_multilin_poly(oracle.id())?;
274			if expected.packed_evals() != got.packed_evals() {
275				return Err(Error::PackedUnderlierMismatch {
276					oracle: oracle_label.into(),
277				});
278			}
279		}
280		MultilinearPolyVariant::Composite(composite_mle) => {
281			let inner_polys = composite_mle
282				.polys()
283				.map(|id| witness.get_multilin_poly(id))
284				.collect::<Result<Vec<_>, _>>()?;
285			let composite = MultilinearComposite::new(n_vars, composite_mle.c(), inner_polys)?;
286			for i in 0..1 << n_vars {
287				let got = poly.evaluate_on_hypercube(i)?;
288				let expected = composite.evaluate_on_hypercube(i)?;
289				check_eval(oracle_label, i, expected, got)?;
290			}
291		}
292	}
293	Ok(())
294}
295
296fn check_eval<F: TowerField>(
297	oracle_label: &str,
298	index: usize,
299	expected: F,
300	got: F,
301) -> Result<(), Error> {
302	if expected == got {
303		Ok(())
304	} else {
305		Err(Error::VirtualOracleEvalMismatch {
306			oracle: oracle_label.into(),
307			index,
308			reason: format!("Expected {expected}, got {got}"),
309		})
310	}
311}
312
313pub mod nonzerocheck {
314	use binius_field::{PackedField, TowerField};
315	use binius_math::MultilinearPoly;
316	use binius_maybe_rayon::prelude::*;
317	use binius_utils::bail;
318
319	use crate::{
320		oracle::{MultilinearOracleSet, OracleId},
321		protocols::sumcheck::Error,
322		witness::MultilinearExtensionIndex,
323	};
324
325	pub fn validate_witness<F, P>(
326		witness: &MultilinearExtensionIndex<P>,
327		oracles: &MultilinearOracleSet<P::Scalar>,
328		oracle_ids: &[OracleId],
329	) -> Result<(), Error>
330	where
331		P: PackedField<Scalar = F>,
332		F: TowerField,
333	{
334		oracle_ids.into_par_iter().try_for_each(|id| {
335			if oracles.is_zero_sized(*id) {
336				return Ok(());
337			}
338			let multilinear = witness.get_multilin_poly(*id)?;
339			(0..(1 << multilinear.n_vars()))
340				.into_par_iter()
341				.try_for_each(|hypercube_index| {
342					if multilinear.evaluate_on_hypercube(hypercube_index)? == F::ZERO {
343						bail!(Error::NonzerocheckNaiveValidationFailure {
344							hypercube_index,
345							oracle: oracles[*id].label()
346						})
347					}
348					Ok(())
349				})
350		})
351	}
352}