binius_m3/gadgets/
lookup.rs1use anyhow::Result;
4use binius_core::constraint_system::channel::ChannelId;
5use binius_field::{ExtensionField, PackedExtension, PackedField, PackedSubfield, TowerField};
6use itertools::Itertools;
7
8use crate::builder::{Col, FlushOpts, TableBuilder, TableWitnessSegment, B1, B128};
9
10#[derive(Debug)]
16pub struct LookupProducer {
17 multiplicity_bits: Vec<Col<B1>>,
18}
19
20impl LookupProducer {
21 pub fn new<FSub>(
22 table: &mut TableBuilder,
23 chan: ChannelId,
24 value_cols: &[Col<FSub>],
25 n_multiplicity_bits: usize,
26 ) -> Self
27 where
28 B128: ExtensionField<FSub>,
29 FSub: TowerField,
30 {
31 let multiplicity_bits = (0..n_multiplicity_bits)
32 .map(|i| table.add_committed::<B1, 1>(format!("multiplicity_bits[{i}]")))
33 .collect::<Vec<_>>();
34
35 for (i, &multiplicity_col) in multiplicity_bits.iter().enumerate() {
36 table.push_with_opts(
37 chan,
38 value_cols.iter().copied(),
39 FlushOpts {
40 multiplicity: 1 << i,
41 selector: Some(multiplicity_col),
42 },
43 );
44 }
45
46 Self { multiplicity_bits }
47 }
48
49 pub fn populate<P>(
55 &self,
56 index: &mut TableWitnessSegment<P>,
57 counts: impl Iterator<Item = u32> + Clone,
58 ) -> Result<(), anyhow::Error>
59 where
60 P: PackedExtension<B1>,
61 P::Scalar: TowerField,
62 {
63 for (j, &multiplicity_col) in self.multiplicity_bits.iter().enumerate().take(32) {
65 let mut multiplicity_col = index.get_mut(multiplicity_col)?;
66 for (packed, counts) in multiplicity_col
67 .iter_mut()
68 .zip(&counts.clone().chunks(<PackedSubfield<P, B1>>::WIDTH))
69 {
70 for (i, count) in counts.enumerate() {
71 packed.set(i, B1::from((count >> j) & 1 == 1))
72 }
73 }
74 }
75 Ok(())
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use std::{cmp::Reverse, iter, iter::repeat_with};
82
83 use binius_field::{arch::OptimalUnderlier128b, as_packed_field::PackedType};
84 use bumpalo::Bump;
85 use rand::{rngs::StdRng, Rng, SeedableRng};
86
87 use super::*;
88 use crate::builder::{test_utils::ClosureFiller, ConstraintSystem, Statement};
89
90 #[test]
91 fn test_basic_lookup_producer() {
92 let mut cs = ConstraintSystem::new();
93 let chan = cs.add_channel("values");
94
95 let mut lookup_table = cs.add_table("lookup");
96 let lookup_table_id = lookup_table.id();
97 let values_col = lookup_table.add_committed::<B128, 1>("values");
98 let lookup_producer = LookupProducer::new(&mut lookup_table, chan, &[values_col], 8);
99
100 let mut looker_1 = cs.add_table("looker 1");
101 let looker_1_id = looker_1.id();
102 let looker_1_vals = looker_1.add_committed::<B128, 1>("values");
103 looker_1.pull(chan, [looker_1_vals]);
104
105 let mut looker_2 = cs.add_table("looker 2");
106 let looker_2_id = looker_2.id();
107 let looker_2_vals = looker_2.add_committed::<B128, 1>("values");
108 looker_2.pull(chan, [looker_2_vals]);
109
110 let lookup_table_size = 45;
111 let mut rng = StdRng::seed_from_u64(0);
112 let values = repeat_with(|| B128::random(&mut rng))
113 .take(lookup_table_size)
114 .collect::<Vec<_>>();
115
116 let mut counts = vec![0u32; lookup_table_size];
117
118 let looker_1_size = 56;
119 let inputs_1 = repeat_with(|| {
120 let index = rng.gen_range(0..lookup_table_size);
121 counts[index] += 1;
122 values[index]
123 })
124 .take(looker_1_size)
125 .collect::<Vec<_>>();
126
127 let looker_2_size = 67;
128 let inputs_2 = repeat_with(|| {
129 let index = rng.gen_range(0..lookup_table_size);
130 counts[index] += 1;
131 values[index]
132 })
133 .take(looker_2_size)
134 .collect::<Vec<_>>();
135
136 let values_and_counts = iter::zip(values, counts)
137 .sorted_unstable_by_key(|&(_val, count)| Reverse(count))
138 .collect::<Vec<_>>();
139
140 let statement = Statement {
141 boundaries: vec![],
142 table_sizes: vec![lookup_table_size, looker_1_size, looker_2_size],
143 };
144 let allocator = Bump::new();
145 let mut witness = cs
146 .build_witness::<PackedType<OptimalUnderlier128b, B128>>(&allocator, &statement)
147 .unwrap();
148
149 witness
151 .fill_table_sequential(
152 &ClosureFiller::new(lookup_table_id, |values_and_counts, witness| {
153 {
154 let mut values_col = witness.get_scalars_mut(values_col)?;
155 for (dst, (val, _)) in iter::zip(&mut *values_col, values_and_counts) {
156 *dst = *val;
157 }
158 }
159 lookup_producer
160 .populate(witness, values_and_counts.iter().map(|(_, count)| *count))?;
161 Ok(())
162 }),
163 &values_and_counts,
164 )
165 .unwrap();
166
167 witness
169 .fill_table_sequential(
170 &ClosureFiller::new(looker_1_id, |inputs_1, witness| {
171 let mut looker_1_vals = witness.get_scalars_mut(looker_1_vals)?;
172 for (dst, src) in iter::zip(&mut *looker_1_vals, inputs_1) {
173 *dst = **src;
174 }
175 Ok(())
176 }),
177 &inputs_1,
178 )
179 .unwrap();
180
181 witness
182 .fill_table_sequential(
183 &ClosureFiller::new(looker_2_id, |inputs_2, witness| {
184 let mut looker_2_vals = witness.get_scalars_mut(looker_2_vals)?;
185 for (dst, src) in iter::zip(&mut *looker_2_vals, inputs_2) {
186 *dst = **src;
187 }
188 Ok(())
189 }),
190 &inputs_2,
191 )
192 .unwrap();
193
194 let ccs = cs.compile(&statement).unwrap();
195 let witness = witness.into_multilinear_extension_index();
196
197 binius_core::constraint_system::validate::validate_witness(&ccs, &[], &witness).unwrap();
198 }
199}