1use binius_field::{util::eq, Field, PackedFieldIndexable, TowerField};
4use binius_math::MultilinearExtension;
5use binius_utils::bail;
6
7use crate::{
8 oracle::ShiftVariant,
9 polynomial::{Error, MultivariatePoly},
10};
11
12#[derive(Debug, Clone)]
85pub struct ShiftIndPartialEval<F: Field> {
86 block_size: usize,
88 shift_offset: usize,
90 shift_variant: ShiftVariant,
92 r: Vec<F>,
95}
96
97impl<F: Field> ShiftIndPartialEval<F> {
98 pub fn new(
99 block_size: usize,
100 shift_offset: usize,
101 shift_variant: ShiftVariant,
102 r: Vec<F>,
103 ) -> Result<Self, Error> {
104 assert_valid_shift_ind_args(block_size, shift_offset, &r)?;
105 Ok(Self {
106 block_size,
107 shift_offset,
108 r,
109 shift_variant,
110 })
111 }
112
113 fn multilinear_extension_circular<P>(&self) -> Result<MultilinearExtension<P>, Error>
114 where
115 P: PackedFieldIndexable<Scalar = F>,
116 {
117 let (ps, pps) =
118 partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
119 let values = ps
120 .iter()
121 .zip(pps)
122 .map(|(p, pp)| *p + pp)
123 .collect::<Vec<_>>();
124 Ok(MultilinearExtension::from_values(values)?)
125 }
126
127 fn multilinear_extension_logical_left<P>(&self) -> Result<MultilinearExtension<P>, Error>
128 where
129 P: PackedFieldIndexable<Scalar = F>,
130 {
131 let (ps, _) =
132 partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
133 Ok(MultilinearExtension::from_values(ps)?)
134 }
135
136 fn multilinear_extension_logical_right<P>(&self) -> Result<MultilinearExtension<P>, Error>
137 where
138 P: PackedFieldIndexable<Scalar = F>,
139 {
140 let right_shift_offset = get_left_shift_offset(self.block_size, self.shift_offset);
141 let (_, pps) =
142 partial_evaluate_hypercube_impl::<P>(self.block_size, right_shift_offset, &self.r)?;
143 Ok(MultilinearExtension::from_values(pps)?)
144 }
145
146 pub fn multilinear_extension<P>(&self) -> Result<MultilinearExtension<P>, Error>
149 where
150 P: PackedFieldIndexable<Scalar = F>,
151 {
152 match self.shift_variant {
153 ShiftVariant::CircularLeft => self.multilinear_extension_circular(),
154 ShiftVariant::LogicalLeft => self.multilinear_extension_logical_left(),
155 ShiftVariant::LogicalRight => self.multilinear_extension_logical_right(),
156 }
157 }
158
159 fn evaluate_at_point(&self, x: &[F]) -> Result<F, Error> {
161 if x.len() != self.block_size {
162 bail!(Error::IncorrectQuerySize {
163 expected: self.block_size,
164 });
165 }
166
167 let left_shift_offset = match self.shift_variant {
168 ShiftVariant::CircularLeft | ShiftVariant::LogicalLeft => self.shift_offset,
169 ShiftVariant::LogicalRight => get_left_shift_offset(self.block_size, self.shift_offset),
170 };
171
172 let (p_res, pp_res) =
173 evaluate_shift_ind_help(self.block_size, left_shift_offset, x, &self.r)?;
174
175 match self.shift_variant {
176 ShiftVariant::CircularLeft => Ok(p_res + pp_res),
177 ShiftVariant::LogicalLeft => Ok(p_res),
178 ShiftVariant::LogicalRight => Ok(pp_res),
179 }
180 }
181}
182
183impl<F: TowerField> MultivariatePoly<F> for ShiftIndPartialEval<F> {
184 fn n_vars(&self) -> usize {
185 self.block_size
186 }
187
188 fn degree(&self) -> usize {
189 self.block_size
190 }
191
192 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
193 self.evaluate_at_point(query)
194 }
195
196 fn binary_tower_level(&self) -> usize {
197 F::TOWER_LEVEL
198 }
199}
200
201const fn get_left_shift_offset(block_size: usize, right_shift_offset: usize) -> usize {
203 (1 << block_size) - right_shift_offset
204}
205
206fn assert_valid_shift_ind_args<F: Field>(
208 block_size: usize,
209 shift_offset: usize,
210 partial_query_point: &[F],
211) -> Result<(), Error> {
212 if partial_query_point.len() != block_size {
213 bail!(Error::IncorrectQuerySize {
214 expected: block_size,
215 });
216 }
217 if shift_offset == 0 || shift_offset >= 1 << block_size {
218 bail!(Error::InvalidShiftOffset {
219 max_shift_offset: (1 << block_size) - 1,
220 shift_offset,
221 });
222 }
223
224 Ok(())
225}
226
227fn evaluate_shift_ind_help<F: Field>(
232 block_size: usize,
233 shift_offset: usize,
234 x: &[F],
235 y: &[F],
236) -> Result<(F, F), Error> {
237 if x.len() != block_size {
238 bail!(Error::IncorrectQuerySize {
239 expected: block_size,
240 });
241 }
242 assert_valid_shift_ind_args(block_size, shift_offset, y)?;
243
244 let (mut s_ind_p, mut s_ind_pp) = (F::ONE, F::ZERO);
245 let (mut temp_p, mut temp_pp) = (F::default(), F::default());
246 (0..block_size).for_each(|k| {
247 let o_k = shift_offset >> k;
248 let product = x[k] * y[k];
249 if o_k % 2 == 1 {
250 temp_p = (y[k] - product) * s_ind_p;
251 temp_pp = (x[k] - product) * s_ind_p + eq(x[k], y[k]) * s_ind_pp;
252 } else {
253 temp_p = eq(x[k], y[k]) * s_ind_p + (y[k] - product) * s_ind_pp;
254 temp_pp = (x[k] - product) * s_ind_pp;
255 }
256 s_ind_p = temp_p;
258 s_ind_pp = temp_pp;
259 });
260
261 Ok((s_ind_p, s_ind_pp))
262}
263
264fn partial_evaluate_hypercube_impl<P: PackedFieldIndexable>(
270 block_size: usize,
271 shift_offset: usize,
272 r: &[P::Scalar],
273) -> Result<(Vec<P>, Vec<P>), Error> {
274 assert_valid_shift_ind_args(block_size, shift_offset, r)?;
275 let mut s_ind_p = vec![P::one(); 1 << (block_size - P::LOG_WIDTH)];
276 let mut s_ind_pp = vec![P::zero(); 1 << (block_size - P::LOG_WIDTH)];
277
278 partial_evaluate_hypercube_with_buffers(
279 block_size.min(P::LOG_WIDTH),
280 shift_offset,
281 r,
282 P::unpack_scalars_mut(&mut s_ind_p),
283 P::unpack_scalars_mut(&mut s_ind_pp),
284 );
285 if block_size > P::LOG_WIDTH {
286 partial_evaluate_hypercube_with_buffers(
287 block_size - P::LOG_WIDTH,
288 shift_offset >> P::LOG_WIDTH,
289 &r[P::LOG_WIDTH..],
290 &mut s_ind_p,
291 &mut s_ind_pp,
292 );
293 }
294
295 Ok((s_ind_p, s_ind_pp))
296}
297
298fn partial_evaluate_hypercube_with_buffers<P: PackedFieldIndexable>(
299 block_size: usize,
300 shift_offset: usize,
301 r: &[P::Scalar],
302 s_ind_p: &mut [P],
303 s_ind_pp: &mut [P],
304) {
305 for k in 0..block_size {
306 if (shift_offset >> k) % 2 == 1 {
308 for i in 0..(1 << k) {
309 let mut pp_lo = s_ind_pp[i];
310 let mut pp_hi = pp_lo * r[k];
311
312 pp_lo -= pp_hi;
313
314 let p_lo = s_ind_p[i];
315 let p_hi = p_lo * r[k];
316 pp_hi += p_lo - p_hi; s_ind_pp[i] = pp_lo;
319 s_ind_pp[1 << k | i] = pp_hi;
320
321 s_ind_p[i] = p_hi;
322 s_ind_p[1 << k | i] = P::zero(); }
324 } else {
325 for i in 0..(1 << k) {
326 let mut p_lo = s_ind_p[i];
327 let p_hi = p_lo * r[k];
328 p_lo -= p_hi;
329
330 let pp_lo = s_ind_pp[i];
331 let pp_hi = pp_lo * (P::one() - r[k]);
332 p_lo += pp_lo - pp_hi;
333
334 s_ind_p[i] = p_lo;
335 s_ind_p[1 << k | i] = p_hi;
336
337 s_ind_pp[i] = P::zero(); s_ind_pp[1 << k | i] = pp_hi;
339 }
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use std::iter::repeat_with;
347
348 use binius_field::{BinaryField32b, PackedBinaryField4x32b};
349 use binius_hal::{make_portable_backend, ComputationBackendExt};
350 use rand::{rngs::StdRng, SeedableRng};
351
352 use super::*;
353 use crate::polynomial::test_utils::decompose_index_to_hypercube_point;
354
355 fn test_circular_left_shift_consistency_help<
357 F: TowerField,
358 P: PackedFieldIndexable<Scalar = F>,
359 >(
360 block_size: usize,
361 right_shift_offset: usize,
362 ) {
363 let mut rng = StdRng::seed_from_u64(0);
364 let backend = make_portable_backend();
365 let r = repeat_with(|| F::random(&mut rng))
366 .take(block_size)
367 .collect::<Vec<_>>();
368 let eval_point = &repeat_with(|| F::random(&mut rng))
369 .take(block_size)
370 .collect::<Vec<_>>();
371
372 let shift_variant = ShiftVariant::CircularLeft;
374 let shift_r_mvp =
375 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
376 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
377
378 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
380 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
381 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
382
383 assert_eq!(eval_mle, eval_mvp);
385 }
386
387 fn test_logical_left_shift_consistency_help<
388 F: TowerField,
389 P: PackedFieldIndexable<Scalar = F>,
390 >(
391 block_size: usize,
392 right_shift_offset: usize,
393 ) {
394 let mut rng = StdRng::seed_from_u64(0);
395 let backend = make_portable_backend();
396 let r = repeat_with(|| F::random(&mut rng))
397 .take(block_size)
398 .collect::<Vec<_>>();
399 let eval_point = &repeat_with(|| F::random(&mut rng))
400 .take(block_size)
401 .collect::<Vec<_>>();
402
403 let shift_variant = ShiftVariant::LogicalLeft;
405 let shift_r_mvp =
406 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
407 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
408
409 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
411 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
412 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
413
414 assert_eq!(eval_mle, eval_mvp);
416 }
417
418 fn test_logical_right_shift_consistency_help<
419 F: TowerField,
420 P: PackedFieldIndexable<Scalar = F>,
421 >(
422 block_size: usize,
423 left_shift_offset: usize,
424 ) {
425 let mut rng = StdRng::seed_from_u64(0);
426 let backend = make_portable_backend();
427 let r = repeat_with(|| F::random(&mut rng))
428 .take(block_size)
429 .collect::<Vec<_>>();
430 let eval_point = &repeat_with(|| F::random(&mut rng))
431 .take(block_size)
432 .collect::<Vec<_>>();
433
434 let shift_variant = ShiftVariant::LogicalRight;
436 let shift_r_mvp =
437 ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
438 let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
439
440 let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
442 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
443 let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
444
445 assert_eq!(eval_mle, eval_mvp);
447 }
448
449 #[test]
450 fn test_circular_left_shift_consistency_schwartz_zippel() {
451 for block_size in 2..=10 {
452 for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
453 test_circular_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
454 block_size,
455 right_shift_offset,
456 );
457 }
458 }
459 }
460
461 #[test]
462 fn test_logical_left_shift_consistency_schwartz_zippel() {
463 for block_size in 2..=10 {
464 for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
465 test_logical_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
466 block_size,
467 right_shift_offset,
468 );
469 }
470 }
471 }
472
473 #[test]
474 fn test_logical_right_shift_consistency_schwartz_zippel() {
475 for block_size in 2..=10 {
476 for left_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
477 test_logical_right_shift_consistency_help::<_, PackedBinaryField4x32b>(
478 block_size,
479 left_shift_offset,
480 );
481 }
482 }
483 }
484
485 fn test_circular_left_shift_functionality_help<F: TowerField>(
487 block_size: usize,
488 right_shift_offset: usize,
489 ) {
490 let shift_variant = ShiftVariant::CircularLeft;
491 (0..(1 << block_size)).for_each(|i| {
492 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
493 let shift_r_mvp =
494 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
495 (0..(1 << block_size)).for_each(|j| {
496 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
497 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
498 if (j + right_shift_offset) % (1 << block_size) == i {
499 assert_eq!(eval_mvp, F::ONE);
500 } else {
501 assert_eq!(eval_mvp, F::ZERO);
502 }
503 });
504 });
505 }
506 fn test_logical_left_shift_functionality_help<F: TowerField>(
507 block_size: usize,
508 right_shift_offset: usize,
509 ) {
510 let shift_variant = ShiftVariant::LogicalLeft;
511 (0..(1 << block_size)).for_each(|i| {
512 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
513 let shift_r_mvp =
514 ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
515 (0..(1 << block_size)).for_each(|j| {
516 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
517 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
518 if j + right_shift_offset == i {
519 assert_eq!(eval_mvp, F::ONE);
520 } else {
521 assert_eq!(eval_mvp, F::ZERO);
522 }
523 });
524 });
525 }
526
527 fn test_logical_right_shift_functionality_help<F: TowerField>(
528 block_size: usize,
529 left_shift_offset: usize,
530 ) {
531 let shift_variant = ShiftVariant::LogicalRight;
532 (0..(1 << block_size)).for_each(|i| {
533 let r = decompose_index_to_hypercube_point::<F>(block_size, i);
534 let shift_r_mvp =
535 ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
536 (0..(1 << block_size)).for_each(|j| {
537 let x = decompose_index_to_hypercube_point::<F>(block_size, j);
538 let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
539 if j >= left_shift_offset && j - left_shift_offset == i {
540 assert_eq!(eval_mvp, F::ONE);
541 } else {
542 assert_eq!(eval_mvp, F::ZERO);
543 }
544 });
545 });
546 }
547
548 #[test]
549 fn test_circular_left_shift_functionality() {
550 for block_size in 3..5 {
551 for right_shift_offset in [
552 1,
553 3,
554 (1 << block_size) - 1,
555 (1 << block_size) - 2,
556 (1 << (block_size - 1)),
557 ] {
558 test_circular_left_shift_functionality_help::<BinaryField32b>(
559 block_size,
560 right_shift_offset,
561 );
562 }
563 }
564 }
565 #[test]
566 fn test_logical_left_shift_functionality() {
567 for block_size in 3..5 {
568 for right_shift_offset in [
569 1,
570 3,
571 (1 << block_size) - 1,
572 (1 << block_size) - 2,
573 (1 << (block_size - 1)),
574 ] {
575 test_logical_left_shift_functionality_help::<BinaryField32b>(
576 block_size,
577 right_shift_offset,
578 );
579 }
580 }
581 }
582 #[test]
583 fn test_logical_right_shift_functionality() {
584 for block_size in 3..5 {
585 for left_shift_offset in [
586 1,
587 3,
588 (1 << block_size) - 1,
589 (1 << block_size) - 2,
590 (1 << (block_size - 1)),
591 ] {
592 test_logical_right_shift_functionality_help::<BinaryField32b>(
593 block_size,
594 left_shift_offset,
595 );
596 }
597 }
598 }
599}