binius_core/constraint_system/
validate.rs1use binius_fast_compute::arith_circuit::ArithCircuitPoly;
4use binius_field::{BinaryField1b, PackedExtension, PackedField, TowerField};
5use binius_hal::ComputationBackendExt;
6use binius_math::MultilinearPoly;
7use binius_utils::bail;
8
9use super::{
10 ConstraintSystem,
11 channel::{self, Boundary},
12 error::Error,
13};
14use crate::{
15 oracle::{
16 ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
17 ShiftVariant,
18 },
19 polynomial::{MultilinearComposite, test_utils::decompose_index_to_hypercube_point},
20 protocols::sumcheck::prove::zerocheck,
21 witness::MultilinearExtensionIndex,
22};
23
24pub fn validate_witness<F, P>(
25 constraint_system: &ConstraintSystem<F>,
26 boundaries: &[Boundary<F>],
27 witness: &MultilinearExtensionIndex<'_, P>,
28) -> Result<(), Error>
29where
30 P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
31 F: TowerField,
32{
33 for constraint_set in &constraint_system.table_constraints {
35 let multilinears = constraint_set
36 .oracle_ids
37 .iter()
38 .map(|id| witness.get_multilin_poly(*id))
39 .collect::<Result<Vec<_>, _>>()?;
40
41 let mut zero_claims = vec![];
42 for constraint in &constraint_set.constraints {
43 match constraint.predicate {
44 ConstraintPredicate::Zero => zero_claims.push((
45 constraint.name.clone(),
46 ArithCircuitPoly::with_n_vars(
47 multilinears.len(),
48 constraint.composition.clone(),
49 )?,
50 )),
51 ConstraintPredicate::Sum(_) => unimplemented!(),
52 }
53 }
54 zerocheck::validate_witness(&multilinears, &zero_claims)?;
55 }
56
57 nonzerocheck::validate_witness(
59 witness,
60 &constraint_system.oracles,
61 &constraint_system.non_zero_oracle_ids,
62 )?;
63
64 channel::validate_witness(
66 witness,
67 &constraint_system.flushes,
68 boundaries,
69 constraint_system.max_channel_id,
70 )?;
71
72 for oracle in constraint_system.oracles.polys() {
75 validate_virtual_oracle_witness(oracle, &constraint_system.oracles, witness)?;
76 }
77
78 Ok(())
79}
80
81pub fn validate_virtual_oracle_witness<F, P>(
82 oracle: &MultilinearPolyOracle<F>,
83 oracles: &MultilinearOracleSet<F>,
84 witness: &MultilinearExtensionIndex<P>,
85) -> Result<(), Error>
86where
87 P: PackedField<Scalar = F>,
88 F: TowerField,
89{
90 let oracle_label = &oracle.label();
91 let n_vars = oracle.n_vars();
92 let poly = witness.get_multilin_poly(oracle.id())?;
93
94 match oracle.variant {
95 MultilinearPolyVariant::Structured(_) if poly.n_vars() > n_vars => {
96 bail!(Error::VirtualOracleNvarsMismatch {
97 oracle: oracle_label.into(),
98 condition: '<',
99 oracle_num_vars: n_vars,
100 witness_num_vars: poly.n_vars(),
101 })
102 }
103 _ if poly.n_vars() != n_vars => {
104 bail!(Error::VirtualOracleNvarsMismatch {
105 oracle: oracle_label.into(),
106 condition: '=',
107 oracle_num_vars: n_vars,
108 witness_num_vars: poly.n_vars(),
109 })
110 }
111 _ => (),
112 }
113
114 match &oracle.variant {
115 MultilinearPolyVariant::Committed => {
116 }
119 MultilinearPolyVariant::Transparent(inner) => {
120 for i in 0..1 << n_vars {
121 let got = poly.evaluate_on_hypercube(i)?;
122 let expected = inner
123 .poly()
124 .evaluate(&decompose_index_to_hypercube_point(n_vars, i))?;
125 check_eval(oracle_label, i, expected, got)?;
126 }
127 }
128 MultilinearPolyVariant::Structured(inner) => {
129 for i in 0..1 << n_vars {
130 let got = poly.evaluate_on_hypercube(i)?;
131 let eval_point = decompose_index_to_hypercube_point(inner.n_vars(), i);
134 let expected = inner.evaluate(&eval_point)?;
135 check_eval(oracle_label, i, expected, got)?;
136 }
137 }
138 MultilinearPolyVariant::LinearCombination(linear_combination) => {
139 let uncombined_polys = linear_combination
140 .polys()
141 .map(|id| witness.get_multilin_poly(id))
142 .collect::<Result<Vec<_>, _>>()?;
143 for i in 0..1 << n_vars {
144 let got = poly.evaluate_on_hypercube(i)?;
145 let expected = linear_combination
146 .coefficients()
147 .zip(uncombined_polys.iter())
148 .try_fold(linear_combination.offset(), |acc, (coeff, poly)| {
149 Ok::<F, Error>(acc + poly.evaluate_on_hypercube_and_scale(i, coeff)?)
150 })?;
151 check_eval(oracle_label, i, expected, got)?;
152 }
153 }
154 MultilinearPolyVariant::Repeating { id, .. } => {
155 let unrepeated_poly = witness.get_multilin_poly(*id)?;
156 let unrepeated_n_vars = oracles.n_vars(*id);
157 for i in 0..1 << n_vars {
158 let got = poly.evaluate_on_hypercube(i)?;
159 let expected =
160 unrepeated_poly.evaluate_on_hypercube(i % (1 << unrepeated_n_vars))?;
161 check_eval(oracle_label, i, expected, got)?;
162 }
163 }
164 MultilinearPolyVariant::Shifted(shifted) => {
165 let unshifted_poly = witness.get_multilin_poly(shifted.id())?;
166 let block_len = 1 << shifted.block_size();
167 let shift_offset = shifted.shift_offset();
168 for block_start in (0..1 << n_vars).step_by(block_len) {
169 match shifted.shift_variant() {
170 ShiftVariant::CircularLeft => {
171 for offset_after in 0..block_len {
172 check_eval(
173 oracle_label,
174 block_start + offset_after,
175 unshifted_poly.evaluate_on_hypercube(
176 block_start
177 + (offset_after + (block_len - shift_offset)) % block_len,
178 )?,
179 poly.evaluate_on_hypercube(block_start + offset_after)?,
180 )?;
181 }
182 }
183 ShiftVariant::LogicalLeft => {
184 for offset_after in 0..shift_offset {
185 check_eval(
186 oracle_label,
187 block_start + offset_after,
188 F::ZERO,
189 poly.evaluate_on_hypercube(block_start + offset_after)?,
190 )?;
191 }
192 for offset_after in shift_offset..block_len {
193 check_eval(
194 oracle_label,
195 block_start + offset_after,
196 unshifted_poly.evaluate_on_hypercube(
197 block_start + offset_after - shift_offset,
198 )?,
199 poly.evaluate_on_hypercube(block_start + offset_after)?,
200 )?;
201 }
202 }
203 ShiftVariant::LogicalRight => {
204 for offset_after in 0..block_len - shift_offset {
205 check_eval(
206 oracle_label,
207 block_start + offset_after,
208 unshifted_poly.evaluate_on_hypercube(
209 block_start + offset_after + shift_offset,
210 )?,
211 poly.evaluate_on_hypercube(block_start + offset_after)?,
212 )?;
213 }
214 for offset_after in block_len - shift_offset..block_len {
215 check_eval(
216 oracle_label,
217 block_start + offset_after,
218 F::ZERO,
219 poly.evaluate_on_hypercube(block_start + offset_after)?,
220 )?;
221 }
222 }
223 }
224 }
225 }
226 MultilinearPolyVariant::Projected(projected) => {
227 let unprojected_poly = witness.get_multilin_poly(projected.id())?;
228 let partial_query =
229 binius_hal::make_portable_backend().multilinear_query(projected.values())?;
230 let projected_poly = unprojected_poly
231 .evaluate_partial(partial_query.to_ref(), projected.start_index())?;
232
233 for i in 0..1 << n_vars {
234 check_eval(
235 oracle_label,
236 i,
237 projected_poly.evaluate_on_hypercube(i)?,
238 poly.evaluate_on_hypercube(i)?,
239 )?;
240 }
241 }
242 MultilinearPolyVariant::ZeroPadded(padded) => {
243 let inner_id = padded.id();
244 let unpadded_poly = witness.get_multilin_poly(inner_id)?;
245 let n_pad_vars = padded.n_pad_vars();
246 let start_index = padded.start_index();
247 let nonzero_index = padded.nonzero_index();
248 let padded_poly = unpadded_poly.zero_pad(n_pad_vars, start_index, nonzero_index)?;
249 for i in 0..1 << n_pad_vars {
250 check_eval(
251 oracle_label,
252 i,
253 padded_poly.evaluate_on_hypercube(i)?,
254 poly.evaluate_on_hypercube(i)?,
255 )?;
256 }
257 }
258 MultilinearPolyVariant::Packed(packed) => {
259 let expected = witness.get_multilin_poly(packed.id())?;
260 let got = witness.get_multilin_poly(oracle.id())?;
261 if expected.packed_evals() != got.packed_evals() {
262 return Err(Error::PackedUnderlierMismatch {
263 oracle: oracle_label.into(),
264 });
265 }
266 }
267 MultilinearPolyVariant::Composite(composite_mle) => {
268 let inner_polys = composite_mle
269 .polys()
270 .map(|id| witness.get_multilin_poly(id))
271 .collect::<Result<Vec<_>, _>>()?;
272 let composite = MultilinearComposite::new(n_vars, composite_mle.c(), inner_polys)?;
273 for i in 0..1 << n_vars {
274 let got = poly.evaluate_on_hypercube(i)?;
275 let expected = composite.evaluate_on_hypercube(i)?;
276 check_eval(oracle_label, i, expected, got)?;
277 }
278 }
279 }
280 Ok(())
281}
282
283fn check_eval<F: TowerField>(
284 oracle_label: &str,
285 index: usize,
286 expected: F,
287 got: F,
288) -> Result<(), Error> {
289 if expected == got {
290 Ok(())
291 } else {
292 Err(Error::VirtualOracleEvalMismatch {
293 oracle: oracle_label.into(),
294 index,
295 reason: format!("Expected {expected}, got {got}"),
296 })
297 }
298}
299
300pub mod nonzerocheck {
301 use binius_field::{PackedField, TowerField};
302 use binius_math::MultilinearPoly;
303 use binius_maybe_rayon::prelude::*;
304 use binius_utils::bail;
305
306 use crate::{
307 oracle::{MultilinearOracleSet, OracleId},
308 protocols::sumcheck::Error,
309 witness::MultilinearExtensionIndex,
310 };
311
312 pub fn validate_witness<F, P>(
313 witness: &MultilinearExtensionIndex<P>,
314 oracles: &MultilinearOracleSet<P::Scalar>,
315 oracle_ids: &[OracleId],
316 ) -> Result<(), Error>
317 where
318 P: PackedField<Scalar = F>,
319 F: TowerField,
320 {
321 oracle_ids.into_par_iter().try_for_each(|id| {
322 let multilinear = witness.get_multilin_poly(*id)?;
323 (0..(1 << multilinear.n_vars()))
324 .into_par_iter()
325 .try_for_each(|hypercube_index| {
326 if multilinear.evaluate_on_hypercube(hypercube_index)? == F::ZERO {
327 bail!(Error::NonzerocheckNaiveValidationFailure {
328 hypercube_index,
329 oracle: oracles[*id].label()
330 })
331 }
332 Ok(())
333 })
334 })
335 }
336}