binius_core/constraint_system/
channel.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! A channel allows communication between tables.
4//!
5//! Note that the channel is unordered - meaning that rows are not
6//! constrained to be in the same order when being pushed and pulled.
7//!
8//! The number of columns per channel must be fixed, but can be any
9//! positive integer. Column order is guaranteed, and column values within
10//! the same row must always stay together.
11//!
12//! A channel only ensures that the inputs and outputs match, using a
13//! multiset check. If you want any kind of ordering, you have to
14//! use polynomial constraints to additionally constraint this.
15//!
16//! The example below shows a channel with width=2, with multiple inputs
17//! and outputs.
18//! ```txt
19//!                                       +-+-+
20//!                                       |C|D|
21//! +-+-+                           +---> +-+-+
22//! |A|B|                           |     |M|N|
23//! +-+-+                           |     +-+-+
24//! |C|D|                           |
25//! +-+-+  --+                      |     +-+-+
26//! |E|F|    |                      |     |I|J|
27//! +-+-+    |                      |     +-+-+
28//! |G|H|    |                      |     |W|X|
29//! +-+-+    |                      | +-> +-+-+
30//!          |                      | |   |A|B|
31//! +-+-+    +-> /¯\¯¯¯¯¯¯¯¯¯¯¯\  --+ |   +-+-+
32//! |I|J|       :   :           : ----+   |K|L|
33//! +-+-+  PUSH |   |  channel  |  PULL   +-+-+
34//! |K|L|       :   :           : ----+
35//! +-+-+    +-> \_/___________/  --+ |   +-+-+
36//! |M|N|    |                      | |   |U|V|
37//! +-+-+    |                      | |   +-+-+
38//! |O|P|    |                      | |   |G|H|
39//! +-+-+  --+                      | +-> +-+-+
40//! |Q|R|                           |     |E|F|
41//! +-+-+                           |     +-+-+
42//! |S|T|                           |     |Q|R|
43//! +-+-+                           |     +-+-+
44//! |U|V|                           |
45//! +-+-+                           |     +-+-+
46//! |W|X|                           |     |O|P|
47//! +-+-+                           +---> +-+-+
48//!                                       |S|T|
49//!                                       +-+-+
50//! ```
51
52use std::collections::HashMap;
53
54use binius_field::{as_packed_field::PackScalar, underlier::UnderlierType, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56
57use super::error::{Error, VerificationError};
58use crate::{oracle::OracleId, witness::MultilinearExtensionIndex};
59
60pub type ChannelId = usize;
61
62#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
63pub struct Flush {
64	pub oracles: Vec<OracleId>,
65	pub channel_id: ChannelId,
66	pub direction: FlushDirection,
67	pub selector: OracleId,
68	pub multiplicity: u64,
69}
70
71#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
72pub struct Boundary<F: TowerField> {
73	pub values: Vec<F>,
74	pub channel_id: ChannelId,
75	pub direction: FlushDirection,
76	pub multiplicity: u64,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
80pub enum FlushDirection {
81	Push,
82	Pull,
83}
84
85pub fn validate_witness<U, F>(
86	witness: &MultilinearExtensionIndex<U, F>,
87	flushes: &[Flush],
88	boundaries: &[Boundary<F>],
89	max_channel_id: ChannelId,
90) -> Result<(), Error>
91where
92	U: UnderlierType + PackScalar<F>,
93	F: TowerField,
94{
95	let mut channels = vec![Channel::<F>::new(); max_channel_id + 1];
96
97	for boundary in boundaries.iter().cloned() {
98		let Boundary {
99			channel_id,
100			values,
101			direction,
102			multiplicity,
103		} = boundary;
104		if channel_id > max_channel_id {
105			return Err(Error::ChannelIdOutOfRange {
106				max: max_channel_id,
107				got: channel_id,
108			});
109		}
110		channels[channel_id].flush(direction, multiplicity, values.clone())?;
111	}
112
113	for flush in flushes {
114		let &Flush {
115			ref oracles,
116			channel_id,
117			direction,
118			selector,
119			multiplicity,
120		} = flush;
121
122		if channel_id > max_channel_id {
123			return Err(Error::ChannelIdOutOfRange {
124				max: max_channel_id,
125				got: channel_id,
126			});
127		}
128
129		let channel = &mut channels[channel_id];
130
131		let polys = oracles
132			.iter()
133			.map(|&id| witness.get_multilin_poly(id))
134			.collect::<Result<Vec<_>, _>>()?;
135
136		// Ensure that all the polys in a single flush have the same n_vars
137		if let Some(first_poly) = polys.first() {
138			let n_vars = first_poly.n_vars();
139			for poly in &polys {
140				if poly.n_vars() != n_vars {
141					return Err(Error::ChannelFlushNvarsMismatch {
142						expected: n_vars,
143						got: poly.n_vars(),
144					});
145				}
146			}
147
148			let selector_poly = witness.get_multilin_poly(selector)?;
149			// Check selector polynomial is compatible
150			if selector_poly.n_vars() != n_vars {
151				let id = oracles.first().copied().expect("polys is not empty");
152				return Err(Error::IncompatibleFlushSelector { id, selector });
153			}
154
155			for i in 0..1 << selector_poly.n_vars() {
156				if selector_poly.evaluate_on_hypercube(i)?.is_zero() {
157					continue;
158				}
159				let values = polys
160					.iter()
161					.map(|poly| poly.evaluate_on_hypercube(i))
162					.collect::<Result<Vec<_>, _>>()?;
163				channel.flush(direction, multiplicity, values)?;
164			}
165		}
166	}
167
168	for (id, channel) in channels.iter().enumerate() {
169		if !channel.is_balanced() {
170			return Err(VerificationError::ChannelUnbalanced { id }.into());
171		}
172	}
173
174	Ok(())
175}
176
177#[derive(Default, Debug, Clone)]
178struct Channel<F: TowerField> {
179	width: Option<usize>,
180	multiplicities: HashMap<Vec<F>, i64>,
181}
182
183impl<F: TowerField> Channel<F> {
184	fn new() -> Self {
185		Self::default()
186	}
187
188	fn _print_unbalanced_values(&self) {
189		for (key, val) in &self.multiplicities {
190			if *val != 0 {
191				println!("{key:?}: {val}");
192			}
193		}
194	}
195
196	fn flush(
197		&mut self,
198		direction: FlushDirection,
199		multiplicity: u64,
200		values: Vec<F>,
201	) -> Result<(), Error> {
202		if self.width.is_none() {
203			self.width = Some(values.len());
204		} else if self.width.expect("checked for None above") != values.len() {
205			return Err(Error::ChannelFlushWidthMismatch {
206				expected: self.width.unwrap(),
207				got: values.len(),
208			});
209		}
210		*self.multiplicities.entry(values).or_default() += (multiplicity as i64)
211			* (match direction {
212				FlushDirection::Pull => -1i64,
213				FlushDirection::Push => 1i64,
214			});
215		Ok(())
216	}
217
218	fn is_balanced(&self) -> bool {
219		self.multiplicities.iter().all(|(_, m)| *m == 0)
220	}
221}
222
223#[cfg(test)]
224mod tests {
225	use binius_field::BinaryField64b;
226
227	use super::*;
228
229	#[test]
230	fn test_flush_push_single_row() {
231		let mut channel = Channel::<BinaryField64b>::new();
232
233		// Push a single row of data
234		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
235		let result = channel.flush(FlushDirection::Push, 1, values.clone());
236
237		assert!(result.is_ok());
238		assert!(!channel.is_balanced());
239		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
240	}
241
242	#[test]
243	fn test_flush_pull_single_row() {
244		let mut channel = Channel::<BinaryField64b>::new();
245
246		// Pull a single row of data
247		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
248		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
249
250		assert!(result.is_ok());
251		assert!(!channel.is_balanced());
252		assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
253	}
254
255	#[test]
256	fn test_flush_push_pull_single_row() {
257		let mut channel = Channel::<BinaryField64b>::new();
258
259		// Push and then pull the same row
260		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
261		channel
262			.flush(FlushDirection::Push, 1, values.clone())
263			.unwrap();
264		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
265
266		assert!(result.is_ok());
267		assert!(channel.is_balanced());
268		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
269	}
270
271	#[test]
272	fn test_flush_multiplicity() {
273		let mut channel = Channel::<BinaryField64b>::new();
274
275		// Push multiple rows with a multiplicity of 2
276		let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
277		channel
278			.flush(FlushDirection::Push, 2, values.clone())
279			.unwrap();
280
281		// Pull the same row with a multiplicity of 1
282		channel
283			.flush(FlushDirection::Pull, 1, values.clone())
284			.unwrap();
285
286		// The channel should not be balanced yet
287		assert!(!channel.is_balanced());
288		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
289
290		// Pull the same row again with a multiplicity of 1
291		channel
292			.flush(FlushDirection::Pull, 1, values.clone())
293			.unwrap();
294
295		// Now the channel should be balanced
296		assert!(channel.is_balanced());
297		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
298	}
299
300	#[test]
301	fn test_flush_width_mismatch() {
302		let mut channel = Channel::<BinaryField64b>::new();
303
304		// Push a row with width 2
305		let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
306		channel.flush(FlushDirection::Push, 1, values1).unwrap();
307
308		// Attempt to push a row with width 3
309		let values2 = vec![
310			BinaryField64b::from(3),
311			BinaryField64b::from(4),
312			BinaryField64b::from(5),
313		];
314		let result = channel.flush(FlushDirection::Push, 1, values2);
315
316		assert!(result.is_err());
317		if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
318			assert_eq!(expected, 2);
319			assert_eq!(got, 3);
320		} else {
321			panic!("Expected ChannelFlushWidthMismatch error");
322		}
323	}
324
325	#[test]
326	fn test_flush_direction_effects() {
327		let mut channel = Channel::<BinaryField64b>::new();
328
329		// Push a row
330		let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
331		channel
332			.flush(FlushDirection::Push, 1, values.clone())
333			.unwrap();
334
335		// Pull a different row
336		let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
337		channel
338			.flush(FlushDirection::Pull, 1, values2.clone())
339			.unwrap();
340
341		// The channel should not be balanced because different rows were pushed and pulled
342		assert!(!channel.is_balanced());
343		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
344		assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
345	}
346}