1use std::{iter, slice};
8
9use binius_core::constraint_system::channel::ChannelId;
10use binius_field::{
11 PackedExtension, PackedFieldIndexable, ext_basis,
12 packed::{get_packed_slice, set_packed_slice},
13};
14use binius_math::{ArithCircuit, ArithExpr};
15
16use crate::{
17 builder::{
18 B1, B8, B32, B128, Col, IndexedLookup, TableBuilder, TableFiller, TableId,
19 TableWitnessSegment, upcast_col,
20 },
21 gadgets::lookup::LookupProducer,
22};
23
24pub struct Incr {
36 pub input: Col<B8>,
37 pub carry_in: Col<B1>,
38 pub output: Col<B8>,
39 pub carry_out: Col<B1>,
40 pub merged: Col<B32>,
41}
42
43impl Incr {
44 pub fn new(
55 table: &mut TableBuilder,
56 lookup_chan: ChannelId,
57 input: Col<B8>,
58 carry_in: Col<B1>,
59 ) -> Self {
60 let output = table.add_committed::<B8, 1>("output");
61 let carry_out = table.add_committed::<B1, 1>("carry_out");
62 let merged = merge_incr_cols(table, input, carry_in, output, carry_out);
63
64 table.pull(lookup_chan, [merged]);
65
66 Self {
67 input,
68 carry_in,
69 output,
70 carry_out,
71 merged,
72 }
73 }
74
75 pub fn populate<P>(&self, witness: &mut TableWitnessSegment<P>) -> anyhow::Result<()>
83 where
84 P: PackedFieldIndexable<Scalar = B128>
85 + PackedExtension<B1>
86 + PackedExtension<B8>
87 + PackedExtension<B32>,
88 {
89 let input = witness.get_as::<u8, _, 1>(self.input)?;
90 let carry_in = witness.get(self.carry_in)?;
91 let mut output = witness.get_mut_as::<u8, _, 1>(self.output)?;
92 let mut carry_out = witness.get_mut(self.carry_out)?;
93 let mut merged = witness.get_mut_as::<u32, _, 1>(self.merged)?;
94
95 for i in 0..witness.size() {
96 let input_i = input[i];
97 let carry_in_bit = bool::from(get_packed_slice(&carry_in, i).val());
98
99 let (output_i, carry_out_bit) = input_i.overflowing_add(carry_in_bit.into());
100 output[i] = output_i;
101 set_packed_slice(&mut carry_out, i, B1::from(carry_out_bit));
102 merged[i] = ((carry_out_bit as u32) << 17)
103 | ((carry_in_bit as u32) << 16)
104 | ((output_i as u32) << 8)
105 | input_i as u32;
106 }
107
108 Ok(())
109 }
110}
111
112pub struct IncrLooker {
114 pub input: Col<B8>,
115 pub carry_in: Col<B1>,
116 incr: Incr,
117}
118
119impl IncrLooker {
120 pub fn new(table: &mut TableBuilder, lookup_chan: ChannelId) -> Self {
122 let input = table.add_committed::<B8, 1>("input");
123 let carry_in = table.add_committed::<B1, 1>("carry_in");
124 let incr = Incr::new(table, lookup_chan, input, carry_in);
125
126 Self {
127 input,
128 carry_in,
129 incr,
130 }
131 }
132
133 pub fn populate<'a, P>(
135 &self,
136 witness: &mut TableWitnessSegment<P>,
137 events: impl Iterator<Item = &'a (u8, bool)>,
138 ) -> anyhow::Result<()>
139 where
140 P: PackedFieldIndexable<Scalar = B128>
141 + PackedExtension<B1>
142 + PackedExtension<B8>
143 + PackedExtension<B32>,
144 {
145 {
146 let mut input = witness.get_mut_as::<u8, _, 1>(self.input)?;
147 let mut carry_in = witness.get_mut(self.carry_in)?;
148
149 for (i, &(input_i, carry_in_bit)) in events.enumerate() {
150 input[i] = input_i;
151 set_packed_slice(&mut carry_in, i, B1::from(carry_in_bit));
152 }
153 }
154
155 self.incr.populate(witness)?;
156 Ok(())
157 }
158}
159
160pub struct IncrLookup {
162 table_id: TableId,
163 entries_ordered: Col<B32>,
164 entries_sorted: Col<B32>,
165 lookup_producer: LookupProducer,
166}
167
168impl IncrLookup {
169 pub fn new(
177 table: &mut TableBuilder,
178 chan: ChannelId,
179 permutation_chan: ChannelId,
180 n_multiplicity_bits: usize,
181 ) -> Self {
182 table.require_fixed_size(IncrIndexedLookup.log_size());
183
184 let entries_ordered = table.add_fixed("incr_lookup", incr_circuit());
186 let entries_sorted = table.add_committed::<B32, 1>("entries_sorted");
187
188 table.push(permutation_chan, [entries_ordered]);
190 table.pull(permutation_chan, [entries_sorted]);
191
192 let lookup_producer =
193 LookupProducer::new(table, chan, &[entries_sorted], n_multiplicity_bits);
194 Self {
195 table_id: table.id(),
196 entries_ordered,
197 entries_sorted,
198 lookup_producer,
199 }
200 }
201}
202
203impl TableFiller for IncrLookup {
205 type Event = (usize, u32);
207
208 fn id(&self) -> TableId {
209 self.table_id
210 }
211
212 fn fill(&self, rows: &[Self::Event], witness: &mut TableWitnessSegment) -> anyhow::Result<()> {
213 {
215 let mut col_data = witness.get_scalars_mut(self.entries_ordered)?;
216 let start_index = witness.index() << witness.log_size();
217 for (i, col_data_i) in col_data.iter_mut().enumerate() {
218 let mut entry_128b = B128::default();
219 IncrIndexedLookup.index_to_entry(start_index + i, slice::from_mut(&mut entry_128b));
220 *col_data_i = B32::try_from(entry_128b).expect("guaranteed by IncrIndexedLookup");
221 }
222 }
223
224 {
226 let mut entries_sorted = witness.get_scalars_mut(self.entries_sorted)?;
227 for (merged_i, &(index, _)) in iter::zip(&mut *entries_sorted, rows.iter()) {
228 let mut entry_128b = B128::default();
229 IncrIndexedLookup.index_to_entry(index, slice::from_mut(&mut entry_128b));
230 *merged_i = B32::try_from(entry_128b).expect("guaranteed by IncrIndexedLookup");
231 }
232 }
233
234 self.lookup_producer
235 .populate(witness, rows.iter().map(|&(_i, count)| count))?;
236 Ok(())
237 }
238}
239
240pub struct IncrIndexedLookup;
242
243impl IndexedLookup<B128> for IncrIndexedLookup {
244 fn log_size(&self) -> usize {
246 8 + 1
248 }
249
250 fn entry_to_index(&self, entry: &[B128]) -> usize {
252 debug_assert_eq!(entry.len(), 1);
253 let merged_val = entry[0].val() as u32;
254 let input = merged_val & 0xFF;
255 let carry_in_bit = (merged_val >> 16) & 1 == 1;
256 (carry_in_bit as usize) << 8 | input as usize
257 }
258
259 fn index_to_entry(&self, index: usize, entry: &mut [B128]) {
261 debug_assert_eq!(entry.len(), 1);
262 let input = (index % (1 << 8)) as u8;
263 let carry_in_bit = (index >> 8) & 1 == 1;
264 let (output, carry_out_bit) = input.overflowing_add(carry_in_bit.into());
265 let entry_u32 = merge_incr_vals(input, carry_in_bit, output, carry_out_bit);
266 entry[0] = B32::new(entry_u32).into();
267 }
268}
269
270pub fn carry_in_circuit(i: usize) -> ArithExpr<B128> {
273 let mut circuit = ArithExpr::Var(8);
276 for var in 0..i {
277 circuit *= ArithExpr::Var(var)
278 }
279 circuit
280}
281
282pub fn incr_circuit() -> ArithCircuit<B128> {
285 let mut circuit = ArithExpr::zero();
288 for i in 0..8 {
289 circuit += ArithExpr::Var(i) * ArithExpr::Const(B128::from(1 << i));
290
291 let carry = carry_in_circuit(i);
292 circuit += (ArithExpr::Var(i) + carry.clone()) * ArithExpr::Const(B128::from(1 << (i + 8)));
293 }
294 circuit += ArithExpr::Var(8) * ArithExpr::Const(B128::from(1 << 16));
295 let carry_out = carry_in_circuit(8);
296 circuit += carry_out * ArithExpr::Const(B128::from(1 << 17));
297 circuit.into()
298}
299
300pub fn merge_incr_cols(
302 table: &mut TableBuilder,
303 input: Col<B8>,
304 carry_in: Col<B1>,
305 output: Col<B8>,
306 carry_out: Col<B1>,
307) -> Col<B32> {
308 let beta_1 = ext_basis::<B32, B8>(1);
309 let beta_2_0 = ext_basis::<B32, B8>(2);
310 let beta_2_1 = beta_2_0 * ext_basis::<B8, B1>(1);
311 table.add_computed(
312 "merged",
313 upcast_col(input)
314 + upcast_col(output) * beta_1
315 + upcast_col(carry_in) * beta_2_0
316 + upcast_col(carry_out) * beta_2_1,
317 )
318}
319
320pub fn merge_incr_vals(input: u8, carry_in: bool, output: u8, carry_out: bool) -> u32 {
322 ((carry_out as u32) << 17) | ((carry_in as u32) << 16) | ((output as u32) << 8) | input as u32
323}
324
325#[cfg(test)]
326mod tests {
327 use std::{cmp::Reverse, iter::repeat_with};
329
330 use binius_compute::cpu::alloc::CpuComputeAllocator;
331 use binius_core::constraint_system::channel::{Boundary, FlushDirection};
332 use binius_field::arch::OptimalUnderlier;
333 use itertools::Itertools;
334 use rand::{Rng, SeedableRng, rngs::StdRng};
335
336 use super::*;
337 use crate::builder::{
338 ConstraintSystem, WitnessIndex, tally,
339 test_utils::{ClosureFiller, validate_system_witness},
340 };
341
342 #[test]
345 fn test_fixed_lookup_producer() {
346 let mut cs = ConstraintSystem::new();
347 let incr_lookup_chan = cs.add_channel("incr lookup");
348 let incr_lookup_perm_chan = cs.add_channel("incr lookup permutation");
349
350 let n_multiplicity_bits = 8;
351
352 let mut incr_table = cs.add_table("increment");
353 let incr_lookup = IncrLookup::new(
354 &mut incr_table,
355 incr_lookup_chan,
356 incr_lookup_perm_chan,
357 n_multiplicity_bits,
358 );
359
360 let mut looker_1 = cs.add_table("looker 1");
361 let looker_1_id = looker_1.id();
362 let incr_1 = IncrLooker::new(&mut looker_1, incr_lookup_chan);
363
364 let mut looker_2 = cs.add_table("looker 2");
365 let looker_2_id = looker_2.id();
366 let incr_2 = IncrLooker::new(&mut looker_2, incr_lookup_chan);
367
368 let looker_1_size = 5;
369 let looker_2_size = 6;
370
371 let mut allocator = CpuComputeAllocator::new(1 << 12);
372 let allocator = allocator.into_bump_allocator();
373 let mut witness = WitnessIndex::new(&cs, &allocator);
374
375 let mut rng = StdRng::seed_from_u64(0);
376 let inputs_1 = repeat_with(|| {
377 let input = rng.random::<u8>();
378 let carry_in_bit = rng.random_bool(0.5);
379 (input, carry_in_bit)
380 })
381 .take(looker_1_size)
382 .collect::<Vec<_>>();
383
384 witness
385 .fill_table_sequential(
386 &ClosureFiller::new(looker_1_id, |inputs, segment| {
387 incr_1.populate(segment, inputs.iter())
388 }),
389 &inputs_1,
390 )
391 .unwrap();
392
393 let inputs_2 = repeat_with(|| {
394 let input = rng.random::<u8>();
395 let carry_in_bit = rng.random_bool(0.5);
396 (input, carry_in_bit)
397 })
398 .take(looker_2_size)
399 .collect::<Vec<_>>();
400
401 witness
402 .fill_table_sequential(
403 &ClosureFiller::new(looker_2_id, |inputs, segment| {
404 incr_2.populate(segment, inputs.iter())
405 }),
406 &inputs_2,
407 )
408 .unwrap();
409
410 let boundary_reads = vec![
411 merge_incr_vals(111, false, 111, false),
412 merge_incr_vals(111, true, 112, false),
413 merge_incr_vals(255, false, 255, false),
414 merge_incr_vals(255, true, 0, true),
415 ];
416 let boundaries = boundary_reads
417 .into_iter()
418 .map(|val| Boundary {
419 values: vec![B32::new(val).into()],
420 direction: FlushDirection::Pull,
421 channel_id: incr_lookup_chan,
422 multiplicity: 1,
423 })
424 .collect::<Vec<_>>();
425
426 let counts =
428 tally(&cs, &mut witness, &boundaries, incr_lookup_chan, &IncrIndexedLookup).unwrap();
429
430 let sorted_counts = counts
432 .into_iter()
433 .enumerate()
434 .sorted_by_key(|(_, count)| Reverse(*count))
435 .collect::<Vec<_>>();
436
437 witness
438 .fill_table_sequential(&incr_lookup, &sorted_counts)
439 .unwrap();
440
441 validate_system_witness::<OptimalUnderlier>(&cs, witness, boundaries);
442 }
443}