binius_core/transparent/
shift_ind.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{util::eq, Field, PackedFieldIndexable, TowerField};
4use binius_math::MultilinearExtension;
5use binius_utils::bail;
6
7use crate::{
8	oracle::ShiftVariant,
9	polynomial::{Error, MultivariatePoly},
10};
11
12/// Represents MLE of shift indicator $f_{b, o}(X, Y)$ on $2*b$ variables
13/// partially evaluated at $Y = r$
14///
15/// # Formal Definition
16/// Let $x, y \in \{0, 1\}^b$
17/// If ShiftVariant is CircularLeft:
18///     * $f(x, y) = 1$ if $\{y\} - \{o\} \equiv \{x\} (\text{mod } 2^b)$
19///     * $f(x, y) = 0$ otw
20///
21/// Else if ShiftVariant is LogicalLeft:
22///    * $f(x, y) = 1$ if $\{y\} - \{o\} \equiv \{x\}$
23///    * $f(x, y) = 0$ otw
24///
25/// Else, ShiftVariant is LogicalRight:
26///    * $f(x, y) = 1$ if $\{y\} + \{o\} \equiv \{x\}$
27///    * $f(x, y) = 0$ otw
28///
29/// where:
30///    * $\{x\}$ is the integer representation of the hypercube point $x \in \{0, 1\}^b$,
31///    * $b$ is the block size parameter'
32///    * $o$ is the shift offset parameter.
33///
34/// Observe $\forall x \in \{0, 1\}^b$, there is at most one $y \in \{0, 1\}^b$ s.t. $f(x, y) = 1$
35///
36/// # Intuition
37/// Consider the lexicographic ordering of each point on the $b$-variate hypercube into a $2^b$ length array.
38/// Thus, we can give each element on the hypercube a unique index $\in \{0, \ldots, 2^b - 1\}$
39/// Let $x, y \in \{0, 1\}^{b}$ be s.t. $\{x\} = i$ and $\{y\} = j$
40/// $f(x, y) = 1$ iff:
41///     * taking $o$ steps from $j$ gets you to $i$
42/// (wrap around if ShiftVariant is Circular + direction of steps depending on ShiftVariant's direction)
43///
44/// # Note
45/// CircularLeft corresponds to the shift indicator in Section 4.3.
46/// LogicalLeft corresponds to the shift prime indicator in Section 4.3.
47/// LogicalRight corresponds to the shift double prime indicator in Section 4.3.
48///
49/// [DP23]: https://eprint.iacr.org/2023/1784
50///
51/// # Example
52/// Let $b$ = 2, $o$ = 1, variant = CircularLeft.
53/// The hypercube points (0, 0), (1, 0), (0, 1), (1, 1) can be lexicographically
54/// ordered into an array [(0, 0), (1, 0), (0, 1), (1, 1)]
55/// Then, by considering the index of each hypercube point in the above array, we observe:
56///     * $f((0, 0), (1, 0)) = 1$ because $1 - 1 = 0$ mod $4$
57///     * $f((1, 0), (0, 1)) = 1$ because $2 - 1 = 1$ mod $4$
58///     * $f((0, 1), (1, 1)) = 1$ because $3 - 1 = 2$ mod $4$
59///     * $f((1, 1), (0, 0)) = 1$ because $0 - 1 = 3$ mod $4$
60/// and every other pair of $b$-variate hypercube points $x, y \in \{0, 1\}^{b}$ is s.t. f(x, y) = 0.
61/// Using these shift params, if f = [[a_i, b_i, c_i, d_i]_i], then shifted_f = [[b_i, c_i, d_i, a_i]_i]
62///
63/// # Example
64/// Let $b$ = 2, $o$ = 1, variant = LogicalLeft.
65/// The hypercube points (0, 0), (1, 0), (0, 1), (1, 1) can be lexicographically
66/// ordered into an array [(0, 0), (1, 0), (0, 1), (1, 1)]
67/// Then, by considering the index of each hypercube point in the above array, we observe:
68///     * $f((0, 0), (1, 0)) = 1$ because $1 - 1 = 0$
69///     * $f((1, 0), (0, 1)) = 1$ because $2 - 1 = 1$
70///     * $f((0, 1), (1, 1)) = 1$ because $3 - 1 = 2$
71/// and every other pair of $b$-variate hypercube points $x, y \in \{0, 1\}^{b}$ is s.t. f(x, y) = 0.
72/// Using these shift params, if f = [[a_i, b_i, c_i, d_i]_i], then shifted_f = [[b_i, c_i, d_i, 0]_i]
73///
74/// # Example
75/// Let $b$ = 2, $o$ = 1, variant = LogicalRight.
76/// The hypercube points (0, 0), (1, 0), (0, 1), (1, 1) can be lexicographically
77/// ordered into an array [(0, 0), (1, 0), (0, 1), (1, 1)]
78/// Then, by considering the index of each hypercube point in the above array, we observe:
79///     * $f((1, 0), (0, 0)) = 1$ because $0 + 1 = 1$
80///     * $f((0, 1), (1, 0)) = 1$ because $1 + 1 = 2$
81///     * $f((1, 1), (0, 1)) = 1$ because $2 + 1 = 3$
82/// and every other pair of $b$-variate hypercube points $x, y \in \{0, 1\}^{b}$ is s.t. f(x, y) = 0.
83/// Using these shift params, if f = [[a_i, b_i, c_i, d_i]_i], then shifted_f = [[0, a_i, b_i, c_i]_i]
84#[derive(Debug, Clone)]
85pub struct ShiftIndPartialEval<F: Field> {
86	/// Block size $b$, also the number of variables
87	block_size: usize,
88	/// shift offset $o \in \{1, \ldots, 2^b - 1\}$
89	shift_offset: usize,
90	/// Shift variant
91	shift_variant: ShiftVariant,
92	/// partial evaluation point $r$, typically lowest $b$ coords
93	/// from a larger challenge point.
94	r: Vec<F>,
95}
96
97impl<F: Field> ShiftIndPartialEval<F> {
98	pub fn new(
99		block_size: usize,
100		shift_offset: usize,
101		shift_variant: ShiftVariant,
102		r: Vec<F>,
103	) -> Result<Self, Error> {
104		assert_valid_shift_ind_args(block_size, shift_offset, &r)?;
105		Ok(Self {
106			block_size,
107			shift_offset,
108			r,
109			shift_variant,
110		})
111	}
112
113	fn multilinear_extension_circular<P>(&self) -> Result<MultilinearExtension<P>, Error>
114	where
115		P: PackedFieldIndexable<Scalar = F>,
116	{
117		let (ps, pps) =
118			partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
119		let values = ps
120			.iter()
121			.zip(pps)
122			.map(|(p, pp)| *p + pp)
123			.collect::<Vec<_>>();
124		Ok(MultilinearExtension::from_values(values)?)
125	}
126
127	fn multilinear_extension_logical_left<P>(&self) -> Result<MultilinearExtension<P>, Error>
128	where
129		P: PackedFieldIndexable<Scalar = F>,
130	{
131		let (ps, _) =
132			partial_evaluate_hypercube_impl::<P>(self.block_size, self.shift_offset, &self.r)?;
133		Ok(MultilinearExtension::from_values(ps)?)
134	}
135
136	fn multilinear_extension_logical_right<P>(&self) -> Result<MultilinearExtension<P>, Error>
137	where
138		P: PackedFieldIndexable<Scalar = F>,
139	{
140		let right_shift_offset = get_left_shift_offset(self.block_size, self.shift_offset);
141		let (_, pps) =
142			partial_evaluate_hypercube_impl::<P>(self.block_size, right_shift_offset, &self.r)?;
143		Ok(MultilinearExtension::from_values(pps)?)
144	}
145
146	/// Evaluates this partially evaluated circular shift indicator MLE $f(X, r)$
147	/// over the entire $b$-variate hypercube
148	pub fn multilinear_extension<P>(&self) -> Result<MultilinearExtension<P>, Error>
149	where
150		P: PackedFieldIndexable<Scalar = F>,
151	{
152		match self.shift_variant {
153			ShiftVariant::CircularLeft => self.multilinear_extension_circular(),
154			ShiftVariant::LogicalLeft => self.multilinear_extension_logical_left(),
155			ShiftVariant::LogicalRight => self.multilinear_extension_logical_right(),
156		}
157	}
158
159	/// Evaluates this partial circular shift indicator MLE $f(X, r)$ with $X=x$
160	fn evaluate_at_point(&self, x: &[F]) -> Result<F, Error> {
161		if x.len() != self.block_size {
162			bail!(Error::IncorrectQuerySize {
163				expected: self.block_size,
164			});
165		}
166
167		let left_shift_offset = match self.shift_variant {
168			ShiftVariant::CircularLeft | ShiftVariant::LogicalLeft => self.shift_offset,
169			ShiftVariant::LogicalRight => get_left_shift_offset(self.block_size, self.shift_offset),
170		};
171
172		let (p_res, pp_res) =
173			evaluate_shift_ind_help(self.block_size, left_shift_offset, x, &self.r)?;
174
175		match self.shift_variant {
176			ShiftVariant::CircularLeft => Ok(p_res + pp_res),
177			ShiftVariant::LogicalLeft => Ok(p_res),
178			ShiftVariant::LogicalRight => Ok(pp_res),
179		}
180	}
181}
182
183impl<F: TowerField> MultivariatePoly<F> for ShiftIndPartialEval<F> {
184	fn n_vars(&self) -> usize {
185		self.block_size
186	}
187
188	fn degree(&self) -> usize {
189		self.block_size
190	}
191
192	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
193		self.evaluate_at_point(query)
194	}
195
196	fn binary_tower_level(&self) -> usize {
197		F::TOWER_LEVEL
198	}
199}
200
201/// Gets right shift offset from left shift offset
202const fn get_left_shift_offset(block_size: usize, right_shift_offset: usize) -> usize {
203	(1 << block_size) - right_shift_offset
204}
205
206/// Checks validity of shift indicator arguments
207fn assert_valid_shift_ind_args<F: Field>(
208	block_size: usize,
209	shift_offset: usize,
210	partial_query_point: &[F],
211) -> Result<(), Error> {
212	if partial_query_point.len() != block_size {
213		bail!(Error::IncorrectQuerySize {
214			expected: block_size,
215		});
216	}
217	if shift_offset == 0 || shift_offset >= 1 << block_size {
218		bail!(Error::InvalidShiftOffset {
219			max_shift_offset: (1 << block_size) - 1,
220			shift_offset,
221		});
222	}
223
224	Ok(())
225}
226
227/// Evaluates the LogicalRightShift and LogicalLeftShift indicators at the point $(x, y)$
228///
229/// Requires length of x (and y) is block_size
230/// Requires shift offset is at most $2^b$ where $b$ is block_size
231fn evaluate_shift_ind_help<F: Field>(
232	block_size: usize,
233	shift_offset: usize,
234	x: &[F],
235	y: &[F],
236) -> Result<(F, F), Error> {
237	if x.len() != block_size {
238		bail!(Error::IncorrectQuerySize {
239			expected: block_size,
240		});
241	}
242	assert_valid_shift_ind_args(block_size, shift_offset, y)?;
243
244	let (mut s_ind_p, mut s_ind_pp) = (F::ONE, F::ZERO);
245	let (mut temp_p, mut temp_pp) = (F::default(), F::default());
246	(0..block_size).for_each(|k| {
247		let o_k = shift_offset >> k;
248		let product = x[k] * y[k];
249		if o_k % 2 == 1 {
250			temp_p = (y[k] - product) * s_ind_p;
251			temp_pp = (x[k] - product) * s_ind_p + eq(x[k], y[k]) * s_ind_pp;
252		} else {
253			temp_p = eq(x[k], y[k]) * s_ind_p + (y[k] - product) * s_ind_pp;
254			temp_pp = (x[k] - product) * s_ind_pp;
255		}
256		// roll over results
257		s_ind_p = temp_p;
258		s_ind_pp = temp_pp;
259	});
260
261	Ok((s_ind_p, s_ind_pp))
262}
263
264/// Evaluates the LogicalRightShift and LogicalLeftShift indicators over the entire hypercube
265///
266/// Total time is O(2^b) field operations (optimal in light of output size)
267/// Requires length of $r$ is exactly block_size
268/// Requires shift offset is at most $2^b$ where $b$ is block_size
269fn partial_evaluate_hypercube_impl<P: PackedFieldIndexable>(
270	block_size: usize,
271	shift_offset: usize,
272	r: &[P::Scalar],
273) -> Result<(Vec<P>, Vec<P>), Error> {
274	assert_valid_shift_ind_args(block_size, shift_offset, r)?;
275	let mut s_ind_p = vec![P::one(); 1 << (block_size - P::LOG_WIDTH)];
276	let mut s_ind_pp = vec![P::zero(); 1 << (block_size - P::LOG_WIDTH)];
277
278	partial_evaluate_hypercube_with_buffers(
279		block_size.min(P::LOG_WIDTH),
280		shift_offset,
281		r,
282		P::unpack_scalars_mut(&mut s_ind_p),
283		P::unpack_scalars_mut(&mut s_ind_pp),
284	);
285	if block_size > P::LOG_WIDTH {
286		partial_evaluate_hypercube_with_buffers(
287			block_size - P::LOG_WIDTH,
288			shift_offset >> P::LOG_WIDTH,
289			&r[P::LOG_WIDTH..],
290			&mut s_ind_p,
291			&mut s_ind_pp,
292		);
293	}
294
295	Ok((s_ind_p, s_ind_pp))
296}
297
298fn partial_evaluate_hypercube_with_buffers<P: PackedFieldIndexable>(
299	block_size: usize,
300	shift_offset: usize,
301	r: &[P::Scalar],
302	s_ind_p: &mut [P],
303	s_ind_pp: &mut [P],
304) {
305	for k in 0..block_size {
306		// complexity: just two multiplications per iteration!
307		if (shift_offset >> k) % 2 == 1 {
308			for i in 0..(1 << k) {
309				let mut pp_lo = s_ind_pp[i];
310				let mut pp_hi = pp_lo * r[k];
311
312				pp_lo -= pp_hi;
313
314				let p_lo = s_ind_p[i];
315				let p_hi = p_lo * r[k];
316				pp_hi += p_lo - p_hi; // * 1 - r
317
318				s_ind_pp[i] = pp_lo;
319				s_ind_pp[1 << k | i] = pp_hi;
320
321				s_ind_p[i] = p_hi;
322				s_ind_p[1 << k | i] = P::zero(); // clear upper half
323			}
324		} else {
325			for i in 0..(1 << k) {
326				let mut p_lo = s_ind_p[i];
327				let p_hi = p_lo * r[k];
328				p_lo -= p_hi;
329
330				let pp_lo = s_ind_pp[i];
331				let pp_hi = pp_lo * (P::one() - r[k]);
332				p_lo += pp_lo - pp_hi;
333
334				s_ind_p[i] = p_lo;
335				s_ind_p[1 << k | i] = p_hi;
336
337				s_ind_pp[i] = P::zero(); // clear lower half
338				s_ind_pp[1 << k | i] = pp_hi;
339			}
340		}
341	}
342}
343
344#[cfg(test)]
345mod tests {
346	use std::iter::repeat_with;
347
348	use binius_field::{BinaryField32b, PackedBinaryField4x32b};
349	use binius_hal::{make_portable_backend, ComputationBackendExt};
350	use rand::{rngs::StdRng, SeedableRng};
351
352	use super::*;
353	use crate::polynomial::test_utils::decompose_index_to_hypercube_point;
354
355	// Consistency Tests for each shift variant
356	fn test_circular_left_shift_consistency_help<
357		F: TowerField,
358		P: PackedFieldIndexable<Scalar = F>,
359	>(
360		block_size: usize,
361		right_shift_offset: usize,
362	) {
363		let mut rng = StdRng::seed_from_u64(0);
364		let backend = make_portable_backend();
365		let r = repeat_with(|| F::random(&mut rng))
366			.take(block_size)
367			.collect::<Vec<_>>();
368		let eval_point = &repeat_with(|| F::random(&mut rng))
369			.take(block_size)
370			.collect::<Vec<_>>();
371
372		// Get Multivariate Poly version
373		let shift_variant = ShiftVariant::CircularLeft;
374		let shift_r_mvp =
375			ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
376		let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
377
378		// Get MultilinearExtension version
379		let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
380		let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
381		let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
382
383		// Assert equality
384		assert_eq!(eval_mle, eval_mvp);
385	}
386
387	fn test_logical_left_shift_consistency_help<
388		F: TowerField,
389		P: PackedFieldIndexable<Scalar = F>,
390	>(
391		block_size: usize,
392		right_shift_offset: usize,
393	) {
394		let mut rng = StdRng::seed_from_u64(0);
395		let backend = make_portable_backend();
396		let r = repeat_with(|| F::random(&mut rng))
397			.take(block_size)
398			.collect::<Vec<_>>();
399		let eval_point = &repeat_with(|| F::random(&mut rng))
400			.take(block_size)
401			.collect::<Vec<_>>();
402
403		// Get Multivariate Poly version
404		let shift_variant = ShiftVariant::LogicalLeft;
405		let shift_r_mvp =
406			ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
407		let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
408
409		// Get MultilinearExtension version
410		let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
411		let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
412		let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
413
414		// Assert equality
415		assert_eq!(eval_mle, eval_mvp);
416	}
417
418	fn test_logical_right_shift_consistency_help<
419		F: TowerField,
420		P: PackedFieldIndexable<Scalar = F>,
421	>(
422		block_size: usize,
423		left_shift_offset: usize,
424	) {
425		let mut rng = StdRng::seed_from_u64(0);
426		let backend = make_portable_backend();
427		let r = repeat_with(|| F::random(&mut rng))
428			.take(block_size)
429			.collect::<Vec<_>>();
430		let eval_point = &repeat_with(|| F::random(&mut rng))
431			.take(block_size)
432			.collect::<Vec<_>>();
433
434		// Get Multivariate Poly version
435		let shift_variant = ShiftVariant::LogicalRight;
436		let shift_r_mvp =
437			ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
438		let eval_mvp = shift_r_mvp.evaluate(eval_point).unwrap();
439
440		// Get MultilinearExtension version
441		let shift_r_mle = shift_r_mvp.multilinear_extension::<P>().unwrap();
442		let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
443		let eval_mle = shift_r_mle.evaluate(&multilin_query).unwrap();
444
445		// Assert equality
446		assert_eq!(eval_mle, eval_mvp);
447	}
448
449	#[test]
450	fn test_circular_left_shift_consistency_schwartz_zippel() {
451		for block_size in 2..=10 {
452			for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
453				test_circular_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
454					block_size,
455					right_shift_offset,
456				);
457			}
458		}
459	}
460
461	#[test]
462	fn test_logical_left_shift_consistency_schwartz_zippel() {
463		for block_size in 2..=10 {
464			for right_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
465				test_logical_left_shift_consistency_help::<_, PackedBinaryField4x32b>(
466					block_size,
467					right_shift_offset,
468				);
469			}
470		}
471	}
472
473	#[test]
474	fn test_logical_right_shift_consistency_schwartz_zippel() {
475		for block_size in 2..=10 {
476			for left_shift_offset in [1, 2, 3, (1 << block_size) - 1, (1 << block_size) / 2] {
477				test_logical_right_shift_consistency_help::<_, PackedBinaryField4x32b>(
478					block_size,
479					left_shift_offset,
480				);
481			}
482		}
483	}
484
485	// Functionality Tests for each shift variant
486	fn test_circular_left_shift_functionality_help<F: TowerField>(
487		block_size: usize,
488		right_shift_offset: usize,
489	) {
490		let shift_variant = ShiftVariant::CircularLeft;
491		(0..(1 << block_size)).for_each(|i| {
492			let r = decompose_index_to_hypercube_point::<F>(block_size, i);
493			let shift_r_mvp =
494				ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
495			(0..(1 << block_size)).for_each(|j| {
496				let x = decompose_index_to_hypercube_point::<F>(block_size, j);
497				let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
498				if (j + right_shift_offset) % (1 << block_size) == i {
499					assert_eq!(eval_mvp, F::ONE);
500				} else {
501					assert_eq!(eval_mvp, F::ZERO);
502				}
503			});
504		});
505	}
506	fn test_logical_left_shift_functionality_help<F: TowerField>(
507		block_size: usize,
508		right_shift_offset: usize,
509	) {
510		let shift_variant = ShiftVariant::LogicalLeft;
511		(0..(1 << block_size)).for_each(|i| {
512			let r = decompose_index_to_hypercube_point::<F>(block_size, i);
513			let shift_r_mvp =
514				ShiftIndPartialEval::new(block_size, right_shift_offset, shift_variant, r).unwrap();
515			(0..(1 << block_size)).for_each(|j| {
516				let x = decompose_index_to_hypercube_point::<F>(block_size, j);
517				let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
518				if j + right_shift_offset == i {
519					assert_eq!(eval_mvp, F::ONE);
520				} else {
521					assert_eq!(eval_mvp, F::ZERO);
522				}
523			});
524		});
525	}
526
527	fn test_logical_right_shift_functionality_help<F: TowerField>(
528		block_size: usize,
529		left_shift_offset: usize,
530	) {
531		let shift_variant = ShiftVariant::LogicalRight;
532		(0..(1 << block_size)).for_each(|i| {
533			let r = decompose_index_to_hypercube_point::<F>(block_size, i);
534			let shift_r_mvp =
535				ShiftIndPartialEval::new(block_size, left_shift_offset, shift_variant, r).unwrap();
536			(0..(1 << block_size)).for_each(|j| {
537				let x = decompose_index_to_hypercube_point::<F>(block_size, j);
538				let eval_mvp = shift_r_mvp.evaluate(&x).unwrap();
539				if j >= left_shift_offset && j - left_shift_offset == i {
540					assert_eq!(eval_mvp, F::ONE);
541				} else {
542					assert_eq!(eval_mvp, F::ZERO);
543				}
544			});
545		});
546	}
547
548	#[test]
549	fn test_circular_left_shift_functionality() {
550		for block_size in 3..5 {
551			for right_shift_offset in [
552				1,
553				3,
554				(1 << block_size) - 1,
555				(1 << block_size) - 2,
556				(1 << (block_size - 1)),
557			] {
558				test_circular_left_shift_functionality_help::<BinaryField32b>(
559					block_size,
560					right_shift_offset,
561				);
562			}
563		}
564	}
565	#[test]
566	fn test_logical_left_shift_functionality() {
567		for block_size in 3..5 {
568			for right_shift_offset in [
569				1,
570				3,
571				(1 << block_size) - 1,
572				(1 << block_size) - 2,
573				(1 << (block_size - 1)),
574			] {
575				test_logical_left_shift_functionality_help::<BinaryField32b>(
576					block_size,
577					right_shift_offset,
578				);
579			}
580		}
581	}
582	#[test]
583	fn test_logical_right_shift_functionality() {
584		for block_size in 3..5 {
585			for left_shift_offset in [
586				1,
587				3,
588				(1 << block_size) - 1,
589				(1 << block_size) - 2,
590				(1 << (block_size - 1)),
591			] {
592				test_logical_right_shift_functionality_help::<BinaryField32b>(
593					block_size,
594					left_shift_offset,
595				);
596			}
597		}
598	}
599}