binius_m3/builder/indexed_lookup.rs
1// Copyright 2025 Irreducible Inc.
2
3use std::iter;
4
5use binius_core::constraint_system::channel::{Boundary, ChannelId, FlushDirection};
6use binius_field::{Field, PackedExtension, PackedField, TowerField};
7
8use super::{
9 B1, B8, B16, B32, B64, B128, constraint_system::ConstraintSystem, error::Error,
10 witness::WitnessIndex,
11};
12
13/// Indexed lookup tables are fixed-size tables where every entry is easily determined by its
14/// index.
15///
16/// Indexed lookup tables cover a large and useful class of tables, such as lookup tables for
17/// bitwise operations, addition of small integer values, multiplication, etc. The entry encodes
18/// an input and an output, where the index encodes the input. For example, a bitwise AND table
19/// would have 2 8-bit input values and one 8-bit output value. The index encodes the input by
20/// concatenating the 8-bit inputs into a 16-bit unsigned integer.
21///
22/// This trait helps to count the number of times a table, which is already filled, reads from a
23/// lookup table. See the documentation for [`tally`] for more information.
24pub trait IndexedLookup<F: TowerField> {
25 /// Binary logarithm of the number of table entries.
26 fn log_size(&self) -> usize;
27
28 /// Encode a table entry as a table index.
29 fn entry_to_index(&self, entry: &[F]) -> usize;
30
31 /// Decode a table index to an entry.
32 fn index_to_entry(&self, index: usize, entry: &mut [F]);
33}
34
35/// Determine the read counts of each entry in an indexed lookup table.
36///
37/// Before a lookup table witness can be filled, the number of times each entry is read must be
38/// known. Reads from indexed lookup tables are a special case where the counts are difficult to
39/// track during emulation, because the use of the lookup tables is an arithmetization detail. For
40/// example, emulation of the system model should not need to know whether integer additions within
41/// a table are constraint using zero constraints or a lookup table for the limbs. In most cases of
42/// practical interest, the lookup table is indexed.
43///
44/// The method to tally counts is to scan all tables in the constraint system and boundaries
45/// values, and identify those that pull from the lookup table's channel. Then we iterate over the
46/// values read from the table and count all the indices.
47///
48/// ## Returns
49///
50/// A vector of counts, whose length is equal to `1 << indexed_lookup.log_size()`.
51pub fn tally<P>(
52 cs: &ConstraintSystem<B128>,
53 // TODO: This doesn't actually need mutable access. But must of the WitnessIndex methods only
54 // allow mutable access.
55 witness: &mut WitnessIndex<P>,
56 boundaries: &[Boundary<B128>],
57 chan: ChannelId,
58 indexed_lookup: &impl IndexedLookup<B128>,
59) -> Result<Vec<u32>, Error>
60where
61 P: PackedField<Scalar = B128>
62 + PackedExtension<B1>
63 + PackedExtension<B8>
64 + PackedExtension<B16>
65 + PackedExtension<B32>
66 + PackedExtension<B64>
67 + PackedExtension<B128>,
68{
69 let mut counts = vec![0; 1 << indexed_lookup.log_size()];
70
71 // Tally counts from the tables
72 for table in &cs.tables {
73 if let Some(table_index) = witness.get_table(table.id()) {
74 for partition in table.partitions.values() {
75 for flush in &partition.flushes {
76 if flush.channel_id == chan && flush.direction == FlushDirection::Pull {
77 let table_size = table_index.size();
78 // TODO: This should be parallelized, which is pretty tricky.
79 let segment = table_index.full_segment();
80 let cols = flush
81 .columns
82 .iter()
83 .map(|&col_index| segment.get_dyn(col_index))
84 .collect::<Result<Vec<_>, _>>()?;
85
86 if !flush.selectors.is_empty() {
87 // TODO: check flush selectors
88 todo!("tally does not support selected table reads yet");
89 }
90
91 let mut elems = vec![B128::ZERO; cols.len()];
92 // It's important that this is only the unpacked table size(rows * values
93 // per row in the partition), not the full segment size. The entries
94 // after the table size are not flushed.
95 for i in 0..table_size * partition.values_per_row {
96 for (elem, col) in iter::zip(&mut elems, &cols) {
97 *elem = col.get(i);
98 }
99 let index = indexed_lookup.entry_to_index(&elems);
100 counts[index] += 1;
101 }
102 }
103 }
104 }
105 }
106 }
107
108 // Add in counts from boundaries
109 for boundary in boundaries {
110 if boundary.channel_id == chan && boundary.direction == FlushDirection::Pull {
111 let index = indexed_lookup.entry_to_index(&boundary.values);
112 counts[index] += 1;
113 }
114 }
115
116 Ok(counts)
117}