binius_core/constraint_system/
channel.rs1use std::collections::HashMap;
53
54use binius_field::{as_packed_field::PackScalar, underlier::UnderlierType, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56
57use super::error::{Error, VerificationError};
58use crate::{oracle::OracleId, witness::MultilinearExtensionIndex};
59
60pub type ChannelId = usize;
61
62#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
63pub struct Flush {
64 pub oracles: Vec<OracleId>,
65 pub channel_id: ChannelId,
66 pub direction: FlushDirection,
67 pub selector: OracleId,
68 pub multiplicity: u64,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
72pub struct Boundary<F: TowerField> {
73 pub values: Vec<F>,
74 pub channel_id: ChannelId,
75 pub direction: FlushDirection,
76 pub multiplicity: u64,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
80pub enum FlushDirection {
81 Push,
82 Pull,
83}
84
85pub fn validate_witness<U, F>(
86 witness: &MultilinearExtensionIndex<U, F>,
87 flushes: &[Flush],
88 boundaries: &[Boundary<F>],
89 max_channel_id: ChannelId,
90) -> Result<(), Error>
91where
92 U: UnderlierType + PackScalar<F>,
93 F: TowerField,
94{
95 let mut channels = vec![Channel::<F>::new(); max_channel_id + 1];
96
97 for boundary in boundaries.iter().cloned() {
98 let Boundary {
99 channel_id,
100 values,
101 direction,
102 multiplicity,
103 } = boundary;
104 if channel_id > max_channel_id {
105 return Err(Error::ChannelIdOutOfRange {
106 max: max_channel_id,
107 got: channel_id,
108 });
109 }
110 channels[channel_id].flush(direction, multiplicity, values.clone())?;
111 }
112
113 for flush in flushes {
114 let &Flush {
115 ref oracles,
116 channel_id,
117 direction,
118 selector,
119 multiplicity,
120 } = flush;
121
122 if channel_id > max_channel_id {
123 return Err(Error::ChannelIdOutOfRange {
124 max: max_channel_id,
125 got: channel_id,
126 });
127 }
128
129 let channel = &mut channels[channel_id];
130
131 let polys = oracles
132 .iter()
133 .map(|&id| witness.get_multilin_poly(id))
134 .collect::<Result<Vec<_>, _>>()?;
135
136 if let Some(first_poly) = polys.first() {
138 let n_vars = first_poly.n_vars();
139 for poly in &polys {
140 if poly.n_vars() != n_vars {
141 return Err(Error::ChannelFlushNvarsMismatch {
142 expected: n_vars,
143 got: poly.n_vars(),
144 });
145 }
146 }
147
148 let selector_poly = witness.get_multilin_poly(selector)?;
149 if selector_poly.n_vars() != n_vars {
151 let id = oracles.first().copied().expect("polys is not empty");
152 return Err(Error::IncompatibleFlushSelector { id, selector });
153 }
154
155 for i in 0..1 << selector_poly.n_vars() {
156 if selector_poly.evaluate_on_hypercube(i)?.is_zero() {
157 continue;
158 }
159 let values = polys
160 .iter()
161 .map(|poly| poly.evaluate_on_hypercube(i))
162 .collect::<Result<Vec<_>, _>>()?;
163 channel.flush(direction, multiplicity, values)?;
164 }
165 }
166 }
167
168 for (id, channel) in channels.iter().enumerate() {
169 if !channel.is_balanced() {
170 return Err(VerificationError::ChannelUnbalanced { id }.into());
171 }
172 }
173
174 Ok(())
175}
176
177#[derive(Default, Debug, Clone)]
178struct Channel<F: TowerField> {
179 width: Option<usize>,
180 multiplicities: HashMap<Vec<F>, i64>,
181}
182
183impl<F: TowerField> Channel<F> {
184 fn new() -> Self {
185 Self::default()
186 }
187
188 fn _print_unbalanced_values(&self) {
189 for (key, val) in &self.multiplicities {
190 if *val != 0 {
191 println!("{key:?}: {val}");
192 }
193 }
194 }
195
196 fn flush(
197 &mut self,
198 direction: FlushDirection,
199 multiplicity: u64,
200 values: Vec<F>,
201 ) -> Result<(), Error> {
202 if self.width.is_none() {
203 self.width = Some(values.len());
204 } else if self.width.expect("checked for None above") != values.len() {
205 return Err(Error::ChannelFlushWidthMismatch {
206 expected: self.width.unwrap(),
207 got: values.len(),
208 });
209 }
210 *self.multiplicities.entry(values).or_default() += (multiplicity as i64)
211 * (match direction {
212 FlushDirection::Pull => -1i64,
213 FlushDirection::Push => 1i64,
214 });
215 Ok(())
216 }
217
218 fn is_balanced(&self) -> bool {
219 self.multiplicities.iter().all(|(_, m)| *m == 0)
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use binius_field::BinaryField64b;
226
227 use super::*;
228
229 #[test]
230 fn test_flush_push_single_row() {
231 let mut channel = Channel::<BinaryField64b>::new();
232
233 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
235 let result = channel.flush(FlushDirection::Push, 1, values.clone());
236
237 assert!(result.is_ok());
238 assert!(!channel.is_balanced());
239 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
240 }
241
242 #[test]
243 fn test_flush_pull_single_row() {
244 let mut channel = Channel::<BinaryField64b>::new();
245
246 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
248 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
249
250 assert!(result.is_ok());
251 assert!(!channel.is_balanced());
252 assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
253 }
254
255 #[test]
256 fn test_flush_push_pull_single_row() {
257 let mut channel = Channel::<BinaryField64b>::new();
258
259 let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
261 channel
262 .flush(FlushDirection::Push, 1, values.clone())
263 .unwrap();
264 let result = channel.flush(FlushDirection::Pull, 1, values.clone());
265
266 assert!(result.is_ok());
267 assert!(channel.is_balanced());
268 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
269 }
270
271 #[test]
272 fn test_flush_multiplicity() {
273 let mut channel = Channel::<BinaryField64b>::new();
274
275 let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
277 channel
278 .flush(FlushDirection::Push, 2, values.clone())
279 .unwrap();
280
281 channel
283 .flush(FlushDirection::Pull, 1, values.clone())
284 .unwrap();
285
286 assert!(!channel.is_balanced());
288 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
289
290 channel
292 .flush(FlushDirection::Pull, 1, values.clone())
293 .unwrap();
294
295 assert!(channel.is_balanced());
297 assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
298 }
299
300 #[test]
301 fn test_flush_width_mismatch() {
302 let mut channel = Channel::<BinaryField64b>::new();
303
304 let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
306 channel.flush(FlushDirection::Push, 1, values1).unwrap();
307
308 let values2 = vec![
310 BinaryField64b::from(3),
311 BinaryField64b::from(4),
312 BinaryField64b::from(5),
313 ];
314 let result = channel.flush(FlushDirection::Push, 1, values2);
315
316 assert!(result.is_err());
317 if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
318 assert_eq!(expected, 2);
319 assert_eq!(got, 3);
320 } else {
321 panic!("Expected ChannelFlushWidthMismatch error");
322 }
323 }
324
325 #[test]
326 fn test_flush_direction_effects() {
327 let mut channel = Channel::<BinaryField64b>::new();
328
329 let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
331 channel
332 .flush(FlushDirection::Push, 1, values.clone())
333 .unwrap();
334
335 let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
337 channel
338 .flush(FlushDirection::Pull, 1, values2.clone())
339 .unwrap();
340
341 assert!(!channel.is_balanced());
343 assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
344 assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
345 }
346}