binius_math/multilinear/
shift.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_field::Field;
4use itertools::izip;
5
6/// Computes the multilinear extension of the logical left shift indicator at a point.
7///
8/// The shift indicator for logical left shift (SLL) evaluates to 1 when bit `i` of the output
9/// equals bit `j` of the input after shifting left by `s` positions. Specifically:
10/// - `sll_ind(i, j, s) = 1` if and only if `i = j + s` and `i < 2^k`
11/// - `sll_ind(i, j, s) = 0` otherwise
12///
13/// This function evaluates the multilinear extension of this indicator at the given point
14/// `(i, j, s)` where each coordinate is a field element.
15///
16/// # Arguments
17/// * `i` - Slice of field elements representing the output bit position (length k)
18/// * `j` - Slice of field elements representing the input bit position (length k)
19/// * `s` - Slice of field elements representing the shift amount (length k)
20///
21/// # Panics
22/// Panics if the slices don't all have the same length.
23pub fn sll_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
24	assert_eq!(i.len(), j.len(), "i and j must have the same length");
25	assert_eq!(i.len(), s.len(), "i and s must have the same length");
26
27	// sll_ind(i, j, s) = srl_ind(j, i, s) by transposition
28	srl_ind(j, i, s)
29}
30
31/// Computes the multilinear extension of the logical right shift indicator at a point.
32///
33/// The shift indicator for logical right shift (srl) evaluates to 1 when bit `i` of the output
34/// equals bit `j` of the input after shifting right by `s` positions. Specifically:
35/// - `srl_ind(i, j, s) = 1` if and only if `j = i + s` and `j < 2^k`
36/// - `srl_ind(i, j, s) = 0` otherwise
37///
38/// This function evaluates the multilinear extension of this indicator at the given point
39/// `(i, j, s)` where each coordinate is a field element.
40///
41/// # Arguments
42/// * `i` - Slice of field elements representing the output bit position (length k)
43/// * `j` - Slice of field elements representing the input bit position (length k)
44/// * `s` - Slice of field elements representing the shift amount (length k)
45///
46/// # Panics
47/// Panics if the slices don't all have the same length.
48pub fn srl_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
49	assert_eq!(i.len(), j.len(), "i and j must have the same length");
50	assert_eq!(i.len(), s.len(), "i and s must have the same length");
51
52	let (sigma, _sigma_prime) = eval_sigmas(i.len(), i, j, s);
53	sigma
54}
55
56/// Computes the multilinear extension of the arithmetic right shift indicator at a point.
57///
58/// The shift indicator for arithmetic right shift (sra) behaves like logical right shift,
59/// but propagates the sign bit (bit 2^k - 1) for positions shifted beyond the original value.
60/// Specifically:
61/// - `sra_ind(i, j, s) = srl_ind(i, j, s)` for normal shifted bits
62/// - Additionally, `sra_ind(i, 2^k - 1, s) = 1` for `i >= 2^k - s` (sign extension)
63///
64/// This function evaluates the multilinear extension of this indicator at the given point
65/// `(i, j, s)` where each coordinate is a field element.
66///
67/// # Arguments
68/// * `i` - Slice of field elements representing the output bit position (length k)
69/// * `j` - Slice of field elements representing the input bit position (length k)
70/// * `s` - Slice of field elements representing the shift amount (length k)
71///
72/// # Panics
73/// Panics if the slices don't all have the same length.
74pub fn sra_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
75	assert_eq!(i.len(), j.len(), "i and j must have the same length");
76	assert_eq!(i.len(), s.len(), "i and s must have the same length");
77
78	let (sigma, _sigma_prime) = eval_sigmas(i.len(), i, j, s);
79	let phi = eval_phi(i.len(), i, s);
80	let j_prod = j.iter().product::<F>();
81	sigma + phi * j_prod
82}
83
84/// Computes the multilinear extension of the rotate right indicator at a point.
85///
86/// The shift indicator for rotate right (rotr) evaluates to 1 when bit `i` of the output
87/// equals bit `j` of the input after rotating right by `s` positions. Unlike logical shifts,
88/// bits that shift off one end wrap around to the other end. Specifically:
89/// - `rotr_ind(i, j, s) = 1` if and only if `j = (i + s) mod 2^k`
90/// - `rotr_ind(i, j, s) = 0` otherwise
91///
92/// This function evaluates the multilinear extension of this indicator at the given point
93/// `(i, j, s)` where each coordinate is a field element.
94///
95/// # Arguments
96/// * `i` - Slice of field elements representing the output bit position (length k)
97/// * `j` - Slice of field elements representing the input bit position (length k)
98/// * `s` - Slice of field elements representing the shift amount (length k)
99///
100/// # Panics
101/// Panics if the slices don't all have the same length.
102pub fn rotr_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
103	assert_eq!(i.len(), j.len(), "i and j must have the same length");
104	assert_eq!(i.len(), s.len(), "i and s must have the same length");
105
106	let (sigma, sigma_prime) = eval_sigmas(i.len(), i, j, s);
107	sigma + sigma_prime
108}
109
110/// Evaluate the shift indicator helper polynomials, $\sigma, \sigma'$.
111///
112/// See section 4.6 of the writeup.
113fn eval_sigmas<F: Field>(n: usize, i: &[F], j: &[F], s: &[F]) -> (F, F) {
114	debug_assert_eq!(i.len(), n);
115	debug_assert_eq!(j.len(), n);
116	debug_assert_eq!(s.len(), n);
117
118	izip!(i, j, s).fold((F::ONE, F::ZERO), |(sigma, sigma_prime), (&i_k, &j_k, &s_k)| {
119		let next_sigma = (F::ONE + j_k + s_k + i_k * (F::ONE + s_k * (F::ONE + j_k))) * sigma
120			+ (F::ONE - i_k) * j_k * (F::ONE - s_k) * sigma_prime;
121		let next_sigma_prime = i_k * (F::ONE - j_k) * s_k * sigma
122			+ (i_k + s_k + j_k * (i_k + s_k * (F::ONE + i_k))) * sigma_prime;
123
124		(next_sigma, next_sigma_prime)
125	})
126}
127
128/// Evaluate the shift indicator helper polynomial $\phi$.
129///
130/// See section 4.6 of the writeup.
131fn eval_phi<F: Field>(n: usize, i: &[F], s: &[F]) -> F {
132	debug_assert_eq!(i.len(), n);
133	debug_assert_eq!(s.len(), n);
134
135	izip!(i, s).fold(F::ZERO, |phi, (&i_k, &s_k)| i_k * s_k + (i_k + s_k) * phi)
136}
137
138#[cfg(test)]
139mod tests {
140	use proptest::prelude::*;
141	use rand::{SeedableRng, rngs::StdRng};
142
143	use super::*;
144	use crate::{
145		line::extrapolate_line_packed,
146		test_utils::{B128, index_to_hypercube_point, random_scalars},
147	};
148
149	// Type alias for shift indicator functions
150	type ShiftIndicatorFn<F> = fn(&[F], &[F], &[F]) -> F;
151
152	// Helper to test hypercube evaluation for any shift indicator
153	fn test_hypercube_evaluation<F: Field>(
154		shift_fn: ShiftIndicatorFn<F>,
155		i_idx: usize,
156		j_idx: usize,
157		s_idx: usize,
158		expected_condition: bool,
159	) {
160		let i = index_to_hypercube_point::<F>(6, i_idx);
161		let j = index_to_hypercube_point::<F>(6, j_idx);
162		let s = index_to_hypercube_point::<F>(6, s_idx);
163
164		let result = shift_fn(&i, &j, &s);
165		let expected = if expected_condition { F::ONE } else { F::ZERO };
166		assert_eq!(result, expected);
167	}
168
169	// Helper to test multilinearity of a function across all variables
170	fn test_multilinearity<F: Field>(f: impl Fn(&[F]) -> F, num_vars: usize) {
171		let mut rng = StdRng::seed_from_u64(0);
172
173		// Generate random base point
174		let mut point = random_scalars(&mut rng, num_vars);
175
176		// Test linearity in each coordinate
177		for coord_idx in 0..num_vars {
178			let z = point[coord_idx];
179
180			point[coord_idx] = F::ZERO;
181			let y0 = f(&point);
182			point[coord_idx] = F::ONE;
183			let y1 = f(&point);
184
185			point[coord_idx] = z;
186			let yz = f(&point);
187
188			// Check linearity using extrapolate_line_packed
189			assert_eq!(yz, extrapolate_line_packed(y0, y1, z));
190		}
191	}
192
193	// Test sll_ind on hypercube points
194	proptest! {
195		#[test]
196		fn test_sll_ind_hypercube(
197			i_idx in 0usize..64,
198			j_idx in 0usize..64,
199			s_idx in 0usize..64,
200		) {
201			test_hypercube_evaluation(
202				sll_ind::<B128>,
203				i_idx, j_idx, s_idx,
204				i_idx == j_idx + s_idx
205			);
206		}
207	}
208
209	// Test srl_ind on hypercube points
210	proptest! {
211		#[test]
212		fn test_srl_ind_hypercube(
213			i_idx in 0usize..64,
214			j_idx in 0usize..64,
215			s_idx in 0usize..64,
216		) {
217			test_hypercube_evaluation(
218				srl_ind::<B128>,
219				i_idx, j_idx, s_idx,
220				j_idx == i_idx + s_idx
221			);
222		}
223	}
224
225	// Test sra_ind on hypercube points
226	proptest! {
227		#[test]
228		fn test_sra_ind_hypercube(
229			i_idx in 0usize..64,
230			j_idx in 0usize..64,
231			s_idx in 0usize..64,
232		) {
233			test_hypercube_evaluation(
234				sra_ind::<B128>,
235				i_idx, j_idx, s_idx,
236				j_idx == (i_idx + s_idx).min(63)
237			);
238		}
239	}
240
241	// Test rotr_ind on hypercube points
242	proptest! {
243		#[test]
244		fn test_rotr_ind_hypercube(
245			i_idx in 0usize..64,
246			j_idx in 0usize..64,
247			s_idx in 0usize..64,
248		) {
249			test_hypercube_evaluation(
250				rotr_ind::<B128>,
251				i_idx, j_idx, s_idx,
252				j_idx == (i_idx + s_idx) % 64
253			);
254		}
255	}
256
257	// Test multilinearity of all shift indicators
258	#[test]
259	fn test_shift_indicators_multilinearity() {
260		// Test only implemented functions for now
261		let shift_inds: [ShiftIndicatorFn<B128>; _] = [sll_ind, srl_ind, sra_ind, rotr_ind];
262		for shift_fn in shift_inds {
263			test_multilinearity(
264				|v| {
265					assert_eq!(v.len(), 9, "Expected 9 variables total");
266					let i = &v[0..3];
267					let j = &v[3..6];
268					let s = &v[6..9];
269					shift_fn(i, j, s)
270				},
271				9,
272			);
273		}
274	}
275}