binius_core/constraint_system/
validate.rs1use binius_field::{BinaryField1b, PackedExtension, PackedField, TowerField};
4use binius_hal::ComputationBackendExt;
5use binius_math::MultilinearPoly;
6use binius_utils::bail;
7
8use super::{
9 channel::{self, Boundary},
10 error::Error,
11 ConstraintSystem,
12};
13use crate::{
14 oracle::{
15 ConstraintPredicate, MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant,
16 ShiftVariant,
17 },
18 polynomial::{
19 test_utils::decompose_index_to_hypercube_point, ArithCircuitPoly, MultilinearComposite,
20 },
21 protocols::sumcheck::prove::zerocheck,
22 witness::MultilinearExtensionIndex,
23};
24
25pub fn validate_witness<F, P>(
26 constraint_system: &ConstraintSystem<F>,
27 boundaries: &[Boundary<F>],
28 witness: &MultilinearExtensionIndex<'_, P>,
29) -> Result<(), Error>
30where
31 P: PackedField<Scalar = F> + PackedExtension<BinaryField1b>,
32 F: TowerField,
33{
34 for constraint_set in &constraint_system.table_constraints {
36 let multilinears = constraint_set
37 .oracle_ids
38 .iter()
39 .map(|id| witness.get_multilin_poly(*id))
40 .collect::<Result<Vec<_>, _>>()?;
41
42 let mut zero_claims = vec![];
43 for constraint in &constraint_set.constraints {
44 match constraint.predicate {
45 ConstraintPredicate::Zero => zero_claims.push((
46 constraint.name.clone(),
47 ArithCircuitPoly::with_n_vars(
48 multilinears.len(),
49 constraint.composition.clone(),
50 )?,
51 )),
52 ConstraintPredicate::Sum(_) => unimplemented!(),
53 }
54 }
55 zerocheck::validate_witness(&multilinears, &zero_claims)?;
56 }
57
58 nonzerocheck::validate_witness(
60 witness,
61 &constraint_system.oracles,
62 &constraint_system.non_zero_oracle_ids,
63 )?;
64
65 channel::validate_witness(
67 witness,
68 &constraint_system.flushes,
69 boundaries,
70 constraint_system.max_channel_id,
71 )?;
72
73 for oracle in constraint_system.oracles.iter() {
75 validate_virtual_oracle_witness(oracle, &constraint_system.oracles, witness)?;
76 }
77
78 Ok(())
79}
80
81pub fn validate_virtual_oracle_witness<F, P>(
82 oracle: MultilinearPolyOracle<F>,
83 oracles: &MultilinearOracleSet<F>,
84 witness: &MultilinearExtensionIndex<P>,
85) -> Result<(), Error>
86where
87 P: PackedField<Scalar = F>,
88 F: TowerField,
89{
90 let oracle_label = &oracle.label();
91 let n_vars = oracle.n_vars();
92 let poly = witness.get_multilin_poly(oracle.id())?;
93
94 if poly.n_vars() != n_vars {
95 bail!(Error::VirtualOracleNvarsMismatch {
96 oracle: oracle_label.into(),
97 oracle_num_vars: n_vars,
98 witness_num_vars: poly.n_vars(),
99 })
100 }
101
102 match oracle.variant {
103 MultilinearPolyVariant::Committed => {
104 }
106 MultilinearPolyVariant::Transparent(inner) => {
107 for i in 0..1 << n_vars {
108 let got = poly.evaluate_on_hypercube(i)?;
109 let expected = inner
110 .poly()
111 .evaluate(&decompose_index_to_hypercube_point(n_vars, i))?;
112 check_eval(oracle_label, i, expected, got)?;
113 }
114 }
115 MultilinearPolyVariant::LinearCombination(linear_combination) => {
116 let uncombined_polys = linear_combination
117 .polys()
118 .map(|id| witness.get_multilin_poly(id))
119 .collect::<Result<Vec<_>, _>>()?;
120 for i in 0..1 << n_vars {
121 let got = poly.evaluate_on_hypercube(i)?;
122 let expected = linear_combination
123 .coefficients()
124 .zip(uncombined_polys.iter())
125 .try_fold(linear_combination.offset(), |acc, (coeff, poly)| {
126 Ok::<F, Error>(acc + poly.evaluate_on_hypercube_and_scale(i, coeff)?)
127 })?;
128 check_eval(oracle_label, i, expected, got)?;
129 }
130 }
131 MultilinearPolyVariant::Repeating { id, .. } => {
132 let unrepeated_poly = witness.get_multilin_poly(id)?;
133 let unrepeated_n_vars = oracles.n_vars(id);
134 for i in 0..1 << n_vars {
135 let got = poly.evaluate_on_hypercube(i)?;
136 let expected =
137 unrepeated_poly.evaluate_on_hypercube(i % (1 << unrepeated_n_vars))?;
138 check_eval(oracle_label, i, expected, got)?;
139 }
140 }
141 MultilinearPolyVariant::Shifted(shifted) => {
142 let unshifted_poly = witness.get_multilin_poly(shifted.id())?;
143 let block_len = 1 << shifted.block_size();
144 let shift_offset = shifted.shift_offset();
145 for block_start in (0..1 << n_vars).step_by(block_len) {
146 match shifted.shift_variant() {
147 ShiftVariant::CircularLeft => {
148 for offset_after in 0..block_len {
149 check_eval(
150 oracle_label,
151 block_start + offset_after,
152 unshifted_poly.evaluate_on_hypercube(
153 block_start
154 + (offset_after + (block_len - shift_offset)) % block_len,
155 )?,
156 poly.evaluate_on_hypercube(block_start + offset_after)?,
157 )?;
158 }
159 }
160 ShiftVariant::LogicalLeft => {
161 for offset_after in 0..shift_offset {
162 check_eval(
163 oracle_label,
164 block_start + offset_after,
165 F::ZERO,
166 poly.evaluate_on_hypercube(block_start + offset_after)?,
167 )?;
168 }
169 for offset_after in shift_offset..block_len {
170 check_eval(
171 oracle_label,
172 block_start + offset_after,
173 unshifted_poly.evaluate_on_hypercube(
174 block_start + offset_after - shift_offset,
175 )?,
176 poly.evaluate_on_hypercube(block_start + offset_after)?,
177 )?;
178 }
179 }
180 ShiftVariant::LogicalRight => {
181 for offset_after in 0..block_len - shift_offset {
182 check_eval(
183 oracle_label,
184 block_start + offset_after,
185 unshifted_poly.evaluate_on_hypercube(
186 block_start + offset_after + shift_offset,
187 )?,
188 poly.evaluate_on_hypercube(block_start + offset_after)?,
189 )?;
190 }
191 for offset_after in block_len - shift_offset..block_len {
192 check_eval(
193 oracle_label,
194 block_start + offset_after,
195 F::ZERO,
196 poly.evaluate_on_hypercube(block_start + offset_after)?,
197 )?;
198 }
199 }
200 }
201 }
202 }
203 MultilinearPolyVariant::Projected(projected) => {
204 let unprojected_poly = witness.get_multilin_poly(projected.id())?;
205 let partial_query =
206 binius_hal::make_portable_backend().multilinear_query(projected.values())?;
207 let projected_poly = unprojected_poly
208 .evaluate_partial(partial_query.to_ref(), projected.start_index())?;
209
210 for i in 0..1 << n_vars {
211 check_eval(
212 oracle_label,
213 i,
214 projected_poly.evaluate_on_hypercube(i)?,
215 poly.evaluate_on_hypercube(i)?,
216 )?;
217 }
218 }
219 MultilinearPolyVariant::ZeroPadded(inner_id) => {
220 let unpadded_poly = witness.get_multilin_poly(inner_id)?;
221 for i in 0..1 << unpadded_poly.n_vars() {
222 check_eval(
223 oracle_label,
224 i,
225 unpadded_poly.evaluate_on_hypercube(i)?,
226 poly.evaluate_on_hypercube(i)?,
227 )?;
228 }
229 for i in 1 << unpadded_poly.n_vars()..1 << n_vars {
230 check_eval(oracle_label, i, F::ZERO, poly.evaluate_on_hypercube(i)?)?;
231 }
232 }
233 MultilinearPolyVariant::Packed(ref packed) => {
234 let expected = witness.get_multilin_poly(packed.id())?;
235 let got = witness.get_multilin_poly(oracle.id())?;
236 if expected.packed_evals() != got.packed_evals() {
237 return Err(Error::PackedUnderlierMismatch {
238 oracle: oracle_label.into(),
239 });
240 }
241 }
242 MultilinearPolyVariant::Composite(composite_mle) => {
243 let inner_polys = composite_mle
244 .polys()
245 .map(|id| witness.get_multilin_poly(id))
246 .collect::<Result<Vec<_>, _>>()?;
247 let composite = MultilinearComposite::new(n_vars, composite_mle.c(), inner_polys)?;
248 for i in 0..1 << n_vars {
249 let got = poly.evaluate_on_hypercube(i)?;
250 let expected = composite.evaluate_on_hypercube(i)?;
251 check_eval(oracle_label, i, expected, got)?;
252 }
253 }
254 }
255 Ok(())
256}
257
258fn check_eval<F: TowerField>(
259 oracle_label: &str,
260 index: usize,
261 expected: F,
262 got: F,
263) -> Result<(), Error> {
264 if expected == got {
265 Ok(())
266 } else {
267 Err(Error::VirtualOracleEvalMismatch {
268 oracle: oracle_label.into(),
269 index,
270 reason: format!("Expected {expected}, got {got}"),
271 })
272 }
273}
274
275pub mod nonzerocheck {
276 use binius_field::{PackedField, TowerField};
277 use binius_math::MultilinearPoly;
278 use binius_maybe_rayon::prelude::*;
279 use binius_utils::bail;
280
281 use crate::{
282 oracle::{MultilinearOracleSet, OracleId},
283 protocols::sumcheck::Error,
284 witness::MultilinearExtensionIndex,
285 };
286
287 pub fn validate_witness<F, P>(
288 witness: &MultilinearExtensionIndex<P>,
289 oracles: &MultilinearOracleSet<P::Scalar>,
290 oracle_ids: &[OracleId],
291 ) -> Result<(), Error>
292 where
293 P: PackedField<Scalar = F>,
294 F: TowerField,
295 {
296 oracle_ids.into_par_iter().try_for_each(|id| {
297 let multilinear = witness.get_multilin_poly(*id)?;
298 (0..(1 << multilinear.n_vars()))
299 .into_par_iter()
300 .try_for_each(|hypercube_index| {
301 if multilinear.evaluate_on_hypercube(hypercube_index)? == F::ZERO {
302 bail!(Error::NonzerocheckNaiveValidationFailure {
303 hypercube_index,
304 oracle: oracles.oracle(*id).label()
305 })
306 }
307 Ok(())
308 })
309 })
310 }
311}