use binius_field::{util::eq, Field, PackedFieldIndexable, TowerField};
use binius_math::MultilinearExtension;
use binius_utils::bail;
use crate::{
oracle::ShiftVariant,
polynomial::{Error, MultivariatePoly},
};
#[derive(Debug, Clone)]
pub struct ShiftIndPartialEval<F: Field> {
block_size: usize,
shift_offset: usize,
shift_variant: ShiftVariant,
r: Vec<F>,
}
impl<F: Field> ShiftIndPartialEval<F> {
pub fn new(
block_size: usize,
shift_offset: usize,
shift_variant: ShiftVariant,
r: Vec<F>,
) -> Result<Self, Error> {
assert_valid_shift_ind_args(block_size, shift_offset, &r)?;
Ok(Self {
block_size,
shift_offset,
r,
shift_variant,
})
}
fn multilinear_extension_circular<P>(&self) -> Result<MultilinearExtension<P>, Error>
where
P: PackedFieldIndexable<Scalar = F>,
{
let (ps, pps) =
partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
let values = ps
.iter()
.zip(pps)
.map(|(p, pp)| *p + pp)
.collect::<Vec<_>>();
Ok(MultilinearExtension::from_values(values)?)
}
fn multilinear_extension_logical_left<P>(&self) -> Result<MultilinearExtension<P>, Error>
where
P: PackedFieldIndexable<Scalar = F>,
{
let (ps, _) =
partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
Ok(MultilinearExtension::from_values(ps)?)
}
fn multilinear_extension_logical_right<P>(&self) -> Result<MultilinearExtension<P>, Error>
where
P: PackedFieldIndexable<Scalar = F>,
{
let right_shift_offset = get_left_shift_offset(self.block_size, self.shift_offset);
let (_, pps) =
partial_evaluate_hypercube_impl::<P>(self.block_size, right_shift_offset, &self.r)?;
Ok(MultilinearExtension::from_values(pps)?)
}
pub fn multilinear_extension<P>(&self) -> Result<MultilinearExtension<P>, Error>
where
P: PackedFieldIndexable<Scalar = F>,
{
match self.shift_variant {
ShiftVariant::CircularLeft => self.multilinear_extension_circular(),
ShiftVariant::LogicalLeft => self.multilinear_extension_logical_left(),
ShiftVariant::LogicalRight => self.multilinear_extension_logical_right(),
}
}
fn evaluate_at_point(&self, x: &[F]) -> Result<F, Error> {
if x.len() != self.block_size {
bail!(Error::IncorrectQuerySize {
expected: self.block_size,
});
}
let left_shift_offset = match self.shift_variant {
ShiftVariant::CircularLeft => self.shift_offset,
ShiftVariant::LogicalLeft => self.shift_offset,
ShiftVariant::LogicalRight => get_left_shift_offset(self.block_size, self.shift_offset),
};
let (p_res, pp_res) =
evaluate_shift_ind_help(self.block_size, left_shift_offset, x, &self.r)?;
match self.shift_variant {
ShiftVariant::CircularLeft => Ok(p_res + pp_res),
ShiftVariant::LogicalLeft => Ok(p_res),
ShiftVariant::LogicalRight => Ok(pp_res),
}
}
}
impl<F: TowerField> MultivariatePoly<F> for ShiftIndPartialEval<F> {
fn n_vars(&self) -> usize {
self.block_size
}
fn degree(&self) -> usize {
self.block_size
}
fn evaluate(&self, query: &[F]) -> Result<F, Error> {
self.evaluate_at_point(query)
}
fn binary_tower_level(&self) -> usize {
F::TOWER_LEVEL
}
}
fn get_left_shift_offset(block_size: usize, right_shift_offset: usize) -> usize {
(1 << block_size) - right_shift_offset
}
fn assert_valid_shift_ind_args<F: Field>(
block_size: usize,
shift_offset: usize,
partial_query_point: &[F],
) -> Result<(), Error> {
if partial_query_point.len() != block_size {
bail!(Error::IncorrectQuerySize {
expected: block_size,
});
}
if shift_offset == 0 || shift_offset >= 1 << block_size {
bail!(Error::InvalidShiftOffset {
max_shift_offset: (1 << block_size) - 1,
shift_offset,
});
}
Ok(())
}
fn evaluate_shift_ind_help<F: Field>(
block_size: usize,
shift_offset: usize,
x: &[F],
y: &[F],
) -> Result<(F, F), Error> {
if x.len() != block_size {
bail!(Error::IncorrectQuerySize {
expected: block_size,
});
}
assert_valid_shift_ind_args(block_size, shift_offset, y)?;
let (mut s_ind_p, mut s_ind_pp) = (F::ONE, F::ZERO);
let (mut temp_p, mut temp_pp) = (F::default(), F::default());
(0..block_size).for_each(|k| {
let o_k = shift_offset >> k;
let product = x[k] * y[k];
if o_k % 2 == 1 {
temp_p = (y[k] - product) * s_ind_p;
temp_pp = (x[k] - product) * s_ind_p + eq(x[k], y[k]) * s_ind_pp;
} else {
temp_p = eq(x[k], y[k]) * s_ind_p + (y[k] - product) * s_ind_pp;
temp_pp = (x[k] - product) * s_ind_pp;
}
s_ind_p = temp_p;
s_ind_pp = temp_pp;
});
Ok((s_ind_p, s_ind_pp))
}
fn partial_evaluate_hypercube_impl<P: PackedFieldIndexable>(
block_size: usize,
shift_offset: usize,
r: &[P::Scalar],
) -> Result<(Vec<P>, Vec<P>), Error> {
assert_valid_shift_ind_args(block_size, shift_offset, r)?;
let mut s_ind_p = vec![P::one(); 1 << (block_size - P::LOG_WIDTH)];
let mut s_ind_pp = vec![P::zero(); 1 << (block_size - P::LOG_WIDTH)];
partial_evaluate_hypercube_with_buffers(
block_size.min(P::LOG_WIDTH),
shift_offset,
r,
P::unpack_scalars_mut(&mut s_ind_p),
P::unpack_scalars_mut(&mut s_ind_pp),
);
if block_size > P::LOG_WIDTH {
partial_evaluate_hypercube_with_buffers(
block_size - P::LOG_WIDTH,
shift_offset >> P::LOG_WIDTH,
&r[P::LOG_WIDTH..],
&mut s_ind_p,
&mut s_ind_pp,
);
}
Ok((s_ind_p, s_ind_pp))
}
fn partial_evaluate_hypercube_with_buffers<P: PackedFieldIndexable>(
block_size: usize,
shift_offset: usize,
r: &[P::Scalar],
s_ind_p: &mut [P],
s_ind_pp: &mut [P],
) {
for k in 0..block_size {
if (shift_offset >> k) % 2 == 1 {
for i in 0..(1 << k) {
let mut pp_lo = s_ind_pp[i];
let mut pp_hi = pp_lo * r[k];
pp_lo -= pp_hi;
let p_lo = s_ind_p[i];
let p_hi = p_lo * r[k];
pp_hi += p_lo - p_hi; s_ind_pp[i] = pp_lo;
s_ind_pp[1 << k | i] = pp_hi;
s_ind_p[i] = p_hi;
s_ind_p[1 << k | i] = P::zero(); }
} else {
for i in 0..(1 << k) {
let mut p_lo = s_ind_p[i];
let p_hi = p_lo * r[k];
p_lo -= p_hi;
let pp_lo = s_ind_pp[i];
let pp_hi = pp_lo * (P::one() - r[k]);
p_lo += pp_lo - pp_hi;
s_ind_p[i] = p_lo;
s_ind_p[1 << k | i] = p_hi;
s_ind_pp[i] = P::zero(); s_ind_pp[1 << k | i] = pp_hi;
}
}
}
}
#[cfg(test)]
mod tests {
use std::iter::repeat_with;
use binius_field::{BinaryField32b, PackedBinaryField4x32b};
use binius_hal::{make_portable_backend, ComputationBackendExt};
use rand::{rngs::StdRng, SeedableRng};
use super::*;
use crate::polynomial::test_utils::decompose_index_to_hypercube_point;
fn test_circular_left_shift_consistency_help<
F: TowerField,
P: PackedFieldIndexable<Scalar = F>,
>(
block_size: usize,
right_shift_offset: usize,
) {
let mut rng = StdRng::seed_from_u64(0);
let backend = make_portable_backend();
let r = repeat_with(|| F::random(&mut rng))
.take(block_size)
.collect::<Vec<_>>();
let eval_point = &repeat_with(|| F::random(&mut rng))
.take(block_size)
.collect::<Vec<_>>();
let shift_variant = ShiftVariant::CircularLeft;
let shift_r_mvp =
ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
assert_eq!(eval_mle, eval_mvp);
}
fn test_logical_left_shift_consistency_help<
F: TowerField,
P: PackedFieldIndexable<Scalar = F>,
>(
block_size: usize,
right_shift_offset: usize,
) {
let mut rng = StdRng::seed_from_u64(0);
let backend = make_portable_backend();
let r = repeat_with(|| F::random(&mut rng))
.take(block_size)
.collect::<Vec<_>>();
let eval_point = &repeat_with(|| F::random(&mut rng))
.take(block_size)
.collect::<Vec<_>>();
let shift_variant = ShiftVariant::LogicalLeft;
let shift_r_mvp =
ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
assert_eq!(eval_mle, eval_mvp);
}
fn test_logical_right_shift_consistency_help<
F: TowerField,
P: PackedFieldIndexable<Scalar = F>,
>(
block_size: usize,
left_shift_offset: usize,
) {
let mut rng = StdRng::seed_from_u64(0);
let backend = make_portable_backend();
let r = repeat_with(|| F::random(&mut rng))
.take(block_size)
.collect::<Vec<_>>();
let eval_point = &repeat_with(|| F::random(&mut rng))
.take(block_size)
.collect::<Vec<_>>();
let shift_variant = ShiftVariant::LogicalRight;
let shift_r_mvp =
ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
assert_eq!(eval_mle, eval_mvp);
}
#[test]
fn test_circular_left_shift_consistency_schwartz_zippel() {
for block_size in 2..=10 {
for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
test_circular_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
block_size,
right_shift_offset,
);
}
}
}
#[test]
fn test_logical_left_shift_consistency_schwartz_zippel() {
for block_size in 2..=10 {
for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
test_logical_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
block_size,
right_shift_offset,
);
}
}
}
#[test]
fn test_logical_right_shift_consistency_schwartz_zippel() {
for block_size in 2..=10 {
for left_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
test_logical_right_shift_consistency_help::<_, PackedBinaryField4x32b>(
block_size,
left_shift_offset,
);
}
}
}
fn test_circular_left_shift_functionality_help<F: TowerField>(
block_size: usize,
right_shift_offset: usize,
) {
let shift_variant = ShiftVariant::CircularLeft;
(0..(1 << block_size)).for_each(|i| {
let r = decompose_index_to_hypercube_point::<F>(block_size, i);
let shift_r_mvp =
ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
(0..(1 << block_size)).for_each(|j| {
let x = decompose_index_to_hypercube_point::<F>(block_size, j);
let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
if (j + right_shift_offset) % (1 << block_size) == i {
assert_eq!(eval_mvp, F::ONE);
} else {
assert_eq!(eval_mvp, F::ZERO);
}
});
});
}
fn test_logical_left_shift_functionality_help<F: TowerField>(
block_size: usize,
right_shift_offset: usize,
) {
let shift_variant = ShiftVariant::LogicalLeft;
(0..(1 << block_size)).for_each(|i| {
let r = decompose_index_to_hypercube_point::<F>(block_size, i);
let shift_r_mvp =
ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
(0..(1 << block_size)).for_each(|j| {
let x = decompose_index_to_hypercube_point::<F>(block_size, j);
let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
if j + right_shift_offset == i {
assert_eq!(eval_mvp, F::ONE);
} else {
assert_eq!(eval_mvp, F::ZERO);
}
});
});
}
fn test_logical_right_shift_functionality_help<F: TowerField>(
block_size: usize,
left_shift_offset: usize,
) {
let shift_variant = ShiftVariant::LogicalRight;
(0..(1 << block_size)).for_each(|i| {
let r = decompose_index_to_hypercube_point::<F>(block_size, i);
let shift_r_mvp =
ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
(0..(1 << block_size)).for_each(|j| {
let x = decompose_index_to_hypercube_point::<F>(block_size, j);
let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
if j >= left_shift_offset && j - left_shift_offset == i {
assert_eq!(eval_mvp, F::ONE);
} else {
assert_eq!(eval_mvp, F::ZERO);
}
});
});
}
#[test]
fn test_circular_left_shift_functionality() {
for block_size in 3..5 {
for right_shift_offset in [
1,
3,
(1 << block_size) - 1,
(1 << block_size) - 2,
(1 << (block_size - 1)),
] {
test_circular_left_shift_functionality_help::<BinaryField32b>(
block_size,
right_shift_offset,
);
}
}
}
#[test]
fn test_logical_left_shift_functionality() {
for block_size in 3..5 {
for right_shift_offset in [
1,
3,
(1 << block_size) - 1,
(1 << block_size) - 2,
(1 << (block_size - 1)),
] {
test_logical_left_shift_functionality_help::<BinaryField32b>(
block_size,
right_shift_offset,
);
}
}
}
#[test]
fn test_logical_right_shift_functionality() {
for block_size in 3..5 {
for left_shift_offset in [
1,
3,
(1 << block_size) - 1,
(1 << block_size) - 2,
(1 << (block_size - 1)),
] {
test_logical_right_shift_functionality_help::<BinaryField32b>(
block_size,
left_shift_offset,
);
}
}
}
}