binius_math/multilinear/shift.rs
1// Copyright 2025 Irreducible Inc.
2
3use binius_field::Field;
4use itertools::izip;
5
6/// Computes the multilinear extension of the logical left shift indicator at a point.
7///
8/// The shift indicator for logical left shift (SLL) evaluates to 1 when bit `i` of the output
9/// equals bit `j` of the input after shifting left by `s` positions. Specifically:
10/// - `sll_ind(i, j, s) = 1` if and only if `i = j + s` and `i < 2^k`
11/// - `sll_ind(i, j, s) = 0` otherwise
12///
13/// This function evaluates the multilinear extension of this indicator at the given point
14/// `(i, j, s)` where each coordinate is a field element.
15///
16/// # Arguments
17/// * `i` - Slice of field elements representing the output bit position (length k)
18/// * `j` - Slice of field elements representing the input bit position (length k)
19/// * `s` - Slice of field elements representing the shift amount (length k)
20///
21/// # Panics
22/// Panics if the slices don't all have the same length.
23pub fn sll_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
24 assert_eq!(i.len(), j.len(), "i and j must have the same length");
25 assert_eq!(i.len(), s.len(), "i and s must have the same length");
26
27 // sll_ind(i, j, s) = srl_ind(j, i, s) by transposition
28 srl_ind(j, i, s)
29}
30
31/// Computes the multilinear extension of the logical right shift indicator at a point.
32///
33/// The shift indicator for logical right shift (srl) evaluates to 1 when bit `i` of the output
34/// equals bit `j` of the input after shifting right by `s` positions. Specifically:
35/// - `srl_ind(i, j, s) = 1` if and only if `j = i + s` and `j < 2^k`
36/// - `srl_ind(i, j, s) = 0` otherwise
37///
38/// This function evaluates the multilinear extension of this indicator at the given point
39/// `(i, j, s)` where each coordinate is a field element.
40///
41/// # Arguments
42/// * `i` - Slice of field elements representing the output bit position (length k)
43/// * `j` - Slice of field elements representing the input bit position (length k)
44/// * `s` - Slice of field elements representing the shift amount (length k)
45///
46/// # Panics
47/// Panics if the slices don't all have the same length.
48pub fn srl_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
49 assert_eq!(i.len(), j.len(), "i and j must have the same length");
50 assert_eq!(i.len(), s.len(), "i and s must have the same length");
51
52 let (sigma, _sigma_prime) = eval_sigmas(i.len(), i, j, s);
53 sigma
54}
55
56/// Computes the multilinear extension of the arithmetic right shift indicator at a point.
57///
58/// The shift indicator for arithmetic right shift (sra) behaves like logical right shift,
59/// but propagates the sign bit (bit 2^k - 1) for positions shifted beyond the original value.
60/// Specifically:
61/// - `sra_ind(i, j, s) = srl_ind(i, j, s)` for normal shifted bits
62/// - Additionally, `sra_ind(i, 2^k - 1, s) = 1` for `i >= 2^k - s` (sign extension)
63///
64/// This function evaluates the multilinear extension of this indicator at the given point
65/// `(i, j, s)` where each coordinate is a field element.
66///
67/// # Arguments
68/// * `i` - Slice of field elements representing the output bit position (length k)
69/// * `j` - Slice of field elements representing the input bit position (length k)
70/// * `s` - Slice of field elements representing the shift amount (length k)
71///
72/// # Panics
73/// Panics if the slices don't all have the same length.
74pub fn sra_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
75 assert_eq!(i.len(), j.len(), "i and j must have the same length");
76 assert_eq!(i.len(), s.len(), "i and s must have the same length");
77
78 let (sigma, _sigma_prime) = eval_sigmas(i.len(), i, j, s);
79 let phi = eval_phi(i.len(), i, s);
80 let j_prod = j.iter().product::<F>();
81 sigma + phi * j_prod
82}
83
84/// Computes the multilinear extension of the rotate right indicator at a point.
85///
86/// The shift indicator for rotate right (rotr) evaluates to 1 when bit `i` of the output
87/// equals bit `j` of the input after rotating right by `s` positions. Unlike logical shifts,
88/// bits that shift off one end wrap around to the other end. Specifically:
89/// - `rotr_ind(i, j, s) = 1` if and only if `j = (i + s) mod 2^k`
90/// - `rotr_ind(i, j, s) = 0` otherwise
91///
92/// This function evaluates the multilinear extension of this indicator at the given point
93/// `(i, j, s)` where each coordinate is a field element.
94///
95/// # Arguments
96/// * `i` - Slice of field elements representing the output bit position (length k)
97/// * `j` - Slice of field elements representing the input bit position (length k)
98/// * `s` - Slice of field elements representing the shift amount (length k)
99///
100/// # Panics
101/// Panics if the slices don't all have the same length.
102pub fn rotr_ind<F: Field>(i: &[F], j: &[F], s: &[F]) -> F {
103 assert_eq!(i.len(), j.len(), "i and j must have the same length");
104 assert_eq!(i.len(), s.len(), "i and s must have the same length");
105
106 let (sigma, sigma_prime) = eval_sigmas(i.len(), i, j, s);
107 sigma + sigma_prime
108}
109
110/// Evaluate the shift indicator helper polynomials, $\sigma, \sigma'$.
111///
112/// See section 4.6 of the writeup.
113fn eval_sigmas<F: Field>(n: usize, i: &[F], j: &[F], s: &[F]) -> (F, F) {
114 debug_assert_eq!(i.len(), n);
115 debug_assert_eq!(j.len(), n);
116 debug_assert_eq!(s.len(), n);
117
118 izip!(i, j, s).fold((F::ONE, F::ZERO), |(sigma, sigma_prime), (&i_k, &j_k, &s_k)| {
119 let next_sigma = (F::ONE + j_k + s_k + i_k * (F::ONE + s_k * (F::ONE + j_k))) * sigma
120 + (F::ONE - i_k) * j_k * (F::ONE - s_k) * sigma_prime;
121 let next_sigma_prime = i_k * (F::ONE - j_k) * s_k * sigma
122 + (i_k + s_k + j_k * (i_k + s_k * (F::ONE + i_k))) * sigma_prime;
123
124 (next_sigma, next_sigma_prime)
125 })
126}
127
128/// Evaluate the shift indicator helper polynomial $\phi$.
129///
130/// See section 4.6 of the writeup.
131fn eval_phi<F: Field>(n: usize, i: &[F], s: &[F]) -> F {
132 debug_assert_eq!(i.len(), n);
133 debug_assert_eq!(s.len(), n);
134
135 izip!(i, s).fold(F::ZERO, |phi, (&i_k, &s_k)| i_k * s_k + (i_k + s_k) * phi)
136}
137
138#[cfg(test)]
139mod tests {
140 use proptest::prelude::*;
141 use rand::{SeedableRng, rngs::StdRng};
142
143 use super::*;
144 use crate::{
145 line::extrapolate_line_packed,
146 test_utils::{B128, index_to_hypercube_point, random_scalars},
147 };
148
149 // Type alias for shift indicator functions
150 type ShiftIndicatorFn<F> = fn(&[F], &[F], &[F]) -> F;
151
152 // Helper to test hypercube evaluation for any shift indicator
153 fn test_hypercube_evaluation<F: Field>(
154 shift_fn: ShiftIndicatorFn<F>,
155 i_idx: usize,
156 j_idx: usize,
157 s_idx: usize,
158 expected_condition: bool,
159 ) {
160 let i = index_to_hypercube_point::<F>(6, i_idx);
161 let j = index_to_hypercube_point::<F>(6, j_idx);
162 let s = index_to_hypercube_point::<F>(6, s_idx);
163
164 let result = shift_fn(&i, &j, &s);
165 let expected = if expected_condition { F::ONE } else { F::ZERO };
166 assert_eq!(result, expected);
167 }
168
169 // Helper to test multilinearity of a function across all variables
170 fn test_multilinearity<F: Field>(f: impl Fn(&[F]) -> F, num_vars: usize) {
171 let mut rng = StdRng::seed_from_u64(0);
172
173 // Generate random base point
174 let mut point = random_scalars(&mut rng, num_vars);
175
176 // Test linearity in each coordinate
177 for coord_idx in 0..num_vars {
178 let z = point[coord_idx];
179
180 point[coord_idx] = F::ZERO;
181 let y0 = f(&point);
182 point[coord_idx] = F::ONE;
183 let y1 = f(&point);
184
185 point[coord_idx] = z;
186 let yz = f(&point);
187
188 // Check linearity using extrapolate_line_packed
189 assert_eq!(yz, extrapolate_line_packed(y0, y1, z));
190 }
191 }
192
193 // Test sll_ind on hypercube points
194 proptest! {
195 #[test]
196 fn test_sll_ind_hypercube(
197 i_idx in 0usize..64,
198 j_idx in 0usize..64,
199 s_idx in 0usize..64,
200 ) {
201 test_hypercube_evaluation(
202 sll_ind::<B128>,
203 i_idx, j_idx, s_idx,
204 i_idx == j_idx + s_idx
205 );
206 }
207 }
208
209 // Test srl_ind on hypercube points
210 proptest! {
211 #[test]
212 fn test_srl_ind_hypercube(
213 i_idx in 0usize..64,
214 j_idx in 0usize..64,
215 s_idx in 0usize..64,
216 ) {
217 test_hypercube_evaluation(
218 srl_ind::<B128>,
219 i_idx, j_idx, s_idx,
220 j_idx == i_idx + s_idx
221 );
222 }
223 }
224
225 // Test sra_ind on hypercube points
226 proptest! {
227 #[test]
228 fn test_sra_ind_hypercube(
229 i_idx in 0usize..64,
230 j_idx in 0usize..64,
231 s_idx in 0usize..64,
232 ) {
233 test_hypercube_evaluation(
234 sra_ind::<B128>,
235 i_idx, j_idx, s_idx,
236 j_idx == (i_idx + s_idx).min(63)
237 );
238 }
239 }
240
241 // Test rotr_ind on hypercube points
242 proptest! {
243 #[test]
244 fn test_rotr_ind_hypercube(
245 i_idx in 0usize..64,
246 j_idx in 0usize..64,
247 s_idx in 0usize..64,
248 ) {
249 test_hypercube_evaluation(
250 rotr_ind::<B128>,
251 i_idx, j_idx, s_idx,
252 j_idx == (i_idx + s_idx) % 64
253 );
254 }
255 }
256
257 // Test multilinearity of all shift indicators
258 #[test]
259 fn test_shift_indicators_multilinearity() {
260 // Test only implemented functions for now
261 let shift_inds: [ShiftIndicatorFn<B128>; _] = [sll_ind, srl_ind, sra_ind, rotr_ind];
262 for shift_fn in shift_inds {
263 test_multilinearity(
264 |v| {
265 assert_eq!(v.len(), 9, "Expected 9 variables total");
266 let i = &v[0..3];
267 let j = &v[3..6];
268 let s = &v[6..9];
269 shift_fn(i, j, s)
270 },
271 9,
272 );
273 }
274 }
275}