binius_m3/gadgets/indexed_lookup/
and.rs1use std::{iter, slice};
7
8use binius_core::constraint_system::channel::ChannelId;
9use binius_field::ext_basis;
10use binius_math::{ArithCircuit, ArithExpr};
11
12use crate::{
13 builder::{
14 B8, B32, B128, Col, IndexedLookup, TableBuilder, TableFiller, TableId, TableWitnessSegment,
15 column::upcast_col,
16 },
17 gadgets::lookup::LookupProducer,
18};
19
20pub struct BitAndLookup {
27 pub table_id: TableId,
29 entries_ordered: Col<B32>,
30 entries_sorted: Col<B32>,
31 lookup_producer: LookupProducer,
32}
33
34pub struct BitAnd<const V: usize = 1> {
35 in_a: Col<B8, V>,
37 in_b: Col<B8, V>,
39 pub output: Col<B8, V>,
41 merged: Col<B32, V>,
43}
44impl<const V: usize> BitAnd<V> {
46 pub fn new(
56 table: &mut TableBuilder,
57 lookup_chan: ChannelId,
58 in_a: Col<B8, V>,
59 in_b: Col<B8, V>,
60 ) -> Self {
61 let output = table.add_committed::<B8, V>("output");
62 let merged = merge_and_columns(table, in_a, in_b, output);
63 table.read(lookup_chan, [merged]);
64 Self {
65 in_a,
66 in_b,
67 output,
68 merged,
69 }
70 }
71
72 pub fn populate(&self, witness: &mut TableWitnessSegment) -> anyhow::Result<()> {
80 let in_a_col = witness.get_scalars(self.in_a)?;
81 let in_b_col = witness.get_scalars(self.in_b)?;
82 let mut output_col: std::cell::RefMut<'_, [B8]> = witness.get_scalars_mut(self.output)?;
83 let mut merged_col: std::cell::RefMut<'_, [B32]> = witness.get_scalars_mut(self.merged)?;
84
85 for i in 0..witness.size() {
86 let in_a = in_a_col[i].val();
87 let in_b = in_b_col[i].val();
88 let output = in_a & in_b;
89 output_col[i] = output.into();
90
91 merged_col[i] = merge_bitand_vals(in_a, in_b, output).into();
93 }
94 Ok(())
95 }
96}
97
98pub fn merge_and_columns<const V: usize>(
100 table: &mut TableBuilder,
101 in_a: Col<B8, V>,
102 in_b: Col<B8, V>,
103 output: Col<B8, V>,
104) -> Col<B32, V> {
105 table.add_computed(
106 "merged",
107 upcast_col(in_a)
108 + upcast_col(in_b) * ext_basis::<B32, B8>(1)
109 + upcast_col(output) * ext_basis::<B32, B8>(2),
110 )
111}
112
113pub fn merge_bitand_vals(in_a: u8, in_b: u8, output: u8) -> u32 {
115 (in_a as u32) | ((in_b as u32) << 8) | ((output as u32) << 16)
116}
117
118pub fn bitand_circuit() -> ArithCircuit<B128> {
121 let mut circuit = ArithExpr::zero();
124 for i in 0..8 {
125 circuit += ArithExpr::Var(i) * ArithExpr::Const(B32::new(1 << i));
126 circuit += ArithExpr::Var(i + 8) * ArithExpr::Const(B32::new(1 << (i + 8)));
127 circuit +=
128 ArithExpr::Var(i) * ArithExpr::Var(i + 8) * ArithExpr::Const(B32::new(1 << (i + 16)));
129 }
130 ArithCircuit::<B32>::from(circuit)
131 .try_convert_field()
132 .expect("And circuit should convert to B128")
133}
134
135impl BitAndLookup {
136 pub fn new(
144 table: &mut TableBuilder,
145 chan: ChannelId,
146 permutation_chan: ChannelId,
147 n_multiplicity_bits: usize,
148 ) -> Self {
149 table.require_fixed_size(BitAndIndexedLookup.log_size());
150
151 let entries_ordered = table.add_fixed("bitand_lookup", bitand_circuit());
153 let entries_sorted = table.add_committed::<B32, 1>("entries_sorted");
154
155 table.push(permutation_chan, [entries_ordered]);
157 table.pull(permutation_chan, [entries_sorted]);
158
159 let lookup_producer =
160 LookupProducer::new(table, chan, &[entries_sorted], n_multiplicity_bits);
161 Self {
162 table_id: table.id(),
163 entries_ordered,
164 entries_sorted,
165 lookup_producer,
166 }
167 }
168}
169
170pub struct BitAndLooker {
173 pub in_a: Col<B8>,
175 pub in_b: Col<B8>,
177 and: BitAnd,
179}
180
181impl BitAndLooker {
182 pub fn new(table: &mut TableBuilder, lookup_chan: ChannelId) -> Self {
184 let in_a = table.add_committed::<B8, 1>("in_a");
185 let in_b = table.add_committed::<B8, 1>("in_b");
186 let and = BitAnd::new(table, lookup_chan, in_a, in_b);
188 Self { in_a, in_b, and }
189 }
190
191 pub fn populate<'a>(
193 &self,
194 witness: &'a mut TableWitnessSegment,
195 inputs: impl Iterator<Item = &'a (u8, u8)> + Clone,
196 ) -> anyhow::Result<()> {
197 {
198 let mut in_a_col: std::cell::RefMut<'_, [u8]> = witness.get_mut_as(self.in_a)?;
199 let mut in_b_col: std::cell::RefMut<'_, [u8]> = witness.get_mut_as(self.in_b)?;
200
201 for (i, &(in_a, in_b)) in inputs.enumerate() {
202 in_a_col[i] = in_a;
203 in_b_col[i] = in_b;
204 }
205 }
206
207 self.and.populate(witness)?;
208 Ok(())
209 }
210}
211
212pub struct BitAndIndexedLookup;
214
215impl IndexedLookup<B128> for BitAndIndexedLookup {
216 fn log_size(&self) -> usize {
218 16
219 }
220
221 fn entry_to_index(&self, entry: &[B128]) -> usize {
223 debug_assert_eq!(entry.len(), 1, "AndLookup entry must be a single B128 field");
224 let merged_val = entry[0].val() as u32;
225 (merged_val & 0xFFFF) as usize
226 }
227
228 fn index_to_entry(&self, index: usize, entry: &mut [B128]) {
230 debug_assert_eq!(entry.len(), 1, "AndLookup entry must be a single B128 field");
231 let in_a = index & 0xFF;
232 let in_b = (index >> 8) & 0xFF;
233 let output = in_a & in_b;
234 let merged = merge_bitand_vals(in_a as u8, in_b as u8, output as u8);
235 entry[0] = B128::from(merged as u128);
236 }
237}
238
239impl TableFiller for BitAndLookup {
241 type Event = (usize, u32);
243
244 fn id(&self) -> TableId {
245 self.table_id
246 }
247
248 fn fill(&self, rows: &[Self::Event], witness: &mut TableWitnessSegment) -> anyhow::Result<()> {
249 {
251 let mut col_data = witness.get_scalars_mut(self.entries_ordered)?;
252 let start_index = witness.index() << witness.log_size();
253 for (i, col_data_i) in col_data.iter_mut().enumerate() {
254 let mut entry_128b = B128::default();
255 BitAndIndexedLookup
256 .index_to_entry(start_index + i, slice::from_mut(&mut entry_128b));
257 *col_data_i = B32::try_from(entry_128b).expect("guaranteed by BitAndIndexedLookup");
258 }
259 }
260
261 {
263 let mut entries_sorted = witness.get_scalars_mut(self.entries_sorted)?;
264 for (merged_i, &(index, _)) in iter::zip(&mut *entries_sorted, rows) {
265 let mut entry_128b = B128::default();
266 BitAndIndexedLookup.index_to_entry(index, slice::from_mut(&mut entry_128b));
267 *merged_i = B32::try_from(entry_128b).expect("guaranteed by BitAndIndexedLookup");
268 }
269 }
270
271 self.lookup_producer
272 .populate(witness, rows.iter().map(|&(_i, count)| count))?;
273 Ok(())
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use std::{cmp::Reverse, iter::repeat_with};
282
283 use binius_compute::cpu::alloc::CpuComputeAllocator;
284 use binius_core::constraint_system::channel::{Boundary, FlushDirection};
285 use binius_field::arch::OptimalUnderlier;
286 use itertools::Itertools;
287 use rand::{Rng, SeedableRng, rngs::StdRng};
288
289 use super::*;
290 use crate::builder::{
291 ConstraintSystem, WitnessIndex, tally,
292 test_utils::{ClosureFiller, validate_system_witness},
293 };
294
295 #[test]
296 fn test_and_lookup() {
297 let mut cs: ConstraintSystem<B128> = ConstraintSystem::new();
298 let lookup_chan = cs.add_channel("lookup");
299 let permutation_chan = cs.add_channel("permutation");
300 let mut and_table = cs.add_table("bitand_lookup");
301 let n_multiplicity_bits = 8;
302
303 let bitand_lookup =
304 BitAndLookup::new(&mut and_table, lookup_chan, permutation_chan, n_multiplicity_bits);
305 let mut and_looker = cs.add_table("bitand_looker");
306
307 let bitand_1 = BitAndLooker::new(&mut and_looker, lookup_chan);
308
309 let looker_1_size = 5;
310 let looker_id = and_looker.id();
311
312 let mut allocator = CpuComputeAllocator::new(1 << 16);
313 let allocator = allocator.into_bump_allocator();
314 let mut witness = WitnessIndex::new(&cs, &allocator);
315
316 let mut rng = StdRng::seed_from_u64(0);
317 let inputs_1 = repeat_with(|| {
318 let in_a = rng.random::<u8>();
319 let in_b = rng.random::<u8>();
320 (in_a, in_b)
321 })
322 .take(looker_1_size)
323 .collect::<Vec<_>>();
324
325 witness
326 .fill_table_parallel(
327 &ClosureFiller::new(looker_id, |inputs, segment| {
328 bitand_1.populate(segment, inputs.iter())
329 }),
330 &inputs_1,
331 )
332 .unwrap();
333
334 let boundary_reads = (0..5)
335 .map(|_| {
336 let in_a = rng.random::<u8>();
337 let in_b = rng.random::<u8>();
338 merge_bitand_vals(in_a, in_b, in_a & in_b)
339 })
340 .collect::<Vec<_>>();
341
342 let boundaries = boundary_reads
343 .into_iter()
344 .map(|val| Boundary {
345 values: vec![B32::new(val).into()],
346 direction: FlushDirection::Pull,
347 channel_id: lookup_chan,
348 multiplicity: 1,
349 })
350 .collect::<Vec<_>>();
351
352 let counts =
354 tally(&cs, &mut witness, &boundaries, lookup_chan, &BitAndIndexedLookup).unwrap();
355
356 let sorted_counts = counts
358 .into_iter()
359 .enumerate()
360 .sorted_by_key(|(_, count)| Reverse(*count))
361 .collect::<Vec<_>>();
362
363 witness
364 .fill_table_parallel(&bitand_lookup, &sorted_counts)
365 .unwrap();
366
367 validate_system_witness::<OptimalUnderlier>(&cs, witness, boundaries);
368 }
369}