binius_core/constraint_system/
channel.rs1use std::collections::HashMap;
53
54use binius_field::{Field, PackedField, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56
57use super::error::{Error, VerificationError};
58use crate::{oracle::OracleId, witness::MultilinearExtensionIndex};
59
60pub type ChannelId = usize;
61
62#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
63pub enum OracleOrConst<F: Field> {
64 Oracle(usize),
65 Const { base: F, tower_level: usize },
66}
67#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
68pub struct Flush<F: TowerField> {
69 pub oracles: Vec<OracleOrConst<F>>,
70 pub channel_id: ChannelId,
71 pub direction: FlushDirection,
72 pub selector: OracleId,
73 pub multiplicity: u64,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
77pub struct Boundary<F: TowerField> {
78 pub values: Vec<F>,
79 pub channel_id: ChannelId,
80 pub direction: FlushDirection,
81 pub multiplicity: u64,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
85pub enum FlushDirection {
86 Push,
87 Pull,
88}
89
90pub fn validate_witness<F, P>(
91 witness: &MultilinearExtensionIndex<P>,
92 flushes: &[Flush<F>],
93 boundaries: &[Boundary<F>],
94 max_channel_id: ChannelId,
95) -> Result<(), Error>
96where
97 P: PackedField<Scalar = F>,
98 F: TowerField,
99{
100 let mut channels = vec![Channel::<F>::new(); max_channel_id + 1];
101
102 for boundary in boundaries.iter().cloned() {
103 let Boundary {
104 channel_id,
105 values,
106 direction,
107 multiplicity,
108 } = boundary;
109 if channel_id > max_channel_id {
110 return Err(Error::ChannelIdOutOfRange {
111 max: max_channel_id,
112 got: channel_id,
113 });
114 }
115 channels[channel_id].flush(direction, multiplicity, values.clone())?;
116 }
117
118 for flush in flushes {
119 let &Flush {
120 ref oracles,
121 channel_id,
122 direction,
123 selector,
124 multiplicity,
125 } = flush;
126
127 if channel_id > max_channel_id {
128 return Err(Error::ChannelIdOutOfRange {
129 max: max_channel_id,
130 got: channel_id,
131 });
132 }
133
134 let channel = &mut channels[channel_id];
135
136 let non_const_polys = oracles
138 .iter()
139 .filter_map(|&id| match id {
140 OracleOrConst::Oracle(oracle_id) => Some(witness.get_multilin_poly(oracle_id)),
141 _ => None,
142 })
143 .collect::<Result<Vec<_>, _>>()?;
144
145 let selector_poly = witness.get_multilin_poly(selector)?;
146
147 if let Some(first_poly) = non_const_polys.first() {
148 let n_vars = first_poly.n_vars();
150 for poly in &non_const_polys {
151 if poly.n_vars() != n_vars {
152 return Err(Error::ChannelFlushNvarsMismatch {
153 expected: n_vars,
154 got: poly.n_vars(),
155 });
156 }
157 }
158
159 if selector_poly.n_vars() != n_vars {
161 let id = oracles
162 .iter()
163 .copied()
164 .filter_map(|id| match id {
165 OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
166 _ => None,
167 })
168 .next()
169 .expect("non_const_polys is not empty");
170 return Err(Error::IncompatibleFlushSelector { id, selector });
171 }
172 }
173
174 for i in 0..1 << selector_poly.n_vars() {
175 if selector_poly.evaluate_on_hypercube(i)?.is_zero() {
176 continue;
177 }
178 let values = oracles
179 .iter()
180 .copied()
181 .map(|id| match id {
182 OracleOrConst::Const { base, .. } => Ok(base),
183 OracleOrConst::Oracle(oracle_id) => witness
184 .get_multilin_poly(oracle_id)
185 .expect("Witness error would have been caught while checking variables.")
186 .evaluate_on_hypercube(i),
187 })
188 .collect::<Result<Vec<_>, _>>()?;
189 channel.flush(direction, multiplicity, values)?;
190 }
191 }
192
193 for (id, channel) in channels.iter().enumerate() {
194 if !channel.is_balanced() {
195 return Err((VerificationError::ChannelUnbalanced { id }).into());
196 }
197 }
198
199 Ok(())
200}
201
202#[derive(Default, Debug, Clone)]
203struct Channel<F: TowerField> {
204 width: Option<usize>,
205 multiplicities: HashMap<Vec<F>, i64>,
206}
207
208impl<F: TowerField> Channel<F> {
209 fn new() -> Self {
210 Self::default()
211 }
212
213 fn _print_unbalanced_values(&self) {
214 for (key, val) in &self.multiplicities {
215 if *val != 0 {
216 println!("{key:?}: {val}");
217 }
218 }
219 }
220
221 fn flush(
222 &mut self,
223 direction: FlushDirection,
224 multiplicity: u64,
225 values: Vec<F>,
226 ) -> Result<(), Error> {
227 if self.width.is_none() {
228 self.width = Some(values.len());
229 } else if self.width.expect("checked for None above") != values.len() {
230 return Err(Error::ChannelFlushWidthMismatch {
231 expected: self.width.unwrap(),
232 got: values.len(),
233 });
234 }
235 *self.multiplicities.entry(values).or_default() += (multiplicity as i64)
236 * (match direction {
237 FlushDirection::Pull => -1i64,
238 FlushDirection::Push => 1i64,
239 });
240 Ok(())
241 }
242
243 fn is_balanced(&self) -> bool {
244 self.multiplicities.iter().all(|(_, m)| *m == 0)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use binius_field::BinaryField64b;
251
252 use super::*;
253
254 #[test]
255 fn test_flush_push_single_row() {
256 let mut channel = Channel::<BinaryField64b>::new();
257
258 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
260 let result = channel.flush(FlushDirection::Push, 1, values.clone());
261
262 assert!(result.is_ok());
263 assert!(!channel.is_balanced());
264 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
265 }
266
267 #[test]
268 fn test_flush_pull_single_row() {
269 let mut channel = Channel::<BinaryField64b>::new();
270
271 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
273 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
274
275 assert!(result.is_ok());
276 assert!(!channel.is_balanced());
277 assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
278 }
279
280 #[test]
281 fn test_flush_push_pull_single_row() {
282 let mut channel = Channel::<BinaryField64b>::new();
283
284 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
286 channel
287 .flush(FlushDirection::Push, 1, values.clone())
288 .unwrap();
289 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
290
291 assert!(result.is_ok());
292 assert!(channel.is_balanced());
293 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
294 }
295
296 #[test]
297 fn test_flush_multiplicity() {
298 let mut channel = Channel::<BinaryField64b>::new();
299
300 let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
302 channel
303 .flush(FlushDirection::Push, 2, values.clone())
304 .unwrap();
305
306 channel
308 .flush(FlushDirection::Pull, 1, values.clone())
309 .unwrap();
310
311 assert!(!channel.is_balanced());
313 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
314
315 channel
317 .flush(FlushDirection::Pull, 1, values.clone())
318 .unwrap();
319
320 assert!(channel.is_balanced());
322 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
323 }
324
325 #[test]
326 fn test_flush_width_mismatch() {
327 let mut channel = Channel::<BinaryField64b>::new();
328
329 let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
331 channel.flush(FlushDirection::Push, 1, values1).unwrap();
332
333 let values2 = vec![
335 BinaryField64b::from(3),
336 BinaryField64b::from(4),
337 BinaryField64b::from(5),
338 ];
339 let result = channel.flush(FlushDirection::Push, 1, values2);
340
341 assert!(result.is_err());
342 if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
343 assert_eq!(expected, 2);
344 assert_eq!(got, 3);
345 } else {
346 panic!("Expected ChannelFlushWidthMismatch error");
347 }
348 }
349
350 #[test]
351 fn test_flush_direction_effects() {
352 let mut channel = Channel::<BinaryField64b>::new();
353
354 let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
356 channel
357 .flush(FlushDirection::Push, 1, values.clone())
358 .unwrap();
359
360 let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
362 channel
363 .flush(FlushDirection::Pull, 1, values2.clone())
364 .unwrap();
365
366 assert!(!channel.is_balanced());
368 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
369 assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
370 }
371}