1use binius_field::{util::eq, Field, PackedField, TowerField};
4use binius_math::MultilinearExtension;
5use binius_utils::bail;
6
7use crate::{
8 oracle::ShiftVariant,
9 polynomial::{Error, MultivariatePoly},
10};
11
12#[derive(Debug, Clone)]
85pub struct ShiftIndPartialEval<F: Field> {
86 block_size: usize,
88 shift_offset: usize,
90 shift_variant: ShiftVariant,
92 r: Vec<F>,
95}
96
97impl<F: Field> ShiftIndPartialEval<F> {
98 pub fn new(
99 block_size: usize,
100 shift_offset: usize,
101 shift_variant: ShiftVariant,
102 r: Vec<F>,
103 ) -> Result<Self, Error> {
104 assert_valid_shift_ind_args(block_size, shift_offset, &r)?;
105 Ok(Self {
106 block_size,
107 shift_offset,
108 r,
109 shift_variant,
110 })
111 }
112
113 fn multilinear_extension_circular<P>(&self) -> Result<MultilinearExtension<P>, Error>
114 where
115 P: PackedField<Scalar = F>,
116 {
117 let (ps, pps) =
118 partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
119 let values = ps
120 .iter()
121 .zip(pps)
122 .map(|(p, pp)| *p + pp)
123 .collect::<Vec<_>>();
124 Ok(MultilinearExtension::new(self.block_size, values)?)
125 }
126
127 fn multilinear_extension_logical_left<P>(&self) -> Result<MultilinearExtension<P>, Error>
128 where
129 P: PackedField<Scalar = F>,
130 {
131 let (ps, _) =
132 partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
133 Ok(MultilinearExtension::new(self.block_size, ps)?)
134 }
135
136 fn multilinear_extension_logical_right<P>(&self) -> Result<MultilinearExtension<P>, Error>
137 where
138 P: PackedField<Scalar = F>,
139 {
140 let right_shift_offset = get_left_shift_offset(self.block_size, self.shift_offset);
141 let (_, pps) =
142 partial_evaluate_hypercube_impl::<P>(self.block_size, right_shift_offset, &self.r)?;
143 Ok(MultilinearExtension::new(self.block_size, pps)?)
144 }
145
146 pub fn multilinear_extension<P>(&self) -> Result<MultilinearExtension<P>, Error>
149 where
150 P: PackedField<Scalar = F>,
151 {
152 match self.shift_variant {
153 ShiftVariant::CircularLeft => self.multilinear_extension_circular(),
154 ShiftVariant::LogicalLeft => self.multilinear_extension_logical_left(),
155 ShiftVariant::LogicalRight => self.multilinear_extension_logical_right(),
156 }
157 }
158
159 fn evaluate_at_point(&self, x: &[F]) -> Result<F, Error> {
161 if x.len() != self.block_size {
162 bail!(Error::IncorrectQuerySize {
163 expected: self.block_size,
164 });
165 }
166
167 let left_shift_offset = match self.shift_variant {
168 ShiftVariant::CircularLeft | ShiftVariant::LogicalLeft => self.shift_offset,
169 ShiftVariant::LogicalRight => get_left_shift_offset(self.block_size, self.shift_offset),
170 };
171
172 let (p_res, pp_res) =
173 evaluate_shift_ind_help(self.block_size, left_shift_offset, x, &self.r)?;
174
175 match self.shift_variant {
176 ShiftVariant::CircularLeft => Ok(p_res + pp_res),
177 ShiftVariant::LogicalLeft => Ok(p_res),
178 ShiftVariant::LogicalRight => Ok(pp_res),
179 }
180 }
181}
182
183impl<F: TowerField> MultivariatePoly<F> for ShiftIndPartialEval<F> {
184 fn n_vars(&self) -> usize {
185 self.block_size
186 }
187
188 fn degree(&self) -> usize {
189 self.block_size
190 }
191
192 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
193 self.evaluate_at_point(query)
194 }
195
196 fn binary_tower_level(&self) -> usize {
197 F::TOWER_LEVEL
198 }
199}
200
201const fn get_left_shift_offset(block_size: usize, right_shift_offset: usize) -> usize {
203 (1 << block_size) - right_shift_offset
204}
205
206fn assert_valid_shift_ind_args<F: Field>(
208 block_size: usize,
209 shift_offset: usize,
210 partial_query_point: &[F],
211) -> Result<(), Error> {
212 if partial_query_point.len() != block_size {
213 bail!(Error::IncorrectQuerySize {
214 expected: block_size,
215 });
216 }
217 if shift_offset == 0 || shift_offset >= 1 << block_size {
218 bail!(Error::InvalidShiftOffset {
219 max_shift_offset: (1 << block_size) - 1,
220 shift_offset,
221 });
222 }
223
224 Ok(())
225}
226
227fn evaluate_shift_ind_help<F: Field>(
232 block_size: usize,
233 shift_offset: usize,
234 x: &[F],
235 y: &[F],
236) -> Result<(F, F), Error> {
237 if x.len() != block_size {
238 bail!(Error::IncorrectQuerySize {
239 expected: block_size,
240 });
241 }
242 assert_valid_shift_ind_args(block_size, shift_offset, y)?;
243
244 let (mut s_ind_p, mut s_ind_pp) = (F::ONE, F::ZERO);
245 let (mut temp_p, mut temp_pp) = (F::default(), F::default());
246 (0..block_size).for_each(|k| {
247 let o_k = shift_offset >> k;
248 let product = x[k] * y[k];
249 if o_k % 2 == 1 {
250 temp_p = (y[k] - product) * s_ind_p;
251 temp_pp = (x[k] - product) * s_ind_p + eq(x[k], y[k]) * s_ind_pp;
252 } else {
253 temp_p = eq(x[k], y[k]) * s_ind_p + (y[k] - product) * s_ind_pp;
254 temp_pp = (x[k] - product) * s_ind_pp;
255 }
256 s_ind_p = temp_p;
258 s_ind_pp = temp_pp;
259 });
260
261 Ok((s_ind_p, s_ind_pp))
262}
263
264fn partial_evaluate_hypercube_impl<P: PackedField>(
270 block_size: usize,
271 shift_offset: usize,
272 r: &[P::Scalar],
273) -> Result<(Vec<P>, Vec<P>), Error> {
274 assert_valid_shift_ind_args(block_size, shift_offset, r)?;
275 let mut s_ind_p = vec![P::one(); 1 << block_size.saturating_sub(P::LOG_WIDTH)];
276 let mut s_ind_pp = vec![P::zero(); 1 << block_size.saturating_sub(P::LOG_WIDTH)];
277
278 partial_evaluate_hypercube_with_buffers_within_packed(
279 block_size.min(P::LOG_WIDTH),
280 shift_offset,
281 r,
282 &mut s_ind_p[0],
283 &mut s_ind_pp[0],
284 );
285 if block_size > P::LOG_WIDTH {
286 partial_evaluate_hypercube_with_buffers(
287 block_size - P::LOG_WIDTH,
288 shift_offset >> P::LOG_WIDTH,
289 &r[P::LOG_WIDTH..],
290 &mut s_ind_p,
291 &mut s_ind_pp,
292 );
293 }
294
295 Ok((s_ind_p, s_ind_pp))
296}
297
298#[inline(always)]
299fn shift_offset_odd_step<P: PackedField>(s_ind_p: &P, s_ind_pp: &P, r_k: P) -> (P, P, P) {
300 let mut pp_lo = *s_ind_pp;
301 let mut pp_hi = pp_lo * r_k;
302
303 pp_lo -= pp_hi;
304
305 let p_lo = *s_ind_p;
306 let p_hi = p_lo * r_k;
307 pp_hi += p_lo - p_hi; (pp_lo, pp_hi, p_hi)
310}
311
312#[inline(always)]
313fn shift_offset_even_step<P: PackedField>(s_ind_p: &P, s_ind_pp: &P, r_k: P) -> (P, P, P) {
314 let mut p_lo = *s_ind_p;
315 let p_hi = p_lo * r_k;
316 p_lo -= p_hi;
317
318 let pp_lo = *s_ind_pp;
319 let pp_hi = pp_lo * (P::one() - r_k);
320 p_lo += pp_lo - pp_hi;
321
322 (p_lo, p_hi, pp_hi)
323}
324
325fn partial_evaluate_hypercube_with_buffers<P: PackedField>(
326 block_size: usize,
327 shift_offset: usize,
328 r: &[P::Scalar],
329 s_ind_p: &mut [P],
330 s_ind_pp: &mut [P],
331) {
332 for k in 0..block_size {
333 let r_k = P::broadcast(r[k]);
335
336 if (shift_offset >> k) % 2 == 1 {
338 for i in 0..(1 << k) {
339 let (pp_lo, pp_hi, p_hi) = shift_offset_odd_step(&s_ind_p[i], &s_ind_pp[i], r_k);
340
341 s_ind_pp[i] = pp_lo;
342 s_ind_pp[1 << k | i] = pp_hi;
343
344 s_ind_p[i] = p_hi;
345 s_ind_p[1 << k | i] = P::zero(); }
347 } else {
348 for i in 0..(1 << k) {
349 let (p_lo, p_hi, pp_hi) = shift_offset_even_step(&s_ind_p[i], &s_ind_pp[i], r_k);
350
351 s_ind_p[i] = p_lo;
352 s_ind_p[1 << k | i] = p_hi;
353
354 s_ind_pp[i] = P::zero(); s_ind_pp[1 << k | i] = pp_hi;
356 }
357 }
358 }
359}
360
361#[allow(clippy::needless_range_loop)]
365fn partial_evaluate_hypercube_with_buffers_within_packed<P: PackedField>(
366 block_size: usize,
367 shift_offset: usize,
368 r: &[P::Scalar],
369 s_ind_p: &mut P,
370 s_ind_pp: &mut P,
371) {
372 debug_assert!(block_size <= P::LOG_WIDTH);
373
374 for k in 0..block_size {
375 let r_k = P::broadcast(r[k]);
377
378 if (shift_offset >> k) % 2 == 1 {
380 let (pp_lo, pp_hi, p_hi) = shift_offset_odd_step(s_ind_p, s_ind_pp, r_k);
381
382 *s_ind_pp = PackedField::from_fn(|i| {
383 if i < 1 << k {
384 unsafe { pp_lo.get_unchecked(i) }
385 } else if i < 1 << (k + 1) {
386 unsafe { pp_hi.get_unchecked(i - (1 << k)) }
387 } else {
388 unsafe { s_ind_pp.get_unchecked(i) }
389 }
390 });
391 *s_ind_p = PackedField::from_fn(|i| {
392 if i < 1 << k {
393 unsafe { p_hi.get_unchecked(i) }
394 } else if i < 1 << (k + 1) {
395 Field::ZERO
396 } else {
397 unsafe { s_ind_p.get_unchecked(i) }
398 }
399 });
400 } else {
401 let (p_lo, p_hi, pp_hi) = shift_offset_even_step(s_ind_p, s_ind_pp, r_k);
402
403 *s_ind_p = PackedField::from_fn(|i| {
404 if i < 1 << k {
405 unsafe { p_lo.get_unchecked(i) }
406 } else if i < 1 << (k + 1) {
407 unsafe { p_hi.get_unchecked(i - (1 << k)) }
408 } else {
409 unsafe { s_ind_p.get_unchecked(i) }
410 }
411 });
412 *s_ind_pp = PackedField::from_fn(|i| {
413 if i < 1 << k {
414 Field::ZERO
415 } else if i < 1 << (k + 1) {
416 unsafe { pp_hi.get_unchecked(i - (1 << k)) }
417 } else {
418 unsafe { s_ind_pp.get_unchecked(i) }
419 }
420 });
421 }
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use std::iter::repeat_with;
428
429 use binius_field::{BinaryField32b, PackedBinaryField4x32b};
430 use binius_hal::{make_portable_backend, ComputationBackendExt};
431 use rand::{rngs::StdRng, SeedableRng};
432
433 use super::*;
434 use crate::polynomial::test_utils::decompose_index_to_hypercube_point;
435
436 fn test_circular_left_shift_consistency_help<F: TowerField, P: PackedField<Scalar = F>>(
438 block_size: usize,
439 right_shift_offset: usize,
440 ) {
441 let mut rng = StdRng::seed_from_u64(0);
442 let backend = make_portable_backend();
443 let r = repeat_with(|| F::random(&mut rng))
444 .take(block_size)
445 .collect::<Vec<_>>();
446 let eval_point = &repeat_with(|| F::random(&mut rng))
447 .take(block_size)
448 .collect::<Vec<_>>();
449
450 let shift_variant = ShiftVariant::CircularLeft;
452 let shift_r_mvp =
453 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
454 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
455
456 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
458 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
459 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
460
461 assert_eq!(eval_mle, eval_mvp);
463 }
464
465 fn test_logical_left_shift_consistency_help<F: TowerField, P: PackedField<Scalar = F>>(
466 block_size: usize,
467 right_shift_offset: usize,
468 ) {
469 let mut rng = StdRng::seed_from_u64(0);
470 let backend = make_portable_backend();
471 let r = repeat_with(|| F::random(&mut rng))
472 .take(block_size)
473 .collect::<Vec<_>>();
474 let eval_point = &repeat_with(|| F::random(&mut rng))
475 .take(block_size)
476 .collect::<Vec<_>>();
477
478 let shift_variant = ShiftVariant::LogicalLeft;
480 let shift_r_mvp =
481 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
482 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
483
484 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
486 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
487 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
488
489 assert_eq!(eval_mle, eval_mvp);
491 }
492
493 fn test_logical_right_shift_consistency_help<F: TowerField, P: PackedField<Scalar = F>>(
494 block_size: usize,
495 left_shift_offset: usize,
496 ) {
497 let mut rng = StdRng::seed_from_u64(0);
498 let backend = make_portable_backend();
499 let r = repeat_with(|| F::random(&mut rng))
500 .take(block_size)
501 .collect::<Vec<_>>();
502 let eval_point = &repeat_with(|| F::random(&mut rng))
503 .take(block_size)
504 .collect::<Vec<_>>();
505
506 let shift_variant = ShiftVariant::LogicalRight;
508 let shift_r_mvp =
509 ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
510 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
511
512 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
514 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
515 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
516
517 assert_eq!(eval_mle, eval_mvp);
519 }
520
521 #[test]
522 fn test_circular_left_shift_consistency_schwartz_zippel() {
523 for block_size in 2..=10 {
524 for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
525 test_circular_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
526 block_size,
527 right_shift_offset,
528 );
529 }
530 }
531 }
532
533 #[test]
534 fn test_logical_left_shift_consistency_schwartz_zippel() {
535 for block_size in 2..=10 {
536 for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
537 test_logical_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
538 block_size,
539 right_shift_offset,
540 );
541 }
542 }
543 }
544
545 #[test]
546 fn test_logical_right_shift_consistency_schwartz_zippel() {
547 for block_size in 2..=10 {
548 for left_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
549 test_logical_right_shift_consistency_help::<_, PackedBinaryField4x32b>(
550 block_size,
551 left_shift_offset,
552 );
553 }
554 }
555 }
556
557 fn test_circular_left_shift_functionality_help<F: TowerField>(
559 block_size: usize,
560 right_shift_offset: usize,
561 ) {
562 let shift_variant = ShiftVariant::CircularLeft;
563 (0..(1 << block_size)).for_each(|i| {
564 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
565 let shift_r_mvp =
566 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
567 (0..(1 << block_size)).for_each(|j| {
568 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
569 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
570 if (j + right_shift_offset) % (1 << block_size) == i {
571 assert_eq!(eval_mvp, F::ONE);
572 } else {
573 assert_eq!(eval_mvp, F::ZERO);
574 }
575 });
576 });
577 }
578 fn test_logical_left_shift_functionality_help<F: TowerField>(
579 block_size: usize,
580 right_shift_offset: usize,
581 ) {
582 let shift_variant = ShiftVariant::LogicalLeft;
583 (0..(1 << block_size)).for_each(|i| {
584 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
585 let shift_r_mvp =
586 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
587 (0..(1 << block_size)).for_each(|j| {
588 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
589 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
590 if j + right_shift_offset == i {
591 assert_eq!(eval_mvp, F::ONE);
592 } else {
593 assert_eq!(eval_mvp, F::ZERO);
594 }
595 });
596 });
597 }
598
599 fn test_logical_right_shift_functionality_help<F: TowerField>(
600 block_size: usize,
601 left_shift_offset: usize,
602 ) {
603 let shift_variant = ShiftVariant::LogicalRight;
604 (0..(1 << block_size)).for_each(|i| {
605 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
606 let shift_r_mvp =
607 ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
608 (0..(1 << block_size)).for_each(|j| {
609 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
610 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
611 if j >= left_shift_offset && j - left_shift_offset == i {
612 assert_eq!(eval_mvp, F::ONE);
613 } else {
614 assert_eq!(eval_mvp, F::ZERO);
615 }
616 });
617 });
618 }
619
620 #[test]
621 fn test_circular_left_shift_functionality() {
622 for block_size in 3..5 {
623 for right_shift_offset in [
624 1,
625 3,
626 (1 << block_size) - 1,
627 (1 << block_size) - 2,
628 (1 << (block_size - 1)),
629 ] {
630 test_circular_left_shift_functionality_help::<BinaryField32b>(
631 block_size,
632 right_shift_offset,
633 );
634 }
635 }
636 }
637 #[test]
638 fn test_logical_left_shift_functionality() {
639 for block_size in 3..5 {
640 for right_shift_offset in [
641 1,
642 3,
643 (1 << block_size) - 1,
644 (1 << block_size) - 2,
645 (1 << (block_size - 1)),
646 ] {
647 test_logical_left_shift_functionality_help::<BinaryField32b>(
648 block_size,
649 right_shift_offset,
650 );
651 }
652 }
653 }
654 #[test]
655 fn test_logical_right_shift_functionality() {
656 for block_size in 3..5 {
657 for left_shift_offset in [
658 1,
659 3,
660 (1 << block_size) - 1,
661 (1 << block_size) - 2,
662 (1 << (block_size - 1)),
663 ] {
664 test_logical_right_shift_functionality_help::<BinaryField32b>(
665 block_size,
666 left_shift_offset,
667 );
668 }
669 }
670 }
671}