binius_core/constraint_system/
channel.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! A channel allows communication between tables.
4//!
5//! Note that the channel is unordered - meaning that rows are not
6//! constrained to be in the same order when being pushed and pulled.
7//!
8//! The number of columns per channel must be fixed, but can be any
9//! positive integer. Column order is guaranteed, and column values within
10//! the same row must always stay together.
11//!
12//! A channel only ensures that the inputs and outputs match, using a
13//! multiset check. If you want any kind of ordering, you have to
14//! use polynomial constraints to additionally constraint this.
15//!
16//! The example below shows a channel with width=2, with multiple inputs
17//! and outputs.
18//! ```txt
19//!                                       +-+-+
20//!                                       |C|D|
21//! +-+-+                           +---> +-+-+
22//! |A|B|                           |     |M|N|
23//! +-+-+                           |     +-+-+
24//! |C|D|                           |
25//! +-+-+  --+                      |     +-+-+
26//! |E|F|    |                      |     |I|J|
27//! +-+-+    |                      |     +-+-+
28//! |G|H|    |                      |     |W|X|
29//! +-+-+    |                      | +-> +-+-+
30//!          |                      | |   |A|B|
31//! +-+-+    +-> /¯\¯¯¯¯¯¯¯¯¯¯¯\  --+ |   +-+-+
32//! |I|J|       :   :           : ----+   |K|L|
33//! +-+-+  PUSH |   |  channel  |  PULL   +-+-+
34//! |K|L|       :   :           : ----+
35//! +-+-+    +-> \_/___________/  --+ |   +-+-+
36//! |M|N|    |                      | |   |U|V|
37//! +-+-+    |                      | |   +-+-+
38//! |O|P|    |                      | |   |G|H|
39//! +-+-+  --+                      | +-> +-+-+
40//! |Q|R|                           |     |E|F|
41//! +-+-+                           |     +-+-+
42//! |S|T|                           |     |Q|R|
43//! +-+-+                           |     +-+-+
44//! |U|V|                           |
45//! +-+-+                           |     +-+-+
46//! |W|X|                           |     |O|P|
47//! +-+-+                           +---> +-+-+
48//!                                       |S|T|
49//!                                       +-+-+
50//! ```
51
52use std::collections::HashMap;
53
54use binius_field::{Field, PackedField, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56use binius_math::MultilinearPoly;
57use itertools::izip;
58
59use super::error::{Error, VerificationError};
60use crate::{constraint_system::TableId, oracle::OracleId, witness::MultilinearExtensionIndex};
61
62pub type ChannelId = usize;
63
64#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
65pub enum OracleOrConst<F: Field> {
66	Oracle(OracleId),
67	Const { base: F, tower_level: usize },
68}
69#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
70pub struct Flush<F: TowerField> {
71	pub table_id: TableId,
72	pub log_values_per_row: usize,
73	pub oracles: Vec<OracleOrConst<F>>,
74	pub channel_id: ChannelId,
75	pub direction: FlushDirection,
76	pub selectors: Vec<OracleId>,
77	pub multiplicity: u64,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
81pub struct Boundary<F: TowerField> {
82	pub values: Vec<F>,
83	pub channel_id: ChannelId,
84	pub direction: FlushDirection,
85	pub multiplicity: u64,
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
89pub enum FlushDirection {
90	Push,
91	Pull,
92}
93
94pub fn validate_witness<F, P>(
95	witness: &MultilinearExtensionIndex<P>,
96	flushes: &[Flush<F>],
97	boundaries: &[Boundary<F>],
98	table_sizes: &[usize],
99	channel_count: usize,
100) -> Result<(), Error>
101where
102	P: PackedField<Scalar = F>,
103	F: TowerField,
104{
105	let mut channels = vec![Channel::<F>::new(); channel_count];
106	let max_channel_id = channel_count.saturating_sub(1);
107
108	for boundary in boundaries.iter().cloned() {
109		let Boundary {
110			channel_id,
111			values,
112			direction,
113			multiplicity,
114		} = boundary;
115		if channel_id > max_channel_id {
116			return Err(Error::ChannelIdOutOfRange {
117				max: max_channel_id,
118				got: channel_id,
119			});
120		}
121		channels[channel_id].flush(direction, multiplicity, values.clone())?;
122	}
123
124	for flush in flushes {
125		let &Flush {
126			ref oracles,
127			channel_id,
128			direction,
129			ref selectors,
130			multiplicity,
131			table_id,
132			log_values_per_row,
133		} = flush;
134
135		if channel_id > max_channel_id {
136			return Err(Error::ChannelIdOutOfRange {
137				max: max_channel_id,
138				got: channel_id,
139			});
140		}
141
142		let table_size = table_sizes[table_id];
143		if table_size == 0 {
144			continue;
145		}
146
147		// We check the variables only of OracleOrConst::Oracle variant oracles being the same.
148		let non_const_polys = oracles
149			.iter()
150			.filter_map(|&id| match id {
151				OracleOrConst::Oracle(oracle_id) => Some(witness.get_multilin_poly(oracle_id)),
152				_ => None,
153			})
154			.collect::<Result<Vec<_>, _>>()?;
155
156		let selector_polys = selectors
157			.iter()
158			.map(|selector| witness.get_multilin_poly(*selector))
159			.collect::<Result<Vec<_>, _>>()?;
160
161		let n_vars = non_const_polys
162			.first()
163			.map(|poly| poly.n_vars())
164			.unwrap_or(0);
165
166		// Ensure that all the polys in a single flush have the same n_vars
167		for poly in &non_const_polys {
168			if poly.n_vars() != n_vars {
169				return Err(Error::ChannelFlushNvarsMismatch {
170					expected: n_vars,
171					got: poly.n_vars(),
172				});
173			}
174		}
175
176		// Check selector polynomials are compatible
177		for (&selector, selector_poly) in izip!(selectors, &selector_polys) {
178			if selector_poly.n_vars() != n_vars {
179				let id = oracles
180					.iter()
181					.copied()
182					.filter_map(|id| match id {
183						OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
184						_ => None,
185					})
186					.next()
187					.expect("non_const_polys is not empty");
188				return Err(Error::IncompatibleFlushSelector { id, selector });
189			}
190		}
191
192		let values_per_row = 1 << log_values_per_row;
193		assert!(table_size * values_per_row <= 1 << n_vars);
194		for i in 0..table_size * values_per_row {
195			let selector_off = selector_polys.iter().any(|selector_poly| {
196				selector_poly
197					.evaluate_on_hypercube(i)
198					.expect(
199						"i in range 0..1 << n_vars; \
200							selector_poly checked above to have n_vars variables",
201					)
202					.is_zero()
203			});
204
205			if selector_off {
206				continue;
207			}
208
209			let values = oracles
210				.iter()
211				.copied()
212				.map(|id| match id {
213					OracleOrConst::Const { base, .. } => Ok(base),
214					OracleOrConst::Oracle(oracle_id) => witness
215						.get_multilin_poly(oracle_id)
216						.expect("Witness error would have been caught while checking variables.")
217						.evaluate_on_hypercube(i),
218				})
219				.collect::<Result<Vec<_>, _>>()?;
220			channels[channel_id].flush(direction, multiplicity, values)?;
221		}
222	}
223
224	for (id, channel) in channels.iter().enumerate() {
225		if !channel.is_balanced() {
226			let unbalanced_flushes: Vec<_> = channel
227				.multiplicities
228				.iter()
229				.filter(|(_, c)| **c != 0i64)
230				.collect();
231
232			tracing::debug!("Channel {:?} unbalanced: {:?}", id, unbalanced_flushes);
233
234			return Err((VerificationError::ChannelUnbalanced { id }).into());
235		}
236	}
237
238	Ok(())
239}
240
241#[derive(Default, Debug, Clone)]
242struct Channel<F: TowerField> {
243	width: Option<usize>,
244	multiplicities: HashMap<Vec<F>, i64>,
245}
246
247impl<F: TowerField> Channel<F> {
248	fn new() -> Self {
249		Self::default()
250	}
251
252	fn _print_unbalanced_values(&self) {
253		for (key, val) in &self.multiplicities {
254			if *val != 0 {
255				println!("{key:?}: {val}");
256			}
257		}
258	}
259
260	fn flush(
261		&mut self,
262		direction: FlushDirection,
263		multiplicity: u64,
264		values: Vec<F>,
265	) -> Result<(), Error> {
266		if self.width.is_none() {
267			self.width = Some(values.len());
268		} else if self.width.expect("checked for None above") != values.len() {
269			return Err(Error::ChannelFlushWidthMismatch {
270				expected: self.width.unwrap(),
271				got: values.len(),
272			});
273		}
274		*self.multiplicities.entry(values).or_default() += (multiplicity as i64)
275			* (match direction {
276				FlushDirection::Pull => -1i64,
277				FlushDirection::Push => 1i64,
278			});
279		Ok(())
280	}
281
282	fn is_balanced(&self) -> bool {
283		self.multiplicities.iter().all(|(_, m)| *m == 0)
284	}
285}
286
287#[cfg(test)]
288mod tests {
289	use binius_field::BinaryField64b;
290
291	use super::*;
292
293	#[test]
294	fn test_flush_push_single_row() {
295		let mut channel = Channel::<BinaryField64b>::new();
296
297		// Push a single row of data
298		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
299		let result = channel.flush(FlushDirection::Push, 1, values.clone());
300
301		assert!(result.is_ok());
302		assert!(!channel.is_balanced());
303		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
304	}
305
306	#[test]
307	fn test_flush_pull_single_row() {
308		let mut channel = Channel::<BinaryField64b>::new();
309
310		// Pull a single row of data
311		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
312		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
313
314		assert!(result.is_ok());
315		assert!(!channel.is_balanced());
316		assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
317	}
318
319	#[test]
320	fn test_flush_push_pull_single_row() {
321		let mut channel = Channel::<BinaryField64b>::new();
322
323		// Push and then pull the same row
324		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
325		channel
326			.flush(FlushDirection::Push, 1, values.clone())
327			.unwrap();
328		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
329
330		assert!(result.is_ok());
331		assert!(channel.is_balanced());
332		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
333	}
334
335	#[test]
336	fn test_flush_multiplicity() {
337		let mut channel = Channel::<BinaryField64b>::new();
338
339		// Push multiple rows with a multiplicity of 2
340		let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
341		channel
342			.flush(FlushDirection::Push, 2, values.clone())
343			.unwrap();
344
345		// Pull the same row with a multiplicity of 1
346		channel
347			.flush(FlushDirection::Pull, 1, values.clone())
348			.unwrap();
349
350		// The channel should not be balanced yet
351		assert!(!channel.is_balanced());
352		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
353
354		// Pull the same row again with a multiplicity of 1
355		channel
356			.flush(FlushDirection::Pull, 1, values.clone())
357			.unwrap();
358
359		// Now the channel should be balanced
360		assert!(channel.is_balanced());
361		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
362	}
363
364	#[test]
365	fn test_flush_width_mismatch() {
366		let mut channel = Channel::<BinaryField64b>::new();
367
368		// Push a row with width 2
369		let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
370		channel.flush(FlushDirection::Push, 1, values1).unwrap();
371
372		// Attempt to push a row with width 3
373		let values2 = vec![
374			BinaryField64b::from(3),
375			BinaryField64b::from(4),
376			BinaryField64b::from(5),
377		];
378		let result = channel.flush(FlushDirection::Push, 1, values2);
379
380		assert!(result.is_err());
381		if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
382			assert_eq!(expected, 2);
383			assert_eq!(got, 3);
384		} else {
385			panic!("Expected ChannelFlushWidthMismatch error");
386		}
387	}
388
389	#[test]
390	fn test_flush_direction_effects() {
391		let mut channel = Channel::<BinaryField64b>::new();
392
393		// Push a row
394		let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
395		channel
396			.flush(FlushDirection::Push, 1, values.clone())
397			.unwrap();
398
399		// Pull a different row
400		let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
401		channel
402			.flush(FlushDirection::Pull, 1, values2.clone())
403			.unwrap();
404
405		// The channel should not be balanced because different rows were pushed and pulled
406		assert!(!channel.is_balanced());
407		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
408		assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
409	}
410}