binius_core/constraint_system/
validate.rs1use binius_fast_compute::arith_circuit::ArithCircuitPoly;
4use binius_field::{BinaryField1b, PackedExtension, PackedField, TowerField};
5use binius_hal::ComputationBackendExt;
6use binius_math::MultilinearPoly;
7use binius_utils::bail;
8
9use super::{
10 ConstraintSystem,
11 channel::{self, Boundary},
12 error::Error,
13};
14use crate::{
15 oracle::{
16 ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
17 ShiftVariant,
18 },
19 polynomial::{MultilinearComposite, test_utils::decompose_index_to_hypercube_point},
20 protocols::sumcheck::prove::zerocheck,
21 witness::MultilinearExtensionIndex,
22};
23
24pub fn validate_witness<F, P>(
25 constraint_system: &ConstraintSystem<F>,
26 boundaries: &[Boundary<F>],
27 table_sizes: &[usize],
28 witness: &MultilinearExtensionIndex<'_, P>,
29) -> Result<(), Error>
30where
31 P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
32 F: TowerField,
33{
34 constraint_system.check_table_sizes(table_sizes)?;
35
36 let ConstraintSystem {
37 oracles: unsized_oracles,
38 table_constraints,
39 non_zero_oracle_ids,
40 flushes,
41 channel_count,
42 table_size_specs: _,
43 exponents: _,
44 } = constraint_system;
45
46 let oracles = unsized_oracles.instantiate(table_sizes)?;
47
48 for constraint_set in table_constraints {
50 if table_sizes[constraint_set.table_id] == 0 {
51 continue;
52 }
53 if constraint_set.oracle_ids.is_empty() {
54 continue;
55 }
56
57 let multilinears = constraint_set
58 .oracle_ids
59 .iter()
60 .map(|id| witness.get_multilin_poly(*id))
61 .collect::<Result<Vec<_>, _>>()?;
62
63 let mut zero_claims = vec![];
64 for constraint in &constraint_set.constraints {
65 match constraint.predicate {
66 ConstraintPredicate::Zero => zero_claims.push((
67 constraint.name.clone(),
68 ArithCircuitPoly::with_n_vars(
69 multilinears.len(),
70 constraint.composition.clone(),
71 )?,
72 )),
73 ConstraintPredicate::Sum(_) => unimplemented!(),
74 }
75 }
76 zerocheck::validate_witness(&multilinears, &zero_claims)?;
77 }
78
79 nonzerocheck::validate_witness(witness, &oracles, non_zero_oracle_ids)?;
81
82 channel::validate_witness(witness, flushes, boundaries, table_sizes, *channel_count)?;
84
85 for oracle in oracles.polys() {
88 validate_virtual_oracle_witness(oracle, &oracles, witness)?;
89 }
90
91 Ok(())
92}
93
94pub fn validate_virtual_oracle_witness<F, P>(
95 oracle: &MultilinearPolyOracle<F>,
96 oracles: &MultilinearOracleSet<F>,
97 witness: &MultilinearExtensionIndex<P>,
98) -> Result<(), Error>
99where
100 P: PackedField<Scalar = F>,
101 F: TowerField,
102{
103 let oracle_label = &oracle.label();
104 let n_vars = oracle.n_vars();
105 let poly = witness.get_multilin_poly(oracle.id())?;
106
107 match oracle.variant {
108 MultilinearPolyVariant::Structured(_) if poly.n_vars() > n_vars => {
109 bail!(Error::VirtualOracleNvarsMismatch {
110 oracle: oracle_label.into(),
111 condition: '<',
112 oracle_num_vars: n_vars,
113 witness_num_vars: poly.n_vars(),
114 })
115 }
116 _ if poly.n_vars() != n_vars => {
117 bail!(Error::VirtualOracleNvarsMismatch {
118 oracle: oracle_label.into(),
119 condition: '=',
120 oracle_num_vars: n_vars,
121 witness_num_vars: poly.n_vars(),
122 })
123 }
124 _ => (),
125 }
126
127 match &oracle.variant {
128 MultilinearPolyVariant::Committed => {
129 }
132 MultilinearPolyVariant::Transparent(inner) => {
133 for i in 0..1 << n_vars {
134 let got = poly.evaluate_on_hypercube(i)?;
135 let expected = inner
136 .poly()
137 .evaluate(&decompose_index_to_hypercube_point(n_vars, i))?;
138 check_eval(oracle_label, i, expected, got)?;
139 }
140 }
141 MultilinearPolyVariant::Structured(inner) => {
142 for i in 0..1 << n_vars {
143 let got = poly.evaluate_on_hypercube(i)?;
144 let eval_point = decompose_index_to_hypercube_point(inner.n_vars(), i);
147 let expected = inner.evaluate(&eval_point)?;
148 check_eval(oracle_label, i, expected, got)?;
149 }
150 }
151 MultilinearPolyVariant::LinearCombination(linear_combination) => {
152 let uncombined_polys = linear_combination
153 .polys()
154 .map(|id| witness.get_multilin_poly(id))
155 .collect::<Result<Vec<_>, _>>()?;
156 for i in 0..1 << n_vars {
157 let got = poly.evaluate_on_hypercube(i)?;
158 let expected = linear_combination
159 .coefficients()
160 .zip(uncombined_polys.iter())
161 .try_fold(linear_combination.offset(), |acc, (coeff, poly)| {
162 Ok::<F, Error>(acc + poly.evaluate_on_hypercube_and_scale(i, coeff)?)
163 })?;
164 check_eval(oracle_label, i, expected, got)?;
165 }
166 }
167 MultilinearPolyVariant::Repeating { id, .. } => {
168 let unrepeated_poly = witness.get_multilin_poly(*id)?;
169 let unrepeated_n_vars = oracles.n_vars(*id);
170 for i in 0..1 << n_vars {
171 let got = poly.evaluate_on_hypercube(i)?;
172 let expected =
173 unrepeated_poly.evaluate_on_hypercube(i % (1 << unrepeated_n_vars))?;
174 check_eval(oracle_label, i, expected, got)?;
175 }
176 }
177 MultilinearPolyVariant::Shifted(shifted) => {
178 let unshifted_poly = witness.get_multilin_poly(shifted.id())?;
179 let block_len = 1 << shifted.block_size();
180 let shift_offset = shifted.shift_offset();
181 for block_start in (0..1 << n_vars).step_by(block_len) {
182 match shifted.shift_variant() {
183 ShiftVariant::CircularLeft => {
184 for offset_after in 0..block_len {
185 check_eval(
186 oracle_label,
187 block_start + offset_after,
188 unshifted_poly.evaluate_on_hypercube(
189 block_start
190 + (offset_after + (block_len - shift_offset)) % block_len,
191 )?,
192 poly.evaluate_on_hypercube(block_start + offset_after)?,
193 )?;
194 }
195 }
196 ShiftVariant::LogicalLeft => {
197 for offset_after in 0..shift_offset {
198 check_eval(
199 oracle_label,
200 block_start + offset_after,
201 F::ZERO,
202 poly.evaluate_on_hypercube(block_start + offset_after)?,
203 )?;
204 }
205 for offset_after in shift_offset..block_len {
206 check_eval(
207 oracle_label,
208 block_start + offset_after,
209 unshifted_poly.evaluate_on_hypercube(
210 block_start + offset_after - shift_offset,
211 )?,
212 poly.evaluate_on_hypercube(block_start + offset_after)?,
213 )?;
214 }
215 }
216 ShiftVariant::LogicalRight => {
217 for offset_after in 0..block_len - shift_offset {
218 check_eval(
219 oracle_label,
220 block_start + offset_after,
221 unshifted_poly.evaluate_on_hypercube(
222 block_start + offset_after + shift_offset,
223 )?,
224 poly.evaluate_on_hypercube(block_start + offset_after)?,
225 )?;
226 }
227 for offset_after in block_len - shift_offset..block_len {
228 check_eval(
229 oracle_label,
230 block_start + offset_after,
231 F::ZERO,
232 poly.evaluate_on_hypercube(block_start + offset_after)?,
233 )?;
234 }
235 }
236 }
237 }
238 }
239 MultilinearPolyVariant::Projected(projected) => {
240 let unprojected_poly = witness.get_multilin_poly(projected.id())?;
241 let partial_query =
242 binius_hal::make_portable_backend().multilinear_query(projected.values())?;
243 let projected_poly = unprojected_poly
244 .evaluate_partial(partial_query.to_ref(), projected.start_index())?;
245
246 for i in 0..1 << n_vars {
247 check_eval(
248 oracle_label,
249 i,
250 projected_poly.evaluate_on_hypercube(i)?,
251 poly.evaluate_on_hypercube(i)?,
252 )?;
253 }
254 }
255 MultilinearPolyVariant::ZeroPadded(padded) => {
256 let inner_id = padded.id();
257 let unpadded_poly = witness.get_multilin_poly(inner_id)?;
258 let n_pad_vars = padded.n_pad_vars();
259 let start_index = padded.start_index();
260 let nonzero_index = padded.nonzero_index();
261 let padded_poly = unpadded_poly.zero_pad(n_pad_vars, start_index, nonzero_index)?;
262 for i in 0..1 << n_pad_vars {
263 check_eval(
264 oracle_label,
265 i,
266 padded_poly.evaluate_on_hypercube(i)?,
267 poly.evaluate_on_hypercube(i)?,
268 )?;
269 }
270 }
271 MultilinearPolyVariant::Packed(packed) => {
272 let expected = witness.get_multilin_poly(packed.id())?;
273 let got = witness.get_multilin_poly(oracle.id())?;
274 if expected.packed_evals() != got.packed_evals() {
275 return Err(Error::PackedUnderlierMismatch {
276 oracle: oracle_label.into(),
277 });
278 }
279 }
280 MultilinearPolyVariant::Composite(composite_mle) => {
281 let inner_polys = composite_mle
282 .polys()
283 .map(|id| witness.get_multilin_poly(id))
284 .collect::<Result<Vec<_>, _>>()?;
285 let composite = MultilinearComposite::new(n_vars, composite_mle.c(), inner_polys)?;
286 for i in 0..1 << n_vars {
287 let got = poly.evaluate_on_hypercube(i)?;
288 let expected = composite.evaluate_on_hypercube(i)?;
289 check_eval(oracle_label, i, expected, got)?;
290 }
291 }
292 }
293 Ok(())
294}
295
296fn check_eval<F: TowerField>(
297 oracle_label: &str,
298 index: usize,
299 expected: F,
300 got: F,
301) -> Result<(), Error> {
302 if expected == got {
303 Ok(())
304 } else {
305 Err(Error::VirtualOracleEvalMismatch {
306 oracle: oracle_label.into(),
307 index,
308 reason: format!("Expected {expected}, got {got}"),
309 })
310 }
311}
312
313pub mod nonzerocheck {
314 use binius_field::{PackedField, TowerField};
315 use binius_math::MultilinearPoly;
316 use binius_maybe_rayon::prelude::*;
317 use binius_utils::bail;
318
319 use crate::{
320 oracle::{MultilinearOracleSet, OracleId},
321 protocols::sumcheck::Error,
322 witness::MultilinearExtensionIndex,
323 };
324
325 pub fn validate_witness<F, P>(
326 witness: &MultilinearExtensionIndex<P>,
327 oracles: &MultilinearOracleSet<P::Scalar>,
328 oracle_ids: &[OracleId],
329 ) -> Result<(), Error>
330 where
331 P: PackedField<Scalar = F>,
332 F: TowerField,
333 {
334 oracle_ids.into_par_iter().try_for_each(|id| {
335 if oracles.is_zero_sized(*id) {
336 return Ok(());
337 }
338 let multilinear = witness.get_multilin_poly(*id)?;
339 (0..(1 << multilinear.n_vars()))
340 .into_par_iter()
341 .try_for_each(|hypercube_index| {
342 if multilinear.evaluate_on_hypercube(hypercube_index)? == F::ZERO {
343 bail!(Error::NonzerocheckNaiveValidationFailure {
344 hypercube_index,
345 oracle: oracles[*id].label()
346 })
347 }
348 Ok(())
349 })
350 })
351 }
352}