binius_core/constraint_system/
channel.rs1use std::collections::HashMap;
53
54use binius_field::{Field, PackedField, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56use binius_math::MultilinearPoly;
57use itertools::izip;
58
59use super::error::{Error, VerificationError};
60use crate::{oracle::OracleId, witness::MultilinearExtensionIndex};
61
62pub type ChannelId = usize;
63
64#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
65pub enum OracleOrConst<F: Field> {
66 Oracle(usize),
67 Const { base: F, tower_level: usize },
68}
69#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
70pub struct Flush<F: TowerField> {
71 pub oracles: Vec<OracleOrConst<F>>,
72 pub channel_id: ChannelId,
73 pub direction: FlushDirection,
74 pub selectors: Vec<OracleId>,
75 pub multiplicity: u64,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
79pub struct Boundary<F: TowerField> {
80 pub values: Vec<F>,
81 pub channel_id: ChannelId,
82 pub direction: FlushDirection,
83 pub multiplicity: u64,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
87pub enum FlushDirection {
88 Push,
89 Pull,
90}
91
92pub fn validate_witness<F, P>(
93 witness: &MultilinearExtensionIndex<P>,
94 flushes: &[Flush<F>],
95 boundaries: &[Boundary<F>],
96 max_channel_id: ChannelId,
97) -> Result<(), Error>
98where
99 P: PackedField<Scalar = F>,
100 F: TowerField,
101{
102 let mut channels = vec![Channel::<F>::new(); max_channel_id + 1];
103
104 for boundary in boundaries.iter().cloned() {
105 let Boundary {
106 channel_id,
107 values,
108 direction,
109 multiplicity,
110 } = boundary;
111 if channel_id > max_channel_id {
112 return Err(Error::ChannelIdOutOfRange {
113 max: max_channel_id,
114 got: channel_id,
115 });
116 }
117 channels[channel_id].flush(direction, multiplicity, values.clone())?;
118 }
119
120 for flush in flushes {
121 let &Flush {
122 ref oracles,
123 channel_id,
124 direction,
125 ref selectors,
126 multiplicity,
127 } = flush;
128
129 if channel_id > max_channel_id {
130 return Err(Error::ChannelIdOutOfRange {
131 max: max_channel_id,
132 got: channel_id,
133 });
134 }
135
136 let channel = &mut channels[channel_id];
137
138 let non_const_polys = oracles
140 .iter()
141 .filter_map(|&id| match id {
142 OracleOrConst::Oracle(oracle_id) => Some(witness.get_multilin_poly(oracle_id)),
143 _ => None,
144 })
145 .collect::<Result<Vec<_>, _>>()?;
146
147 let selector_polys = selectors
148 .iter()
149 .map(|selector| witness.get_multilin_poly(*selector))
150 .collect::<Result<Vec<_>, _>>()?;
151
152 let n_vars = non_const_polys
153 .first()
154 .map(|poly| poly.n_vars())
155 .unwrap_or(0);
156
157 for poly in &non_const_polys {
159 if poly.n_vars() != n_vars {
160 return Err(Error::ChannelFlushNvarsMismatch {
161 expected: n_vars,
162 got: poly.n_vars(),
163 });
164 }
165 }
166
167 for (&selector, selector_poly) in izip!(selectors, &selector_polys) {
169 if selector_poly.n_vars() != n_vars {
170 let id = oracles
171 .iter()
172 .copied()
173 .filter_map(|id| match id {
174 OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
175 _ => None,
176 })
177 .next()
178 .expect("non_const_polys is not empty");
179 return Err(Error::IncompatibleFlushSelector { id, selector });
180 }
181 }
182
183 for i in 0..1 << n_vars {
184 let selector_off = selector_polys.iter().any(|selector_poly| {
185 selector_poly
186 .evaluate_on_hypercube(i)
187 .expect(
188 "i in range 0..1 << n_vars; \
189 selector_poly checked above to have n_vars variables",
190 )
191 .is_zero()
192 });
193
194 if selector_off {
195 continue;
196 }
197
198 let values = oracles
199 .iter()
200 .copied()
201 .map(|id| match id {
202 OracleOrConst::Const { base, .. } => Ok(base),
203 OracleOrConst::Oracle(oracle_id) => witness
204 .get_multilin_poly(oracle_id)
205 .expect("Witness error would have been caught while checking variables.")
206 .evaluate_on_hypercube(i),
207 })
208 .collect::<Result<Vec<_>, _>>()?;
209 channel.flush(direction, multiplicity, values)?;
210 }
211 }
212
213 for (id, channel) in channels.iter().enumerate() {
214 if !channel.is_balanced() {
215 let unbalanced_flushes: Vec<_> = channel
216 .multiplicities
217 .iter()
218 .filter(|(_, &c)| c != 0i64)
219 .collect();
220
221 tracing::debug!("Channel {:?} unbalanced: {:?}", id, unbalanced_flushes);
222
223 return Err((VerificationError::ChannelUnbalanced { id }).into());
224 }
225 }
226
227 Ok(())
228}
229
230#[derive(Default, Debug, Clone)]
231struct Channel<F: TowerField> {
232 width: Option<usize>,
233 multiplicities: HashMap<Vec<F>, i64>,
234}
235
236impl<F: TowerField> Channel<F> {
237 fn new() -> Self {
238 Self::default()
239 }
240
241 fn _print_unbalanced_values(&self) {
242 for (key, val) in &self.multiplicities {
243 if *val != 0 {
244 println!("{key:?}: {val}");
245 }
246 }
247 }
248
249 fn flush(
250 &mut self,
251 direction: FlushDirection,
252 multiplicity: u64,
253 values: Vec<F>,
254 ) -> Result<(), Error> {
255 if self.width.is_none() {
256 self.width = Some(values.len());
257 } else if self.width.expect("checked for None above") != values.len() {
258 return Err(Error::ChannelFlushWidthMismatch {
259 expected: self.width.unwrap(),
260 got: values.len(),
261 });
262 }
263 *self.multiplicities.entry(values).or_default() += (multiplicity as i64)
264 * (match direction {
265 FlushDirection::Pull => -1i64,
266 FlushDirection::Push => 1i64,
267 });
268 Ok(())
269 }
270
271 fn is_balanced(&self) -> bool {
272 self.multiplicities.iter().all(|(_, m)| *m == 0)
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use binius_field::BinaryField64b;
279
280 use super::*;
281
282 #[test]
283 fn test_flush_push_single_row() {
284 let mut channel = Channel::<BinaryField64b>::new();
285
286 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
288 let result = channel.flush(FlushDirection::Push, 1, values.clone());
289
290 assert!(result.is_ok());
291 assert!(!channel.is_balanced());
292 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
293 }
294
295 #[test]
296 fn test_flush_pull_single_row() {
297 let mut channel = Channel::<BinaryField64b>::new();
298
299 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
301 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
302
303 assert!(result.is_ok());
304 assert!(!channel.is_balanced());
305 assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
306 }
307
308 #[test]
309 fn test_flush_push_pull_single_row() {
310 let mut channel = Channel::<BinaryField64b>::new();
311
312 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
314 channel
315 .flush(FlushDirection::Push, 1, values.clone())
316 .unwrap();
317 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
318
319 assert!(result.is_ok());
320 assert!(channel.is_balanced());
321 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
322 }
323
324 #[test]
325 fn test_flush_multiplicity() {
326 let mut channel = Channel::<BinaryField64b>::new();
327
328 let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
330 channel
331 .flush(FlushDirection::Push, 2, values.clone())
332 .unwrap();
333
334 channel
336 .flush(FlushDirection::Pull, 1, values.clone())
337 .unwrap();
338
339 assert!(!channel.is_balanced());
341 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
342
343 channel
345 .flush(FlushDirection::Pull, 1, values.clone())
346 .unwrap();
347
348 assert!(channel.is_balanced());
350 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
351 }
352
353 #[test]
354 fn test_flush_width_mismatch() {
355 let mut channel = Channel::<BinaryField64b>::new();
356
357 let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
359 channel.flush(FlushDirection::Push, 1, values1).unwrap();
360
361 let values2 = vec![
363 BinaryField64b::from(3),
364 BinaryField64b::from(4),
365 BinaryField64b::from(5),
366 ];
367 let result = channel.flush(FlushDirection::Push, 1, values2);
368
369 assert!(result.is_err());
370 if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
371 assert_eq!(expected, 2);
372 assert_eq!(got, 3);
373 } else {
374 panic!("Expected ChannelFlushWidthMismatch error");
375 }
376 }
377
378 #[test]
379 fn test_flush_direction_effects() {
380 let mut channel = Channel::<BinaryField64b>::new();
381
382 let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
384 channel
385 .flush(FlushDirection::Push, 1, values.clone())
386 .unwrap();
387
388 let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
390 channel
391 .flush(FlushDirection::Pull, 1, values2.clone())
392 .unwrap();
393
394 assert!(!channel.is_balanced());
396 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
397 assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
398 }
399}