use binius_core::{
constraint_system::channel::{Boundary, FlushDirection},
oracle::OracleId,
};
use binius_field::{
as_packed_field::PackScalar, packed::set_packed_slice, BinaryField1b, ExtensionField, Field,
TowerField,
};
use bytemuck::Pod;
use itertools::izip;
use crate::builder::ConstraintSystemBuilder;
pub fn plain_lookup<U, F, FS, const LOG_MAX_MULTIPLICITY: usize>(
builder: &mut ConstraintSystemBuilder<U, F>,
table: OracleId,
table_count: usize,
balancer_value: FS,
lookup_values: OracleId,
lookup_values_count: usize,
) -> Result<Boundary<F>, anyhow::Error>
where
U: PackScalar<F> + PackScalar<FS> + PackScalar<BinaryField1b> + Pod,
F: TowerField + ExtensionField<FS>,
FS: TowerField + Pod,
{
let n_vars = builder.log_rows([table])?;
debug_assert!(table_count <= 1 << n_vars);
let channel = builder.add_channel();
builder.send(channel, lookup_values_count, [lookup_values]);
let mut multiplicities = None;
if let Some(witness) = builder.witness() {
let table_slice = witness.get::<FS>(table)?.as_slice::<FS>();
let values_slice = witness.get::<FS>(lookup_values)?.as_slice::<FS>();
multiplicities = Some(count_multiplicities(
&table_slice[0..table_count],
&values_slice[0..lookup_values_count],
false,
)?);
}
let components: [OracleId; LOG_MAX_MULTIPLICITY] =
get_components::<_, _, FS, LOG_MAX_MULTIPLICITY>(
builder,
table,
table_count,
balancer_value,
multiplicities,
)?;
components
.into_iter()
.enumerate()
.for_each(|(i, component)| {
builder.flush_with_multiplicity(
FlushDirection::Pull,
channel,
table_count,
[component],
1 << i,
);
});
let balancer_value_multiplicity =
(((1 << LOG_MAX_MULTIPLICITY) - 1) * table_count - lookup_values_count) as u64;
let boundary = Boundary {
values: vec![balancer_value.into()],
channel_id: channel,
direction: FlushDirection::Push,
multiplicity: balancer_value_multiplicity,
};
Ok(boundary)
}
fn get_components<U, F, FS, const LOG_MAX_MULTIPLICITY: usize>(
builder: &mut ConstraintSystemBuilder<U, F>,
table: OracleId,
table_count: usize,
balancer_value: FS,
multiplicities: Option<Vec<usize>>,
) -> Result<[OracleId; LOG_MAX_MULTIPLICITY], anyhow::Error>
where
U: PackScalar<F> + PackScalar<FS> + PackScalar<BinaryField1b> + Pod,
F: TowerField + ExtensionField<FS>,
FS: TowerField + Pod,
{
let n_vars = builder.log_rows([table])?;
let bits: [OracleId; LOG_MAX_MULTIPLICITY] = builder
.add_committed_multiple::<LOG_MAX_MULTIPLICITY>("bits", n_vars, BinaryField1b::TOWER_LEVEL);
let components: [OracleId; LOG_MAX_MULTIPLICITY] = builder
.add_committed_multiple::<LOG_MAX_MULTIPLICITY>("components", n_vars, FS::TOWER_LEVEL);
if let Some(witness) = builder.witness() {
let multiplicities =
multiplicities.ok_or_else(|| anyhow::anyhow!("multiplicities empty for prover"))?;
debug_assert_eq!(table_count, multiplicities.len());
if multiplicities
.iter()
.any(|&multiplicity| multiplicity >= 1 << LOG_MAX_MULTIPLICITY)
{
return Err(anyhow::anyhow!(
"one or more multiplicities exceed `1 << LOG_MAX_MULTIPLICITY`"
));
}
let mut bit_cols = bits.map(|bit| witness.new_column::<BinaryField1b>(bit));
let mut packed_bit_cols = bit_cols.each_mut().map(|bit_col| bit_col.packed());
let mut component_cols = components.map(|component| witness.new_column::<FS>(component));
let mut packed_component_cols = component_cols
.each_mut()
.map(|component_col| component_col.packed());
let table_slice = witness.get::<FS>(table)?.as_slice::<FS>();
izip!(table_slice, multiplicities).enumerate().for_each(
|(i, (table_val, multiplicity))| {
for j in 0..LOG_MAX_MULTIPLICITY {
let bit_set = multiplicity & (1 << j) != 0;
set_packed_slice(
packed_bit_cols[j],
i,
match bit_set {
true => BinaryField1b::ONE,
false => BinaryField1b::ZERO,
},
);
set_packed_slice(
packed_component_cols[j],
i,
match bit_set {
true => *table_val,
false => balancer_value,
},
);
}
},
);
}
let expression = {
use binius_math::ArithExpr as Expr;
let table = Expr::Var(0);
let bit = Expr::Var(1);
let component = Expr::Var(2);
component - (bit.clone() * table + (Expr::one() - bit) * Expr::Const(balancer_value))
};
(0..LOG_MAX_MULTIPLICITY).for_each(|i| {
builder.assert_zero(
format!("lookup_{i}"),
[table, bits[i], components[i]],
expression.convert_field(),
);
});
Ok(components)
}
#[cfg(test)]
pub mod test_plain_lookup {
use binius_field::BinaryField32b;
use rayon::prelude::*;
use super::*;
use crate::transparent;
fn into_lookup_claim(x: u8, y: u8, z: u16) -> u32 {
((z as u32) << 16) | ((y as u32) << 8) | (x as u32)
}
fn generate_u8_mul_table() -> Vec<u32> {
let mut result = Vec::with_capacity(1 << 16);
for x in 0..=255u8 {
for y in 0..=255u8 {
let product = x as u16 * y as u16;
result.push(into_lookup_claim(x, y, product));
}
}
result
}
fn generate_random_u8_mul_claims(vals: &mut [u32]) {
use rand::Rng;
vals.par_iter_mut().for_each(|val| {
let mut rng = rand::thread_rng();
let x = rng.gen_range(0..=255u8);
let y = rng.gen_range(0..=255u8);
let product = x as u16 * y as u16;
*val = into_lookup_claim(x, y, product);
});
}
pub fn test_u8_mul_lookup<U, F, const LOG_MAX_MULTIPLICITY: usize>(
builder: &mut ConstraintSystemBuilder<U, F>,
log_lookup_count: usize,
) -> Result<Boundary<F>, anyhow::Error>
where
U: PackScalar<F> + PackScalar<BinaryField1b> + PackScalar<BinaryField32b> + Pod,
F: TowerField + ExtensionField<BinaryField32b>,
{
let table_values = generate_u8_mul_table();
let table = transparent::make_transparent(
builder,
"u8_mul_table",
bytemuck::cast_slice::<_, BinaryField32b>(&table_values),
)?;
let balancer_value = BinaryField32b::new(table_values[99]); let lookup_values =
builder.add_committed("lookup_values", log_lookup_count, BinaryField32b::TOWER_LEVEL);
let table_count = table_values.len();
let lookup_values_count = 1 << log_lookup_count;
if let Some(witness) = builder.witness() {
let mut lookup_values_col = witness.new_column::<BinaryField32b>(lookup_values);
let mut_slice = lookup_values_col.as_mut_slice::<u32>();
generate_random_u8_mul_claims(&mut mut_slice[0..lookup_values_count]);
}
let boundary = plain_lookup::<U, F, BinaryField32b, LOG_MAX_MULTIPLICITY>(
builder,
table,
table_count,
balancer_value,
lookup_values,
lookup_values_count,
)?;
Ok(boundary)
}
}
fn count_multiplicities<T: Eq + std::hash::Hash + Clone + std::fmt::Debug>(
table: &[T],
values: &[T],
check_inclusion: bool,
) -> Result<Vec<usize>, anyhow::Error> {
use std::collections::{HashMap, HashSet};
if check_inclusion {
let table_set: HashSet<_> = table.iter().cloned().collect();
if let Some(invalid_value) = values.iter().find(|value| !table_set.contains(value)) {
return Err(anyhow::anyhow!("value {:?} not in table", invalid_value));
}
}
let counts: HashMap<_, usize> = values.iter().fold(HashMap::new(), |mut acc, value| {
*acc.entry(value).or_insert(0) += 1;
acc
});
let multiplicities = table
.iter()
.map(|item| counts.get(item).copied().unwrap_or(0))
.collect();
Ok(multiplicities)
}
#[cfg(test)]
mod count_multiplicity_tests {
use super::*;
#[test]
fn test_basic_functionality() {
let table = vec![1, 2, 3, 4];
let values = vec![1, 2, 2, 3, 3, 3];
let result = count_multiplicities(&table, &values, true).unwrap();
assert_eq!(result, vec![1, 2, 3, 0]);
}
#[test]
fn test_empty_values() {
let table = vec![1, 2, 3];
let values: Vec<i32> = vec![];
let result = count_multiplicities(&table, &values, true).unwrap();
assert_eq!(result, vec![0, 0, 0]);
}
#[test]
fn test_empty_table() {
let table: Vec<i32> = vec![];
let values = vec![1, 2, 3];
let result = count_multiplicities(&table, &values, false).unwrap();
assert_eq!(result, vec![]);
}
#[test]
fn test_value_not_in_table() {
let table = vec![1, 2, 3];
let values = vec![1, 4, 2];
let result = count_multiplicities(&table, &values, true);
assert!(result.is_err());
assert_eq!(result.unwrap_err().to_string(), "value 4 not in table");
}
#[test]
fn test_duplicates_in_table() {
let table = vec![1, 1, 2, 3];
let values = vec![1, 2, 2, 3, 3, 3];
let result = count_multiplicities(&table, &values, true).unwrap();
assert_eq!(result, vec![1, 1, 2, 3]);
}
#[test]
fn test_non_integer_values() {
let table = vec!["a", "b", "c"];
let values = vec!["a", "b", "b", "c", "c", "c"];
let result = count_multiplicities(&table, &values, true).unwrap();
assert_eq!(result, vec![1, 2, 3]);
}
}