1use anyhow::{Result, ensure};
4use binius_core::constraint_system::channel::ChannelId;
5use binius_field::{ExtensionField, PackedExtension, PackedField, PackedSubfield, TowerField};
6use itertools::Itertools;
7
8use crate::builder::{B1, B128, Col, FlushOpts, TableBuilder, TableWitnessSegment};
9
10#[derive(Debug)]
16pub struct LookupProducer {
17 multiplicity_bits: Vec<Col<B1>>,
18}
19
20impl LookupProducer {
21 pub fn new<FSub>(
22 table: &mut TableBuilder,
23 chan: ChannelId,
24 value_cols: &[Col<FSub>],
25 n_multiplicity_bits: usize,
26 ) -> Self
27 where
28 B128: ExtensionField<FSub>,
29 FSub: TowerField,
30 {
31 let multiplicity_bits = (0..n_multiplicity_bits)
32 .map(|i| table.add_committed::<B1, 1>(format!("multiplicity_bits[{i}]")))
33 .collect::<Vec<_>>();
34
35 for (i, &multiplicity_col) in multiplicity_bits.iter().enumerate() {
36 table.push_with_opts(
37 chan,
38 value_cols.iter().copied(),
39 FlushOpts {
40 multiplicity: 1 << i,
41 selectors: vec![multiplicity_col],
42 },
43 );
44 }
45
46 Self { multiplicity_bits }
47 }
48
49 pub fn populate<P>(
55 &self,
56 index: &mut TableWitnessSegment<P>,
57 counts: impl Iterator<Item = u32> + Clone,
58 ) -> Result<(), anyhow::Error>
59 where
60 P: PackedExtension<B1>,
61 P::Scalar: TowerField,
62 {
63 if self.multiplicity_bits.len() < u32::BITS as usize {
64 for count in counts.clone() {
65 ensure!(
66 count < (1 << self.multiplicity_bits.len()) as u32,
67 "count {count} exceeds maximum configured multiplicity; \
68 try raising the multiplicity bits in the constraint system"
69 );
70 }
71 }
72
73 for (j, &multiplicity_col) in self.multiplicity_bits.iter().enumerate().take(32) {
75 let mut multiplicity_col = index.get_mut(multiplicity_col)?;
76 for (packed, counts) in multiplicity_col
77 .iter_mut()
78 .zip(&counts.clone().chunks(<PackedSubfield<P, B1>>::WIDTH))
79 {
80 for (i, count) in counts.enumerate() {
81 packed.set(i, B1::from((count >> j) & 1 == 1))
82 }
83 }
84 }
85 Ok(())
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use std::{cmp::Reverse, iter, iter::repeat_with};
92
93 use binius_compute::cpu::alloc::CpuComputeAllocator;
94 use binius_field::{arch::OptimalUnderlier128b, as_packed_field::PackedType};
95 use rand::{Rng, SeedableRng, rngs::StdRng};
96
97 use super::*;
98 use crate::builder::{
99 ConstraintSystem, WitnessIndex,
100 test_utils::{ClosureFiller, validate_system_witness},
101 };
102
103 fn with_lookup_test_instance(
104 no_zero_counts: bool,
105 f: impl FnOnce(&ConstraintSystem<B128>, WitnessIndex<PackedType<OptimalUnderlier128b, B128>>),
106 ) {
107 let mut cs = ConstraintSystem::new();
108 let chan = cs.add_channel("values");
109
110 let mut lookup_table = cs.add_table("lookup");
111 lookup_table.require_power_of_two_size();
112 let lookup_table_id = lookup_table.id();
113 let values_col = lookup_table.add_committed::<B128, 1>("values");
114 let lookup_producer = LookupProducer::new(&mut lookup_table, chan, &[values_col], 8);
115
116 let mut looker_1 = cs.add_table("looker 1");
117 let looker_1_id = looker_1.id();
118 let looker_1_vals = looker_1.add_committed::<B128, 1>("values");
119 looker_1.pull(chan, [looker_1_vals]);
120
121 let mut looker_2 = cs.add_table("looker 2");
122 let looker_2_id = looker_2.id();
123 let looker_2_vals = looker_2.add_committed::<B128, 1>("values");
124 looker_2.pull(chan, [looker_2_vals]);
125
126 let lookup_table_size = 64;
127 let mut rng = StdRng::seed_from_u64(0);
128 let values = repeat_with(|| B128::random(&mut rng))
129 .take(lookup_table_size)
130 .collect::<Vec<_>>();
131
132 let mut counts = vec![0u32; lookup_table_size];
133
134 let looker_1_size = 56;
135 let looker_2_size = 67;
136
137 let mut look_indices = Vec::with_capacity(looker_1_size + looker_2_size);
140 if no_zero_counts {
141 look_indices.extend(0..lookup_table_size);
142 }
143 let remaining = look_indices.capacity() - look_indices.len();
144 look_indices.extend(repeat_with(|| rng.random_range(0..lookup_table_size)).take(remaining));
145
146 let look_values = look_indices
147 .into_iter()
148 .map(|index| {
149 counts[index] += 1;
150 values[index]
151 })
152 .collect::<Vec<_>>();
153
154 let (inputs_1, inputs_2) = look_values.split_at(looker_1_size);
155
156 let values_and_counts = iter::zip(values, counts)
157 .sorted_unstable_by_key(|&(_val, count)| Reverse(count))
158 .collect::<Vec<_>>();
159
160 let mut allocator = CpuComputeAllocator::new(1 << 12);
161 let allocator = allocator.into_bump_allocator();
162 let mut witness =
163 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
164
165 witness
167 .fill_table_sequential(
168 &ClosureFiller::new(lookup_table_id, |values_and_counts, witness| {
169 {
170 let mut values_col = witness.get_scalars_mut(values_col)?;
171 for (dst, (val, _)) in iter::zip(&mut *values_col, values_and_counts) {
172 *dst = *val;
173 }
174 }
175 lookup_producer
176 .populate(witness, values_and_counts.iter().map(|(_, count)| *count))?;
177 Ok(())
178 }),
179 &values_and_counts,
180 )
181 .unwrap();
182
183 witness
185 .fill_table_sequential(
186 &ClosureFiller::new(looker_1_id, |inputs_1, witness| {
187 let mut looker_1_vals = witness.get_scalars_mut(looker_1_vals)?;
188 for (dst, src) in iter::zip(&mut *looker_1_vals, inputs_1) {
189 *dst = *src;
190 }
191 Ok(())
192 }),
193 inputs_1,
194 )
195 .unwrap();
196
197 witness
198 .fill_table_sequential(
199 &ClosureFiller::new(looker_2_id, |inputs_2, witness| {
200 let mut looker_2_vals = witness.get_scalars_mut(looker_2_vals)?;
201 for (dst, src) in iter::zip(&mut *looker_2_vals, inputs_2) {
202 *dst = *src;
203 }
204 Ok(())
205 }),
206 inputs_2,
207 )
208 .unwrap();
209
210 f(&cs, witness)
211 }
212
213 #[test]
214 fn test_basic_lookup_producer() {
215 with_lookup_test_instance(false, |cs, witness| {
216 validate_system_witness::<OptimalUnderlier128b>(cs, witness, vec![])
217 });
218 }
219
220 #[test]
221 fn test_lookup_producer_no_zero_counts() {
222 with_lookup_test_instance(true, |cs, witness| {
223 validate_system_witness::<OptimalUnderlier128b>(cs, witness, vec![])
224 });
225 }
226
227 #[test]
228 fn test_lookup_overflows_max_multiplicity() {
229 let mut cs = ConstraintSystem::new();
230 let chan = cs.add_channel("values");
231
232 let mut lookup_table = cs.add_table("lookup");
233 lookup_table.require_power_of_two_size();
234 let lookup_table_id = lookup_table.id();
235 let values_col = lookup_table.add_committed::<B128, 1>("values");
236 let lookup_producer = LookupProducer::new(&mut lookup_table, chan, &[values_col], 1);
237
238 let lookup_table_size = 64;
239 let mut rng = StdRng::seed_from_u64(0);
240 let values = repeat_with(|| B128::random(&mut rng))
241 .take(lookup_table_size)
242 .collect::<Vec<_>>();
243
244 let counts = vec![9; lookup_table_size];
245 let values_and_counts = iter::zip(values, counts).collect::<Vec<_>>();
246
247 let mut allocator = CpuComputeAllocator::new(1 << 12);
248 let allocator = allocator.into_bump_allocator();
249 let mut witness =
250 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
251
252 let result = witness.fill_table_sequential(
254 &ClosureFiller::new(lookup_table_id, |values_and_counts, witness| {
255 {
256 let mut values_col = witness.get_scalars_mut(values_col)?;
257 for (dst, (val, _)) in iter::zip(&mut *values_col, values_and_counts) {
258 *dst = *val;
259 }
260 }
261 lookup_producer
262 .populate(witness, values_and_counts.iter().map(|(_, count)| *count))?;
263 Ok(())
264 }),
265 &values_and_counts,
266 );
267 assert!(result.is_err());
268 }
269}