binius_core/constraint_system/
channel.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! A channel allows communication between tables.
4//!
5//! Note that the channel is unordered - meaning that rows are not
6//! constrained to be in the same order when being pushed and pulled.
7//!
8//! The number of columns per channel must be fixed, but can be any
9//! positive integer. Column order is guaranteed, and column values within
10//! the same row must always stay together.
11//!
12//! A channel only ensures that the inputs and outputs match, using a
13//! multiset check. If you want any kind of ordering, you have to
14//! use polynomial constraints to additionally constraint this.
15//!
16//! The example below shows a channel with width=2, with multiple inputs
17//! and outputs.
18//! ```txt
19//!                                       +-+-+
20//!                                       |C|D|
21//! +-+-+                           +---> +-+-+
22//! |A|B|                           |     |M|N|
23//! +-+-+                           |     +-+-+
24//! |C|D|                           |
25//! +-+-+  --+                      |     +-+-+
26//! |E|F|    |                      |     |I|J|
27//! +-+-+    |                      |     +-+-+
28//! |G|H|    |                      |     |W|X|
29//! +-+-+    |                      | +-> +-+-+
30//!          |                      | |   |A|B|
31//! +-+-+    +-> /¯\¯¯¯¯¯¯¯¯¯¯¯\  --+ |   +-+-+
32//! |I|J|       :   :           : ----+   |K|L|
33//! +-+-+  PUSH |   |  channel  |  PULL   +-+-+
34//! |K|L|       :   :           : ----+
35//! +-+-+    +-> \_/___________/  --+ |   +-+-+
36//! |M|N|    |                      | |   |U|V|
37//! +-+-+    |                      | |   +-+-+
38//! |O|P|    |                      | |   |G|H|
39//! +-+-+  --+                      | +-> +-+-+
40//! |Q|R|                           |     |E|F|
41//! +-+-+                           |     +-+-+
42//! |S|T|                           |     |Q|R|
43//! +-+-+                           |     +-+-+
44//! |U|V|                           |
45//! +-+-+                           |     +-+-+
46//! |W|X|                           |     |O|P|
47//! +-+-+                           +---> +-+-+
48//!                                       |S|T|
49//!                                       +-+-+
50//! ```
51
52use std::collections::HashMap;
53
54use binius_field::{Field, PackedField, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56
57use super::error::{Error, VerificationError};
58use crate::{oracle::OracleId, witness::MultilinearExtensionIndex};
59
60pub type ChannelId = usize;
61
62#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
63pub enum OracleOrConst<F: Field> {
64	Oracle(usize),
65	Const { base: F, tower_level: usize },
66}
67#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
68pub struct Flush<F: TowerField> {
69	pub oracles: Vec<OracleOrConst<F>>,
70	pub channel_id: ChannelId,
71	pub direction: FlushDirection,
72	pub selector: OracleId,
73	pub multiplicity: u64,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
77pub struct Boundary<F: TowerField> {
78	pub values: Vec<F>,
79	pub channel_id: ChannelId,
80	pub direction: FlushDirection,
81	pub multiplicity: u64,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
85pub enum FlushDirection {
86	Push,
87	Pull,
88}
89
90pub fn validate_witness<F, P>(
91	witness: &MultilinearExtensionIndex<P>,
92	flushes: &[Flush<F>],
93	boundaries: &[Boundary<F>],
94	max_channel_id: ChannelId,
95) -> Result<(), Error>
96where
97	P: PackedField<Scalar = F>,
98	F: TowerField,
99{
100	let mut channels = vec![Channel::<F>::new(); max_channel_id + 1];
101
102	for boundary in boundaries.iter().cloned() {
103		let Boundary {
104			channel_id,
105			values,
106			direction,
107			multiplicity,
108		} = boundary;
109		if channel_id > max_channel_id {
110			return Err(Error::ChannelIdOutOfRange {
111				max: max_channel_id,
112				got: channel_id,
113			});
114		}
115		channels[channel_id].flush(direction, multiplicity, values.clone())?;
116	}
117
118	for flush in flushes {
119		let &Flush {
120			ref oracles,
121			channel_id,
122			direction,
123			selector,
124			multiplicity,
125		} = flush;
126
127		if channel_id > max_channel_id {
128			return Err(Error::ChannelIdOutOfRange {
129				max: max_channel_id,
130				got: channel_id,
131			});
132		}
133
134		let channel = &mut channels[channel_id];
135
136		//We check the variables only of OracleOrConst::Oracle variant oracles being the same.
137		let non_const_polys = oracles
138			.iter()
139			.filter_map(|&id| match id {
140				OracleOrConst::Oracle(oracle_id) => Some(witness.get_multilin_poly(oracle_id)),
141				_ => None,
142			})
143			.collect::<Result<Vec<_>, _>>()?;
144
145		let selector_poly = witness.get_multilin_poly(selector)?;
146
147		if let Some(first_poly) = non_const_polys.first() {
148			// Ensure that all the polys in a single flush have the same n_vars
149			let n_vars = first_poly.n_vars();
150			for poly in &non_const_polys {
151				if poly.n_vars() != n_vars {
152					return Err(Error::ChannelFlushNvarsMismatch {
153						expected: n_vars,
154						got: poly.n_vars(),
155					});
156				}
157			}
158
159			// Check selector polynomial is compatible
160			if selector_poly.n_vars() != n_vars {
161				let id = oracles
162					.iter()
163					.copied()
164					.filter_map(|id| match id {
165						OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
166						_ => None,
167					})
168					.next()
169					.expect("non_const_polys is not empty");
170				return Err(Error::IncompatibleFlushSelector { id, selector });
171			}
172		}
173
174		for i in 0..1 << selector_poly.n_vars() {
175			if selector_poly.evaluate_on_hypercube(i)?.is_zero() {
176				continue;
177			}
178			let values = oracles
179				.iter()
180				.copied()
181				.map(|id| match id {
182					OracleOrConst::Const { base, .. } => Ok(base),
183					OracleOrConst::Oracle(oracle_id) => witness
184						.get_multilin_poly(oracle_id)
185						.expect("Witness error would have been caught while checking variables.")
186						.evaluate_on_hypercube(i),
187				})
188				.collect::<Result<Vec<_>, _>>()?;
189			channel.flush(direction, multiplicity, values)?;
190		}
191	}
192
193	for (id, channel) in channels.iter().enumerate() {
194		if !channel.is_balanced() {
195			return Err((VerificationError::ChannelUnbalanced { id }).into());
196		}
197	}
198
199	Ok(())
200}
201
202#[derive(Default, Debug, Clone)]
203struct Channel<F: TowerField> {
204	width: Option<usize>,
205	multiplicities: HashMap<Vec<F>, i64>,
206}
207
208impl<F: TowerField> Channel<F> {
209	fn new() -> Self {
210		Self::default()
211	}
212
213	fn _print_unbalanced_values(&self) {
214		for (key, val) in &self.multiplicities {
215			if *val != 0 {
216				println!("{key:?}: {val}");
217			}
218		}
219	}
220
221	fn flush(
222		&mut self,
223		direction: FlushDirection,
224		multiplicity: u64,
225		values: Vec<F>,
226	) -> Result<(), Error> {
227		if self.width.is_none() {
228			self.width = Some(values.len());
229		} else if self.width.expect("checked for None above") != values.len() {
230			return Err(Error::ChannelFlushWidthMismatch {
231				expected: self.width.unwrap(),
232				got: values.len(),
233			});
234		}
235		*self.multiplicities.entry(values).or_default() += (multiplicity as i64)
236			* (match direction {
237				FlushDirection::Pull => -1i64,
238				FlushDirection::Push => 1i64,
239			});
240		Ok(())
241	}
242
243	fn is_balanced(&self) -> bool {
244		self.multiplicities.iter().all(|(_, m)| *m == 0)
245	}
246}
247
248#[cfg(test)]
249mod tests {
250	use binius_field::BinaryField64b;
251
252	use super::*;
253
254	#[test]
255	fn test_flush_push_single_row() {
256		let mut channel = Channel::<BinaryField64b>::new();
257
258		// Push a single row of data
259		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
260		let result = channel.flush(FlushDirection::Push, 1, values.clone());
261
262		assert!(result.is_ok());
263		assert!(!channel.is_balanced());
264		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
265	}
266
267	#[test]
268	fn test_flush_pull_single_row() {
269		let mut channel = Channel::<BinaryField64b>::new();
270
271		// Pull a single row of data
272		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
273		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
274
275		assert!(result.is_ok());
276		assert!(!channel.is_balanced());
277		assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
278	}
279
280	#[test]
281	fn test_flush_push_pull_single_row() {
282		let mut channel = Channel::<BinaryField64b>::new();
283
284		// Push and then pull the same row
285		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
286		channel
287			.flush(FlushDirection::Push, 1, values.clone())
288			.unwrap();
289		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
290
291		assert!(result.is_ok());
292		assert!(channel.is_balanced());
293		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
294	}
295
296	#[test]
297	fn test_flush_multiplicity() {
298		let mut channel = Channel::<BinaryField64b>::new();
299
300		// Push multiple rows with a multiplicity of 2
301		let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
302		channel
303			.flush(FlushDirection::Push, 2, values.clone())
304			.unwrap();
305
306		// Pull the same row with a multiplicity of 1
307		channel
308			.flush(FlushDirection::Pull, 1, values.clone())
309			.unwrap();
310
311		// The channel should not be balanced yet
312		assert!(!channel.is_balanced());
313		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
314
315		// Pull the same row again with a multiplicity of 1
316		channel
317			.flush(FlushDirection::Pull, 1, values.clone())
318			.unwrap();
319
320		// Now the channel should be balanced
321		assert!(channel.is_balanced());
322		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
323	}
324
325	#[test]
326	fn test_flush_width_mismatch() {
327		let mut channel = Channel::<BinaryField64b>::new();
328
329		// Push a row with width 2
330		let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
331		channel.flush(FlushDirection::Push, 1, values1).unwrap();
332
333		// Attempt to push a row with width 3
334		let values2 = vec![
335			BinaryField64b::from(3),
336			BinaryField64b::from(4),
337			BinaryField64b::from(5),
338		];
339		let result = channel.flush(FlushDirection::Push, 1, values2);
340
341		assert!(result.is_err());
342		if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
343			assert_eq!(expected, 2);
344			assert_eq!(got, 3);
345		} else {
346			panic!("Expected ChannelFlushWidthMismatch error");
347		}
348	}
349
350	#[test]
351	fn test_flush_direction_effects() {
352		let mut channel = Channel::<BinaryField64b>::new();
353
354		// Push a row
355		let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
356		channel
357			.flush(FlushDirection::Push, 1, values.clone())
358			.unwrap();
359
360		// Pull a different row
361		let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
362		channel
363			.flush(FlushDirection::Pull, 1, values2.clone())
364			.unwrap();
365
366		// The channel should not be balanced because different rows were pushed and pulled
367		assert!(!channel.is_balanced());
368		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
369		assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
370	}
371}