binius_core/constraint_system/
channel.rs1use std::collections::HashMap;
53
54use binius_field::{Field, PackedField, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56use binius_math::MultilinearPoly;
57use itertools::izip;
58
59use super::error::{Error, VerificationError};
60use crate::{constraint_system::TableId, oracle::OracleId, witness::MultilinearExtensionIndex};
61
62pub type ChannelId = usize;
63
64#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
65pub enum OracleOrConst<F: Field> {
66 Oracle(OracleId),
67 Const { base: F, tower_level: usize },
68}
69#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
70pub struct Flush<F: TowerField> {
71 pub table_id: TableId,
72 pub log_values_per_row: usize,
73 pub oracles: Vec<OracleOrConst<F>>,
74 pub channel_id: ChannelId,
75 pub direction: FlushDirection,
76 pub selectors: Vec<OracleId>,
77 pub multiplicity: u64,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
81pub struct Boundary<F: TowerField> {
82 pub values: Vec<F>,
83 pub channel_id: ChannelId,
84 pub direction: FlushDirection,
85 pub multiplicity: u64,
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
89pub enum FlushDirection {
90 Push,
91 Pull,
92}
93
94pub fn validate_witness<F, P>(
95 witness: &MultilinearExtensionIndex<P>,
96 flushes: &[Flush<F>],
97 boundaries: &[Boundary<F>],
98 table_sizes: &[usize],
99 channel_count: usize,
100) -> Result<(), Error>
101where
102 P: PackedField<Scalar = F>,
103 F: TowerField,
104{
105 let mut channels = vec![Channel::<F>::new(); channel_count];
106 let max_channel_id = channel_count.saturating_sub(1);
107
108 for boundary in boundaries.iter().cloned() {
109 let Boundary {
110 channel_id,
111 values,
112 direction,
113 multiplicity,
114 } = boundary;
115 if channel_id > max_channel_id {
116 return Err(Error::ChannelIdOutOfRange {
117 max: max_channel_id,
118 got: channel_id,
119 });
120 }
121 channels[channel_id].flush(direction, multiplicity, values.clone())?;
122 }
123
124 for flush in flushes {
125 let &Flush {
126 ref oracles,
127 channel_id,
128 direction,
129 ref selectors,
130 multiplicity,
131 table_id,
132 log_values_per_row,
133 } = flush;
134
135 if channel_id > max_channel_id {
136 return Err(Error::ChannelIdOutOfRange {
137 max: max_channel_id,
138 got: channel_id,
139 });
140 }
141
142 let table_size = table_sizes[table_id];
143 if table_size == 0 {
144 continue;
145 }
146
147 let non_const_polys = oracles
149 .iter()
150 .filter_map(|&id| match id {
151 OracleOrConst::Oracle(oracle_id) => Some(witness.get_multilin_poly(oracle_id)),
152 _ => None,
153 })
154 .collect::<Result<Vec<_>, _>>()?;
155
156 let selector_polys = selectors
157 .iter()
158 .map(|selector| witness.get_multilin_poly(*selector))
159 .collect::<Result<Vec<_>, _>>()?;
160
161 let n_vars = non_const_polys
162 .first()
163 .map(|poly| poly.n_vars())
164 .unwrap_or(0);
165
166 for poly in &non_const_polys {
168 if poly.n_vars() != n_vars {
169 return Err(Error::ChannelFlushNvarsMismatch {
170 expected: n_vars,
171 got: poly.n_vars(),
172 });
173 }
174 }
175
176 for (&selector, selector_poly) in izip!(selectors, &selector_polys) {
178 if selector_poly.n_vars() != n_vars {
179 let id = oracles
180 .iter()
181 .copied()
182 .filter_map(|id| match id {
183 OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
184 _ => None,
185 })
186 .next()
187 .expect("non_const_polys is not empty");
188 return Err(Error::IncompatibleFlushSelector { id, selector });
189 }
190 }
191
192 let values_per_row = 1 << log_values_per_row;
193 assert!(table_size * values_per_row <= 1 << n_vars);
194 for i in 0..table_size * values_per_row {
195 let selector_off = selector_polys.iter().any(|selector_poly| {
196 selector_poly
197 .evaluate_on_hypercube(i)
198 .expect(
199 "i in range 0..1 << n_vars; \
200 selector_poly checked above to have n_vars variables",
201 )
202 .is_zero()
203 });
204
205 if selector_off {
206 continue;
207 }
208
209 let values = oracles
210 .iter()
211 .copied()
212 .map(|id| match id {
213 OracleOrConst::Const { base, .. } => Ok(base),
214 OracleOrConst::Oracle(oracle_id) => witness
215 .get_multilin_poly(oracle_id)
216 .expect("Witness error would have been caught while checking variables.")
217 .evaluate_on_hypercube(i),
218 })
219 .collect::<Result<Vec<_>, _>>()?;
220 channels[channel_id].flush(direction, multiplicity, values)?;
221 }
222 }
223
224 for (id, channel) in channels.iter().enumerate() {
225 if !channel.is_balanced() {
226 let unbalanced_flushes: Vec<_> = channel
227 .multiplicities
228 .iter()
229 .filter(|(_, c)| **c != 0i64)
230 .collect();
231
232 tracing::debug!("Channel {:?} unbalanced: {:?}", id, unbalanced_flushes);
233
234 return Err((VerificationError::ChannelUnbalanced { id }).into());
235 }
236 }
237
238 Ok(())
239}
240
241#[derive(Default, Debug, Clone)]
242struct Channel<F: TowerField> {
243 width: Option<usize>,
244 multiplicities: HashMap<Vec<F>, i64>,
245}
246
247impl<F: TowerField> Channel<F> {
248 fn new() -> Self {
249 Self::default()
250 }
251
252 fn _print_unbalanced_values(&self) {
253 for (key, val) in &self.multiplicities {
254 if *val != 0 {
255 println!("{key:?}: {val}");
256 }
257 }
258 }
259
260 fn flush(
261 &mut self,
262 direction: FlushDirection,
263 multiplicity: u64,
264 values: Vec<F>,
265 ) -> Result<(), Error> {
266 if self.width.is_none() {
267 self.width = Some(values.len());
268 } else if self.width.expect("checked for None above") != values.len() {
269 return Err(Error::ChannelFlushWidthMismatch {
270 expected: self.width.unwrap(),
271 got: values.len(),
272 });
273 }
274 *self.multiplicities.entry(values).or_default() += (multiplicity as i64)
275 * (match direction {
276 FlushDirection::Pull => -1i64,
277 FlushDirection::Push => 1i64,
278 });
279 Ok(())
280 }
281
282 fn is_balanced(&self) -> bool {
283 self.multiplicities.iter().all(|(_, m)| *m == 0)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use binius_field::BinaryField64b;
290
291 use super::*;
292
293 #[test]
294 fn test_flush_push_single_row() {
295 let mut channel = Channel::<BinaryField64b>::new();
296
297 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
299 let result = channel.flush(FlushDirection::Push, 1, values.clone());
300
301 assert!(result.is_ok());
302 assert!(!channel.is_balanced());
303 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
304 }
305
306 #[test]
307 fn test_flush_pull_single_row() {
308 let mut channel = Channel::<BinaryField64b>::new();
309
310 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
312 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
313
314 assert!(result.is_ok());
315 assert!(!channel.is_balanced());
316 assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
317 }
318
319 #[test]
320 fn test_flush_push_pull_single_row() {
321 let mut channel = Channel::<BinaryField64b>::new();
322
323 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
325 channel
326 .flush(FlushDirection::Push, 1, values.clone())
327 .unwrap();
328 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
329
330 assert!(result.is_ok());
331 assert!(channel.is_balanced());
332 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
333 }
334
335 #[test]
336 fn test_flush_multiplicity() {
337 let mut channel = Channel::<BinaryField64b>::new();
338
339 let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
341 channel
342 .flush(FlushDirection::Push, 2, values.clone())
343 .unwrap();
344
345 channel
347 .flush(FlushDirection::Pull, 1, values.clone())
348 .unwrap();
349
350 assert!(!channel.is_balanced());
352 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
353
354 channel
356 .flush(FlushDirection::Pull, 1, values.clone())
357 .unwrap();
358
359 assert!(channel.is_balanced());
361 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
362 }
363
364 #[test]
365 fn test_flush_width_mismatch() {
366 let mut channel = Channel::<BinaryField64b>::new();
367
368 let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
370 channel.flush(FlushDirection::Push, 1, values1).unwrap();
371
372 let values2 = vec![
374 BinaryField64b::from(3),
375 BinaryField64b::from(4),
376 BinaryField64b::from(5),
377 ];
378 let result = channel.flush(FlushDirection::Push, 1, values2);
379
380 assert!(result.is_err());
381 if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
382 assert_eq!(expected, 2);
383 assert_eq!(got, 3);
384 } else {
385 panic!("Expected ChannelFlushWidthMismatch error");
386 }
387 }
388
389 #[test]
390 fn test_flush_direction_effects() {
391 let mut channel = Channel::<BinaryField64b>::new();
392
393 let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
395 channel
396 .flush(FlushDirection::Push, 1, values.clone())
397 .unwrap();
398
399 let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
401 channel
402 .flush(FlushDirection::Pull, 1, values2.clone())
403 .unwrap();
404
405 assert!(!channel.is_balanced());
407 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
408 assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
409 }
410}