1use binius_field::{Field, PackedField, TowerField, util::eq};
4use binius_math::MultilinearExtension;
5use binius_utils::bail;
6
7use crate::{
8 oracle::ShiftVariant,
9 polynomial::{Error, MultivariatePoly},
10};
11
12#[derive(Debug, Clone)]
89pub struct ShiftIndPartialEval<F: Field> {
90 block_size: usize,
92 shift_offset: usize,
94 shift_variant: ShiftVariant,
96 r: Vec<F>,
99}
100
101impl<F: Field> ShiftIndPartialEval<F> {
102 pub fn new(
103 block_size: usize,
104 shift_offset: usize,
105 shift_variant: ShiftVariant,
106 r: Vec<F>,
107 ) -> Result<Self, Error> {
108 assert_valid_shift_ind_args(block_size, shift_offset, &r)?;
109 Ok(Self {
110 block_size,
111 shift_offset,
112 r,
113 shift_variant,
114 })
115 }
116
117 fn multilinear_extension_circular<P>(&self) -> Result<MultilinearExtension<P>, Error>
118 where
119 P: PackedField<Scalar = F>,
120 {
121 let (ps, pps) =
122 partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
123 let values = ps
124 .iter()
125 .zip(pps)
126 .map(|(p, pp)| *p + pp)
127 .collect::<Vec<_>>();
128 Ok(MultilinearExtension::new(self.block_size, values)?)
129 }
130
131 fn multilinear_extension_logical_left<P>(&self) -> Result<MultilinearExtension<P>, Error>
132 where
133 P: PackedField<Scalar = F>,
134 {
135 let (ps, _) =
136 partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
137 Ok(MultilinearExtension::new(self.block_size, ps)?)
138 }
139
140 fn multilinear_extension_logical_right<P>(&self) -> Result<MultilinearExtension<P>, Error>
141 where
142 P: PackedField<Scalar = F>,
143 {
144 let right_shift_offset = get_left_shift_offset(self.block_size, self.shift_offset);
145 let (_, pps) =
146 partial_evaluate_hypercube_impl::<P>(self.block_size, right_shift_offset, &self.r)?;
147 Ok(MultilinearExtension::new(self.block_size, pps)?)
148 }
149
150 pub fn multilinear_extension<P>(&self) -> Result<MultilinearExtension<P>, Error>
153 where
154 P: PackedField<Scalar = F>,
155 {
156 match self.shift_variant {
157 ShiftVariant::CircularLeft => self.multilinear_extension_circular(),
158 ShiftVariant::LogicalLeft => self.multilinear_extension_logical_left(),
159 ShiftVariant::LogicalRight => self.multilinear_extension_logical_right(),
160 }
161 }
162
163 fn evaluate_at_point(&self, x: &[F]) -> Result<F, Error> {
165 if x.len() != self.block_size {
166 bail!(Error::IncorrectQuerySize {
167 expected: self.block_size,
168 actual: x.len(),
169 });
170 }
171
172 let left_shift_offset = match self.shift_variant {
173 ShiftVariant::CircularLeft | ShiftVariant::LogicalLeft => self.shift_offset,
174 ShiftVariant::LogicalRight => get_left_shift_offset(self.block_size, self.shift_offset),
175 };
176
177 let (p_res, pp_res) =
178 evaluate_shift_ind_help(self.block_size, left_shift_offset, x, &self.r)?;
179
180 match self.shift_variant {
181 ShiftVariant::CircularLeft => Ok(p_res + pp_res),
182 ShiftVariant::LogicalLeft => Ok(p_res),
183 ShiftVariant::LogicalRight => Ok(pp_res),
184 }
185 }
186}
187
188impl<F: TowerField> MultivariatePoly<F> for ShiftIndPartialEval<F> {
189 fn n_vars(&self) -> usize {
190 self.block_size
191 }
192
193 fn degree(&self) -> usize {
194 self.block_size
195 }
196
197 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
198 self.evaluate_at_point(query)
199 }
200
201 fn binary_tower_level(&self) -> usize {
202 F::TOWER_LEVEL
203 }
204}
205
206const fn get_left_shift_offset(block_size: usize, right_shift_offset: usize) -> usize {
208 (1 << block_size) - right_shift_offset
209}
210
211fn assert_valid_shift_ind_args<F: Field>(
213 block_size: usize,
214 shift_offset: usize,
215 partial_query_point: &[F],
216) -> Result<(), Error> {
217 if partial_query_point.len() != block_size {
218 bail!(Error::IncorrectQuerySize {
219 expected: block_size,
220 actual: partial_query_point.len(),
221 });
222 }
223 if shift_offset == 0 || shift_offset >= 1 << block_size {
224 bail!(Error::InvalidShiftOffset {
225 max_shift_offset: (1 << block_size) - 1,
226 shift_offset,
227 });
228 }
229
230 Ok(())
231}
232
233fn evaluate_shift_ind_help<F: Field>(
238 block_size: usize,
239 shift_offset: usize,
240 x: &[F],
241 y: &[F],
242) -> Result<(F, F), Error> {
243 if x.len() != block_size {
244 bail!(Error::IncorrectQuerySize {
245 expected: block_size,
246 actual: x.len(),
247 });
248 }
249 assert_valid_shift_ind_args(block_size, shift_offset, y)?;
250
251 let (mut s_ind_p, mut s_ind_pp) = (F::ONE, F::ZERO);
252 let (mut temp_p, mut temp_pp) = (F::default(), F::default());
253 (0..block_size).for_each(|k| {
254 let o_k = shift_offset >> k;
255 let product = x[k] * y[k];
256 if o_k % 2 == 1 {
257 temp_p = (y[k] - product) * s_ind_p;
258 temp_pp = (x[k] - product) * s_ind_p + eq(x[k], y[k]) * s_ind_pp;
259 } else {
260 temp_p = eq(x[k], y[k]) * s_ind_p + (y[k] - product) * s_ind_pp;
261 temp_pp = (x[k] - product) * s_ind_pp;
262 }
263 s_ind_p = temp_p;
265 s_ind_pp = temp_pp;
266 });
267
268 Ok((s_ind_p, s_ind_pp))
269}
270
271fn partial_evaluate_hypercube_impl<P: PackedField>(
277 block_size: usize,
278 shift_offset: usize,
279 r: &[P::Scalar],
280) -> Result<(Vec<P>, Vec<P>), Error> {
281 assert_valid_shift_ind_args(block_size, shift_offset, r)?;
282 let mut s_ind_p = vec![P::one(); 1 << block_size.saturating_sub(P::LOG_WIDTH)];
283 let mut s_ind_pp = vec![P::zero(); 1 << block_size.saturating_sub(P::LOG_WIDTH)];
284
285 partial_evaluate_hypercube_with_buffers_within_packed(
286 block_size.min(P::LOG_WIDTH),
287 shift_offset,
288 r,
289 &mut s_ind_p[0],
290 &mut s_ind_pp[0],
291 );
292 if block_size > P::LOG_WIDTH {
293 partial_evaluate_hypercube_with_buffers(
294 block_size - P::LOG_WIDTH,
295 shift_offset >> P::LOG_WIDTH,
296 &r[P::LOG_WIDTH..],
297 &mut s_ind_p,
298 &mut s_ind_pp,
299 );
300 }
301
302 Ok((s_ind_p, s_ind_pp))
303}
304
305#[inline(always)]
306fn shift_offset_odd_step<P: PackedField>(s_ind_p: &P, s_ind_pp: &P, r_k: P) -> (P, P, P) {
307 let mut pp_lo = *s_ind_pp;
308 let mut pp_hi = pp_lo * r_k;
309
310 pp_lo -= pp_hi;
311
312 let p_lo = *s_ind_p;
313 let p_hi = p_lo * r_k;
314 pp_hi += p_lo - p_hi; (pp_lo, pp_hi, p_hi)
317}
318
319#[inline(always)]
320fn shift_offset_even_step<P: PackedField>(s_ind_p: &P, s_ind_pp: &P, r_k: P) -> (P, P, P) {
321 let mut p_lo = *s_ind_p;
322 let p_hi = p_lo * r_k;
323 p_lo -= p_hi;
324
325 let pp_lo = *s_ind_pp;
326 let pp_hi = pp_lo * (P::one() - r_k);
327 p_lo += pp_lo - pp_hi;
328
329 (p_lo, p_hi, pp_hi)
330}
331
332fn partial_evaluate_hypercube_with_buffers<P: PackedField>(
333 block_size: usize,
334 shift_offset: usize,
335 r: &[P::Scalar],
336 s_ind_p: &mut [P],
337 s_ind_pp: &mut [P],
338) {
339 for k in 0..block_size {
340 let r_k = P::broadcast(r[k]);
342
343 if (shift_offset >> k) % 2 == 1 {
345 for i in 0..(1 << k) {
346 let (pp_lo, pp_hi, p_hi) = shift_offset_odd_step(&s_ind_p[i], &s_ind_pp[i], r_k);
347
348 s_ind_pp[i] = pp_lo;
349 s_ind_pp[1 << k | i] = pp_hi;
350
351 s_ind_p[i] = p_hi;
352 s_ind_p[1 << k | i] = P::zero(); }
354 } else {
355 for i in 0..(1 << k) {
356 let (p_lo, p_hi, pp_hi) = shift_offset_even_step(&s_ind_p[i], &s_ind_pp[i], r_k);
357
358 s_ind_p[i] = p_lo;
359 s_ind_p[1 << k | i] = p_hi;
360
361 s_ind_pp[i] = P::zero(); s_ind_pp[1 << k | i] = pp_hi;
363 }
364 }
365 }
366}
367
368#[allow(clippy::needless_range_loop)]
372fn partial_evaluate_hypercube_with_buffers_within_packed<P: PackedField>(
373 block_size: usize,
374 shift_offset: usize,
375 r: &[P::Scalar],
376 s_ind_p: &mut P,
377 s_ind_pp: &mut P,
378) {
379 debug_assert!(block_size <= P::LOG_WIDTH);
380
381 for k in 0..block_size {
382 let r_k = P::broadcast(r[k]);
384
385 if (shift_offset >> k) % 2 == 1 {
387 let (pp_lo, pp_hi, p_hi) = shift_offset_odd_step(s_ind_p, s_ind_pp, r_k);
388
389 *s_ind_pp = PackedField::from_fn(|i| {
390 if i < 1 << k {
391 unsafe { pp_lo.get_unchecked(i) }
392 } else if i < 1 << (k + 1) {
393 unsafe { pp_hi.get_unchecked(i - (1 << k)) }
394 } else {
395 unsafe { s_ind_pp.get_unchecked(i) }
396 }
397 });
398 *s_ind_p = PackedField::from_fn(|i| {
399 if i < 1 << k {
400 unsafe { p_hi.get_unchecked(i) }
401 } else if i < 1 << (k + 1) {
402 Field::ZERO
403 } else {
404 unsafe { s_ind_p.get_unchecked(i) }
405 }
406 });
407 } else {
408 let (p_lo, p_hi, pp_hi) = shift_offset_even_step(s_ind_p, s_ind_pp, r_k);
409
410 *s_ind_p = PackedField::from_fn(|i| {
411 if i < 1 << k {
412 unsafe { p_lo.get_unchecked(i) }
413 } else if i < 1 << (k + 1) {
414 unsafe { p_hi.get_unchecked(i - (1 << k)) }
415 } else {
416 unsafe { s_ind_p.get_unchecked(i) }
417 }
418 });
419 *s_ind_pp = PackedField::from_fn(|i| {
420 if i < 1 << k {
421 Field::ZERO
422 } else if i < 1 << (k + 1) {
423 unsafe { pp_hi.get_unchecked(i - (1 << k)) }
424 } else {
425 unsafe { s_ind_pp.get_unchecked(i) }
426 }
427 });
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use std::iter::repeat_with;
435
436 use binius_field::{BinaryField32b, PackedBinaryField4x32b};
437 use binius_hal::{ComputationBackendExt, make_portable_backend};
438 use rand::{SeedableRng, rngs::StdRng};
439
440 use super::*;
441 use crate::polynomial::test_utils::decompose_index_to_hypercube_point;
442
443 fn test_circular_left_shift_consistency_help<F: TowerField, P: PackedField<Scalar = F>>(
445 block_size: usize,
446 right_shift_offset: usize,
447 ) {
448 let mut rng = StdRng::seed_from_u64(0);
449 let backend = make_portable_backend();
450 let r = repeat_with(|| F::random(&mut rng))
451 .take(block_size)
452 .collect::<Vec<_>>();
453 let eval_point = &repeat_with(|| F::random(&mut rng))
454 .take(block_size)
455 .collect::<Vec<_>>();
456
457 let shift_variant = ShiftVariant::CircularLeft;
459 let shift_r_mvp =
460 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
461 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
462
463 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
465 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
466 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
467
468 assert_eq!(eval_mle, eval_mvp);
470 }
471
472 fn test_logical_left_shift_consistency_help<F: TowerField, P: PackedField<Scalar = F>>(
473 block_size: usize,
474 right_shift_offset: usize,
475 ) {
476 let mut rng = StdRng::seed_from_u64(0);
477 let backend = make_portable_backend();
478 let r = repeat_with(|| F::random(&mut rng))
479 .take(block_size)
480 .collect::<Vec<_>>();
481 let eval_point = &repeat_with(|| F::random(&mut rng))
482 .take(block_size)
483 .collect::<Vec<_>>();
484
485 let shift_variant = ShiftVariant::LogicalLeft;
487 let shift_r_mvp =
488 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
489 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
490
491 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
493 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
494 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
495
496 assert_eq!(eval_mle, eval_mvp);
498 }
499
500 fn test_logical_right_shift_consistency_help<F: TowerField, P: PackedField<Scalar = F>>(
501 block_size: usize,
502 left_shift_offset: usize,
503 ) {
504 let mut rng = StdRng::seed_from_u64(0);
505 let backend = make_portable_backend();
506 let r = repeat_with(|| F::random(&mut rng))
507 .take(block_size)
508 .collect::<Vec<_>>();
509 let eval_point = &repeat_with(|| F::random(&mut rng))
510 .take(block_size)
511 .collect::<Vec<_>>();
512
513 let shift_variant = ShiftVariant::LogicalRight;
515 let shift_r_mvp =
516 ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
517 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
518
519 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
521 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
522 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
523
524 assert_eq!(eval_mle, eval_mvp);
526 }
527
528 #[test]
529 fn test_circular_left_shift_consistency_schwartz_zippel() {
530 for block_size in 2..=10 {
531 for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
532 test_circular_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
533 block_size,
534 right_shift_offset,
535 );
536 }
537 }
538 }
539
540 #[test]
541 fn test_logical_left_shift_consistency_schwartz_zippel() {
542 for block_size in 2..=10 {
543 for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
544 test_logical_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
545 block_size,
546 right_shift_offset,
547 );
548 }
549 }
550 }
551
552 #[test]
553 fn test_logical_right_shift_consistency_schwartz_zippel() {
554 for block_size in 2..=10 {
555 for left_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
556 test_logical_right_shift_consistency_help::<_, PackedBinaryField4x32b>(
557 block_size,
558 left_shift_offset,
559 );
560 }
561 }
562 }
563
564 fn test_circular_left_shift_functionality_help<F: TowerField>(
566 block_size: usize,
567 right_shift_offset: usize,
568 ) {
569 let shift_variant = ShiftVariant::CircularLeft;
570 (0..(1 << block_size)).for_each(|i| {
571 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
572 let shift_r_mvp =
573 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
574 (0..(1 << block_size)).for_each(|j| {
575 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
576 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
577 if (j + right_shift_offset) % (1 << block_size) == i {
578 assert_eq!(eval_mvp, F::ONE);
579 } else {
580 assert_eq!(eval_mvp, F::ZERO);
581 }
582 });
583 });
584 }
585 fn test_logical_left_shift_functionality_help<F: TowerField>(
586 block_size: usize,
587 right_shift_offset: usize,
588 ) {
589 let shift_variant = ShiftVariant::LogicalLeft;
590 (0..(1 << block_size)).for_each(|i| {
591 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
592 let shift_r_mvp =
593 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
594 (0..(1 << block_size)).for_each(|j| {
595 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
596 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
597 if j + right_shift_offset == i {
598 assert_eq!(eval_mvp, F::ONE);
599 } else {
600 assert_eq!(eval_mvp, F::ZERO);
601 }
602 });
603 });
604 }
605
606 fn test_logical_right_shift_functionality_help<F: TowerField>(
607 block_size: usize,
608 left_shift_offset: usize,
609 ) {
610 let shift_variant = ShiftVariant::LogicalRight;
611 (0..(1 << block_size)).for_each(|i| {
612 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
613 let shift_r_mvp =
614 ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
615 (0..(1 << block_size)).for_each(|j| {
616 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
617 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
618 if j >= left_shift_offset && j - left_shift_offset == i {
619 assert_eq!(eval_mvp, F::ONE);
620 } else {
621 assert_eq!(eval_mvp, F::ZERO);
622 }
623 });
624 });
625 }
626
627 #[test]
628 fn test_circular_left_shift_functionality() {
629 for block_size in 3..5 {
630 for right_shift_offset in [
631 1,
632 3,
633 (1 << block_size) - 1,
634 (1 << block_size) - 2,
635 (1 << (block_size - 1)),
636 ] {
637 test_circular_left_shift_functionality_help::<BinaryField32b>(
638 block_size,
639 right_shift_offset,
640 );
641 }
642 }
643 }
644 #[test]
645 fn test_logical_left_shift_functionality() {
646 for block_size in 3..5 {
647 for right_shift_offset in [
648 1,
649 3,
650 (1 << block_size) - 1,
651 (1 << block_size) - 2,
652 (1 << (block_size - 1)),
653 ] {
654 test_logical_left_shift_functionality_help::<BinaryField32b>(
655 block_size,
656 right_shift_offset,
657 );
658 }
659 }
660 }
661 #[test]
662 fn test_logical_right_shift_functionality() {
663 for block_size in 3..5 {
664 for left_shift_offset in [
665 1,
666 3,
667 (1 << block_size) - 1,
668 (1 << block_size) - 2,
669 (1 << (block_size - 1)),
670 ] {
671 test_logical_right_shift_functionality_help::<BinaryField32b>(
672 block_size,
673 left_shift_offset,
674 );
675 }
676 }
677 }
678}