binius_m3/builder/
indexed_lookup.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_core::constraint_system::channel::{Boundary, ChannelId, FlushDirection};
6use binius_field::{Field, PackedExtension, PackedField, TowerField};
7
8use super::{
9	B1, B8, B16, B32, B64, B128, constraint_system::ConstraintSystem, error::Error,
10	witness::WitnessIndex,
11};
12
13/// Indexed lookup tables are fixed-size tables where every entry is easily determined by its
14/// index.
15///
16/// Indexed lookup tables cover a large and useful class of tables, such as lookup tables for
17/// bitwise operations, addition of small integer values, multiplication, etc. The entry encodes
18/// an input and an output, where the index encodes the input. For example, a bitwise AND table
19/// would have 2 8-bit input values and one 8-bit output value. The index encodes the input by
20/// concatenating the 8-bit inputs into a 16-bit unsigned integer.
21///
22/// This trait helps to count the number of times a table, which is already filled, reads from a
23/// lookup table. See the documentation for [`tally`] for more information.
24pub trait IndexedLookup<F: TowerField> {
25	/// Binary logarithm of the number of table entries.
26	fn log_size(&self) -> usize;
27
28	/// Encode a table entry as a table index.
29	fn entry_to_index(&self, entry: &[F]) -> usize;
30
31	/// Decode a table index to an entry.
32	fn index_to_entry(&self, index: usize, entry: &mut [F]);
33}
34
35/// Determine the read counts of each entry in an indexed lookup table.
36///
37/// Before a lookup table witness can be filled, the number of times each entry is read must be
38/// known. Reads from indexed lookup tables are a special case where the counts are difficult to
39/// track during emulation, because the use of the lookup tables is an arithmetization detail. For
40/// example, emulation of the system model should not need to know whether integer additions within
41/// a table are constraint using zero constraints or a lookup table for the limbs. In most cases of
42/// practical interest, the lookup table is indexed.
43///
44/// The method to tally counts is to scan all tables in the constraint system and boundaries
45/// values, and identify those that pull from the lookup table's channel. Then we iterate over the
46/// values read from the table and count all the indices.
47///
48/// ## Returns
49///
50/// A vector of counts, whose length is equal to `1 << indexed_lookup.log_size()`.
51pub fn tally<P>(
52	cs: &ConstraintSystem<B128>,
53	// TODO: This doesn't actually need mutable access. But must of the WitnessIndex methods only
54	// allow mutable access.
55	witness: &mut WitnessIndex<P>,
56	boundaries: &[Boundary<B128>],
57	chan: ChannelId,
58	indexed_lookup: &impl IndexedLookup<B128>,
59) -> Result<Vec<u32>, Error>
60where
61	P: PackedField<Scalar = B128>
62		+ PackedExtension<B1>
63		+ PackedExtension<B8>
64		+ PackedExtension<B16>
65		+ PackedExtension<B32>
66		+ PackedExtension<B64>
67		+ PackedExtension<B128>,
68{
69	let mut counts = vec![0; 1 << indexed_lookup.log_size()];
70
71	// Tally counts from the tables
72	for table in &cs.tables {
73		if let Some(table_index) = witness.get_table(table.id()) {
74			for partition in table.partitions.values() {
75				for flush in &partition.flushes {
76					if flush.channel_id == chan && flush.direction == FlushDirection::Pull {
77						let table_size = table_index.size();
78						// TODO: This should be parallelized, which is pretty tricky.
79						let segment = table_index.full_segment();
80						let cols = flush
81							.columns
82							.iter()
83							.map(|&col_index| segment.get_dyn(col_index))
84							.collect::<Result<Vec<_>, _>>()?;
85
86						if !flush.selectors.is_empty() {
87							// TODO: check flush selectors
88							todo!("tally does not support selected table reads yet");
89						}
90
91						let mut elems = vec![B128::ZERO; cols.len()];
92						// It's important that this is only the unpacked table size(rows * values
93						// per row in the partition), not the full segment size. The entries
94						// after the table size are not flushed.
95						for i in 0..table_size * partition.values_per_row {
96							for (elem, col) in iter::zip(&mut elems, &cols) {
97								*elem = col.get(i);
98							}
99							let index = indexed_lookup.entry_to_index(&elems);
100							counts[index] += 1;
101						}
102					}
103				}
104			}
105		}
106	}
107
108	// Add in counts from boundaries
109	for boundary in boundaries {
110		if boundary.channel_id == chan && boundary.direction == FlushDirection::Pull {
111			let index = indexed_lookup.entry_to_index(&boundary.values);
112			counts[index] += 1;
113		}
114	}
115
116	Ok(counts)
117}