binius_core/constraint_system/
channel.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3//! A channel allows communication between tables.
4//!
5//! Note that the channel is unordered - meaning that rows are not
6//! constrained to be in the same order when being pushed and pulled.
7//!
8//! The number of columns per channel must be fixed, but can be any
9//! positive integer. Column order is guaranteed, and column values within
10//! the same row must always stay together.
11//!
12//! A channel only ensures that the inputs and outputs match, using a
13//! multiset check. If you want any kind of ordering, you have to
14//! use polynomial constraints to additionally constraint this.
15//!
16//! The example below shows a channel with width=2, with multiple inputs
17//! and outputs.
18//! ```txt
19//!                                       +-+-+
20//!                                       |C|D|
21//! +-+-+                           +---> +-+-+
22//! |A|B|                           |     |M|N|
23//! +-+-+                           |     +-+-+
24//! |C|D|                           |
25//! +-+-+  --+                      |     +-+-+
26//! |E|F|    |                      |     |I|J|
27//! +-+-+    |                      |     +-+-+
28//! |G|H|    |                      |     |W|X|
29//! +-+-+    |                      | +-> +-+-+
30//!          |                      | |   |A|B|
31//! +-+-+    +-> /¯\¯¯¯¯¯¯¯¯¯¯¯\  --+ |   +-+-+
32//! |I|J|       :   :           : ----+   |K|L|
33//! +-+-+  PUSH |   |  channel  |  PULL   +-+-+
34//! |K|L|       :   :           : ----+
35//! +-+-+    +-> \_/___________/  --+ |   +-+-+
36//! |M|N|    |                      | |   |U|V|
37//! +-+-+    |                      | |   +-+-+
38//! |O|P|    |                      | |   |G|H|
39//! +-+-+  --+                      | +-> +-+-+
40//! |Q|R|                           |     |E|F|
41//! +-+-+                           |     +-+-+
42//! |S|T|                           |     |Q|R|
43//! +-+-+                           |     +-+-+
44//! |U|V|                           |
45//! +-+-+                           |     +-+-+
46//! |W|X|                           |     |O|P|
47//! +-+-+                           +---> +-+-+
48//!                                       |S|T|
49//!                                       +-+-+
50//! ```
51
52use std::collections::HashMap;
53
54use binius_field::{Field, PackedField, TowerField};
55use binius_macros::{DeserializeBytes, SerializeBytes};
56use binius_math::MultilinearPoly;
57use itertools::izip;
58
59use super::error::{Error, VerificationError};
60use crate::{oracle::OracleId, witness::MultilinearExtensionIndex};
61
62pub type ChannelId = usize;
63
64#[derive(Debug, Clone, Copy, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
65pub enum OracleOrConst<F: Field> {
66	Oracle(usize),
67	Const { base: F, tower_level: usize },
68}
69#[derive(Debug, Clone, SerializeBytes, DeserializeBytes)]
70pub struct Flush<F: TowerField> {
71	pub oracles: Vec<OracleOrConst<F>>,
72	pub channel_id: ChannelId,
73	pub direction: FlushDirection,
74	pub selectors: Vec<OracleId>,
75	pub multiplicity: u64,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
79pub struct Boundary<F: TowerField> {
80	pub values: Vec<F>,
81	pub channel_id: ChannelId,
82	pub direction: FlushDirection,
83	pub multiplicity: u64,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
87pub enum FlushDirection {
88	Push,
89	Pull,
90}
91
92pub fn validate_witness<F, P>(
93	witness: &MultilinearExtensionIndex<P>,
94	flushes: &[Flush<F>],
95	boundaries: &[Boundary<F>],
96	max_channel_id: ChannelId,
97) -> Result<(), Error>
98where
99	P: PackedField<Scalar = F>,
100	F: TowerField,
101{
102	let mut channels = vec![Channel::<F>::new(); max_channel_id + 1];
103
104	for boundary in boundaries.iter().cloned() {
105		let Boundary {
106			channel_id,
107			values,
108			direction,
109			multiplicity,
110		} = boundary;
111		if channel_id > max_channel_id {
112			return Err(Error::ChannelIdOutOfRange {
113				max: max_channel_id,
114				got: channel_id,
115			});
116		}
117		channels[channel_id].flush(direction, multiplicity, values.clone())?;
118	}
119
120	for flush in flushes {
121		let &Flush {
122			ref oracles,
123			channel_id,
124			direction,
125			ref selectors,
126			multiplicity,
127		} = flush;
128
129		if channel_id > max_channel_id {
130			return Err(Error::ChannelIdOutOfRange {
131				max: max_channel_id,
132				got: channel_id,
133			});
134		}
135
136		let channel = &mut channels[channel_id];
137
138		// We check the variables only of OracleOrConst::Oracle variant oracles being the same.
139		let non_const_polys = oracles
140			.iter()
141			.filter_map(|&id| match id {
142				OracleOrConst::Oracle(oracle_id) => Some(witness.get_multilin_poly(oracle_id)),
143				_ => None,
144			})
145			.collect::<Result<Vec<_>, _>>()?;
146
147		let selector_polys = selectors
148			.iter()
149			.map(|selector| witness.get_multilin_poly(*selector))
150			.collect::<Result<Vec<_>, _>>()?;
151
152		let n_vars = non_const_polys
153			.first()
154			.map(|poly| poly.n_vars())
155			.unwrap_or(0);
156
157		// Ensure that all the polys in a single flush have the same n_vars
158		for poly in &non_const_polys {
159			if poly.n_vars() != n_vars {
160				return Err(Error::ChannelFlushNvarsMismatch {
161					expected: n_vars,
162					got: poly.n_vars(),
163				});
164			}
165		}
166
167		// Check selector polynomials are compatible
168		for (&selector, selector_poly) in izip!(selectors, &selector_polys) {
169			if selector_poly.n_vars() != n_vars {
170				let id = oracles
171					.iter()
172					.copied()
173					.filter_map(|id| match id {
174						OracleOrConst::Oracle(oracle_id) => Some(oracle_id),
175						_ => None,
176					})
177					.next()
178					.expect("non_const_polys is not empty");
179				return Err(Error::IncompatibleFlushSelector { id, selector });
180			}
181		}
182
183		for i in 0..1 << n_vars {
184			let selector_off = selector_polys.iter().any(|selector_poly| {
185				selector_poly
186					.evaluate_on_hypercube(i)
187					.expect(
188						"i in range 0..1 << n_vars; \
189							selector_poly checked above to have n_vars variables",
190					)
191					.is_zero()
192			});
193
194			if selector_off {
195				continue;
196			}
197
198			let values = oracles
199				.iter()
200				.copied()
201				.map(|id| match id {
202					OracleOrConst::Const { base, .. } => Ok(base),
203					OracleOrConst::Oracle(oracle_id) => witness
204						.get_multilin_poly(oracle_id)
205						.expect("Witness error would have been caught while checking variables.")
206						.evaluate_on_hypercube(i),
207				})
208				.collect::<Result<Vec<_>, _>>()?;
209			channel.flush(direction, multiplicity, values)?;
210		}
211	}
212
213	for (id, channel) in channels.iter().enumerate() {
214		if !channel.is_balanced() {
215			let unbalanced_flushes: Vec<_> = channel
216				.multiplicities
217				.iter()
218				.filter(|(_, &c)| c != 0i64)
219				.collect();
220
221			tracing::debug!("Channel {:?} unbalanced: {:?}", id, unbalanced_flushes);
222
223			return Err((VerificationError::ChannelUnbalanced { id }).into());
224		}
225	}
226
227	Ok(())
228}
229
230#[derive(Default, Debug, Clone)]
231struct Channel<F: TowerField> {
232	width: Option<usize>,
233	multiplicities: HashMap<Vec<F>, i64>,
234}
235
236impl<F: TowerField> Channel<F> {
237	fn new() -> Self {
238		Self::default()
239	}
240
241	fn _print_unbalanced_values(&self) {
242		for (key, val) in &self.multiplicities {
243			if *val != 0 {
244				println!("{key:?}: {val}");
245			}
246		}
247	}
248
249	fn flush(
250		&mut self,
251		direction: FlushDirection,
252		multiplicity: u64,
253		values: Vec<F>,
254	) -> Result<(), Error> {
255		if self.width.is_none() {
256			self.width = Some(values.len());
257		} else if self.width.expect("checked for None above") != values.len() {
258			return Err(Error::ChannelFlushWidthMismatch {
259				expected: self.width.unwrap(),
260				got: values.len(),
261			});
262		}
263		*self.multiplicities.entry(values).or_default() += (multiplicity as i64)
264			* (match direction {
265				FlushDirection::Pull => -1i64,
266				FlushDirection::Push => 1i64,
267			});
268		Ok(())
269	}
270
271	fn is_balanced(&self) -> bool {
272		self.multiplicities.iter().all(|(_, m)| *m == 0)
273	}
274}
275
276#[cfg(test)]
277mod tests {
278	use binius_field::BinaryField64b;
279
280	use super::*;
281
282	#[test]
283	fn test_flush_push_single_row() {
284		let mut channel = Channel::<BinaryField64b>::new();
285
286		// Push a single row of data
287		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
288		let result = channel.flush(FlushDirection::Push, 1, values.clone());
289
290		assert!(result.is_ok());
291		assert!(!channel.is_balanced());
292		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
293	}
294
295	#[test]
296	fn test_flush_pull_single_row() {
297		let mut channel = Channel::<BinaryField64b>::new();
298
299		// Pull a single row of data
300		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
301		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
302
303		assert!(result.is_ok());
304		assert!(!channel.is_balanced());
305		assert_eq!(channel.multiplicities.get(&values).unwrap(), &-1);
306	}
307
308	#[test]
309	fn test_flush_push_pull_single_row() {
310		let mut channel = Channel::<BinaryField64b>::new();
311
312		// Push and then pull the same row
313		let values = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
314		channel
315			.flush(FlushDirection::Push, 1, values.clone())
316			.unwrap();
317		let result = channel.flush(FlushDirection::Pull, 1, values.clone());
318
319		assert!(result.is_ok());
320		assert!(channel.is_balanced());
321		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
322	}
323
324	#[test]
325	fn test_flush_multiplicity() {
326		let mut channel = Channel::<BinaryField64b>::new();
327
328		// Push multiple rows with a multiplicity of 2
329		let values = vec![BinaryField64b::from(3), BinaryField64b::from(4)];
330		channel
331			.flush(FlushDirection::Push, 2, values.clone())
332			.unwrap();
333
334		// Pull the same row with a multiplicity of 1
335		channel
336			.flush(FlushDirection::Pull, 1, values.clone())
337			.unwrap();
338
339		// The channel should not be balanced yet
340		assert!(!channel.is_balanced());
341		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
342
343		// Pull the same row again with a multiplicity of 1
344		channel
345			.flush(FlushDirection::Pull, 1, values.clone())
346			.unwrap();
347
348		// Now the channel should be balanced
349		assert!(channel.is_balanced());
350		assert_eq!(channel.multiplicities.get(&values).unwrap_or(&0), &0);
351	}
352
353	#[test]
354	fn test_flush_width_mismatch() {
355		let mut channel = Channel::<BinaryField64b>::new();
356
357		// Push a row with width 2
358		let values1 = vec![BinaryField64b::from(1), BinaryField64b::from(2)];
359		channel.flush(FlushDirection::Push, 1, values1).unwrap();
360
361		// Attempt to push a row with width 3
362		let values2 = vec![
363			BinaryField64b::from(3),
364			BinaryField64b::from(4),
365			BinaryField64b::from(5),
366		];
367		let result = channel.flush(FlushDirection::Push, 1, values2);
368
369		assert!(result.is_err());
370		if let Err(Error::ChannelFlushWidthMismatch { expected, got }) = result {
371			assert_eq!(expected, 2);
372			assert_eq!(got, 3);
373		} else {
374			panic!("Expected ChannelFlushWidthMismatch error");
375		}
376	}
377
378	#[test]
379	fn test_flush_direction_effects() {
380		let mut channel = Channel::<BinaryField64b>::new();
381
382		// Push a row
383		let values = vec![BinaryField64b::from(7), BinaryField64b::from(8)];
384		channel
385			.flush(FlushDirection::Push, 1, values.clone())
386			.unwrap();
387
388		// Pull a different row
389		let values2 = vec![BinaryField64b::from(9), BinaryField64b::from(10)];
390		channel
391			.flush(FlushDirection::Pull, 1, values2.clone())
392			.unwrap();
393
394		// The channel should not be balanced because different rows were pushed and pulled
395		assert!(!channel.is_balanced());
396		assert_eq!(channel.multiplicities.get(&values).unwrap(), &1);
397		assert_eq!(channel.multiplicities.get(&values2).unwrap(), &-1);
398	}
399}