1use std::ops::Deref;
4
5use binius_field::{BinaryField, BinaryField1b};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct BinarySubspace<F, Data: Deref<Target = [F]> = Vec<F>> {
13 basis: Data,
14}
15
16impl<F: BinaryField, Data: Deref<Target = [F]>> BinarySubspace<F, Data> {
17 pub const fn new_unchecked(basis: Data) -> Self {
21 Self { basis }
22 }
23
24 pub fn isomorphic<FIso>(&self) -> BinarySubspace<FIso>
26 where
27 FIso: BinaryField + From<F>,
28 {
29 BinarySubspace {
30 basis: self.basis.iter().copied().map(Into::into).collect(),
31 }
32 }
33
34 pub fn dim(&self) -> usize {
36 self.basis.len()
37 }
38
39 pub fn basis(&self) -> &[F] {
41 &self.basis
42 }
43
44 pub fn get(&self, index: usize) -> F {
45 assert!(index < 1 << self.dim(), "precondition: index must be less than 2^dim");
46 self.basis
47 .iter()
48 .enumerate()
49 .map(|(i, &basis_i)| basis_i * BinaryField1b::from((index >> i) & 1 == 1))
50 .sum()
51 }
52
53 pub fn iter(&self) -> BinarySubspaceIterator<'_, F> {
57 BinarySubspaceIterator::new(&self.basis)
58 }
59}
60
61impl<F: BinaryField> BinarySubspace<F> {
62 pub fn with_dim(dim: usize) -> Self {
71 assert!(dim <= F::DEGREE, "precondition: dim must be at most F::DEGREE");
72 let basis = (0..dim).map(|i| F::basis(i)).collect();
73 Self { basis }
74 }
75
76 pub fn reduce_dim(&self, dim: usize) -> Self {
84 assert!(dim <= self.dim(), "precondition: dim must be at most this subspace's dimension");
85 Self {
86 basis: self.basis[..dim].to_vec(),
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
96pub struct BinarySubspaceIterator<'a, F> {
97 basis: &'a [F],
98 index: usize,
99 next: Option<F>,
100}
101
102impl<'a, F: BinaryField> BinarySubspaceIterator<'a, F> {
103 pub fn new(basis: &'a [F]) -> Self {
104 assert!(basis.len() < usize::BITS as usize);
105 Self {
106 basis,
107 index: 0,
108 next: Some(F::ZERO),
109 }
110 }
111}
112
113impl<'a, F: BinaryField> Iterator for BinarySubspaceIterator<'a, F> {
114 type Item = F;
115
116 #[inline]
117 fn next(&mut self) -> Option<Self::Item> {
118 let ret = self.next?;
119
120 let mut next = ret;
121 let mut i = 0;
122 while (self.index >> i) & 1 == 1 {
123 next -= self.basis[i];
124 i += 1;
125 }
126 self.next = self.basis.get(i).map(|&basis_i| next + basis_i);
127
128 self.index += 1;
129 Some(ret)
130 }
131
132 fn size_hint(&self) -> (usize, Option<usize>) {
133 let last = 1 << self.basis.len();
134 let remaining = last - self.index;
135 (remaining, Some(remaining))
136 }
137
138 fn nth(&mut self, n: usize) -> Option<Self::Item> {
139 match self.index.checked_add(n) {
140 Some(new_index) if new_index < 1 << self.basis.len() => {
141 let new_next = BinarySubspace::new_unchecked(self.basis).get(new_index);
142
143 self.index = new_index;
144 self.next = Some(new_next);
145 }
146 _ => {
147 self.index = 1 << self.basis.len();
148 self.next = None;
149 }
150 }
151
152 self.next()
153 }
154}
155
156impl<'a, F: BinaryField> ExactSizeIterator for BinarySubspaceIterator<'a, F> {
157 fn len(&self) -> usize {
158 let last = 1 << self.basis.len();
159 last - self.index
160 }
161}
162
163impl<'a, F: BinaryField> std::iter::FusedIterator for BinarySubspaceIterator<'a, F> {}
164
165impl<F: BinaryField> Default for BinarySubspace<F> {
166 fn default() -> Self {
167 let basis = (0..F::DEGREE).map(|i| F::basis(i)).collect();
168 Self { basis }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use binius_field::{AESTowerField8b as B8, BinaryField128bGhash as B128, Field};
175
176 use super::*;
177
178 #[test]
179 fn test_default_binary_subspace_iterates_elements() {
180 let subspace = BinarySubspace::<B8>::default();
181 for i in 0..=255 {
182 assert_eq!(subspace.get(i), B8::new(i as u8));
183 }
184 }
185
186 #[test]
187 #[should_panic(expected = "precondition")]
188 fn test_binary_subspace_range_error() {
189 let subspace = BinarySubspace::<B8>::default();
190 let _ = subspace.get(256);
191 }
192
193 #[test]
194 fn test_default_binary_subspace() {
195 let subspace = BinarySubspace::<B8>::default();
196 assert_eq!(subspace.dim(), 8);
197 assert_eq!(subspace.basis().len(), 8);
198
199 assert_eq!(
200 subspace.basis(),
201 [
202 B8::new(0b00000001),
203 B8::new(0b00000010),
204 B8::new(0b00000100),
205 B8::new(0b00001000),
206 B8::new(0b00010000),
207 B8::new(0b00100000),
208 B8::new(0b01000000),
209 B8::new(0b10000000)
210 ]
211 );
212
213 let expected_elements: [u8; 256] = (0..=255).collect::<Vec<_>>().try_into().unwrap();
214
215 for (i, &expected) in expected_elements.iter().enumerate() {
216 assert_eq!(subspace.get(i), B8::new(expected));
217 }
218 }
219
220 #[test]
221 fn test_with_dim_valid() {
222 let subspace = BinarySubspace::<B8>::with_dim(3);
223 assert_eq!(subspace.dim(), 3);
224 assert_eq!(subspace.basis().len(), 3);
225
226 assert_eq!(subspace.basis(), [B8::new(0b001), B8::new(0b010), B8::new(0b100)]);
227
228 let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
229
230 for (i, &expected) in expected_elements.iter().enumerate() {
231 assert_eq!(subspace.get(i), B8::new(expected));
232 }
233 }
234
235 #[test]
236 #[should_panic(expected = "precondition")]
237 fn test_with_dim_invalid() {
238 let _ = BinarySubspace::<B8>::with_dim(10);
239 }
240
241 #[test]
242 fn test_reduce_dim_valid() {
243 let subspace = BinarySubspace::<B8>::with_dim(6);
244 let reduced = subspace.reduce_dim(4);
245 assert_eq!(reduced.dim(), 4);
246 assert_eq!(reduced.basis().len(), 4);
247
248 assert_eq!(
249 reduced.basis(),
250 [
251 B8::new(0b0001),
252 B8::new(0b0010),
253 B8::new(0b0100),
254 B8::new(0b1000)
255 ]
256 );
257
258 let expected_elements: [u8; 16] = (0..16).collect::<Vec<_>>().try_into().unwrap();
259
260 for (i, &expected) in expected_elements.iter().enumerate() {
261 assert_eq!(reduced.get(i), B8::new(expected));
262 }
263 }
264
265 #[test]
266 #[should_panic(expected = "precondition")]
267 fn test_reduce_dim_invalid() {
268 let subspace = BinarySubspace::<B8>::with_dim(4);
269 let _ = subspace.reduce_dim(6);
270 }
271
272 #[test]
273 fn test_isomorphic_conversion() {
274 let subspace = BinarySubspace::<B8>::with_dim(3);
275 let iso_subspace: BinarySubspace<B128> = subspace.isomorphic();
276 assert_eq!(iso_subspace.dim(), 3);
277 assert_eq!(iso_subspace.basis().len(), 3);
278
279 assert_eq!(
280 iso_subspace.basis(),
281 [
282 B128::from(B8::new(0b001)),
283 B128::from(B8::new(0b010)),
284 B128::from(B8::new(0b100)),
285 ]
286 );
287 }
288
289 #[test]
290 fn test_iterate_subspace() {
291 let subspace = BinarySubspace::<B8>::with_dim(3);
292 let elements: Vec<_> = subspace.iter().collect();
293 assert_eq!(elements.len(), 8);
294
295 let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
296
297 for (i, &expected) in expected_elements.iter().enumerate() {
298 assert_eq!(elements[i], B8::new(expected));
299 }
300 }
301
302 #[test]
303 fn test_iterator_matches_get() {
304 let subspace = BinarySubspace::<B8>::with_dim(5);
305
306 for (i, elem) in subspace.iter().enumerate() {
308 assert_eq!(elem, subspace.get(i), "Mismatch at index {}", i);
309 }
310 }
311
312 #[test]
313 #[allow(clippy::iter_nth_zero)]
314 fn test_iterator_nth() {
315 let subspace = BinarySubspace::<B8>::with_dim(4);
316
317 let mut iter = subspace.iter();
319 assert_eq!(iter.nth(0), Some(subspace.get(0)));
320 assert_eq!(iter.nth(0), Some(subspace.get(1)));
321 assert_eq!(iter.nth(2), Some(subspace.get(4)));
322 assert_eq!(iter.nth(5), Some(subspace.get(10)));
323
324 let mut iter = subspace.iter();
326 assert_eq!(iter.nth(15), Some(subspace.get(15)));
327 assert_eq!(iter.nth(0), None);
328 }
329
330 #[test]
331 fn test_iterator_nth_skips_efficiently() {
332 let subspace = BinarySubspace::<B8>::with_dim(6);
333
334 let mut iter = subspace.iter();
336 assert_eq!(iter.nth(30), Some(subspace.get(30)));
337 assert_eq!(iter.next(), Some(subspace.get(31)));
338
339 let mut iter = subspace.iter();
341 assert_eq!(iter.nth(50), Some(subspace.get(50)));
342 }
343
344 #[test]
345 fn test_iterator_size_hint() {
346 let subspace = BinarySubspace::<B8>::with_dim(3);
347 let mut iter = subspace.iter();
348
349 assert_eq!(iter.size_hint(), (8, Some(8)));
350 iter.next();
351 assert_eq!(iter.size_hint(), (7, Some(7)));
352 iter.nth(3);
353 assert_eq!(iter.size_hint(), (3, Some(3)));
354 }
355
356 #[test]
357 fn test_iterator_exact_size() {
358 let subspace = BinarySubspace::<B8>::with_dim(4);
359 let mut iter = subspace.iter();
360
361 assert_eq!(iter.len(), 16);
362 iter.next();
363 assert_eq!(iter.len(), 15);
364 iter.nth(5);
365 assert_eq!(iter.len(), 9);
366 }
367
368 #[test]
369 fn test_iterator_empty_subspace() {
370 let subspace = BinarySubspace::<B8>::with_dim(0);
371 let mut iter = subspace.iter();
372
373 assert_eq!(iter.len(), 1);
375 assert_eq!(iter.next(), Some(B8::ZERO));
376 assert_eq!(iter.next(), None);
377 }
378
379 #[test]
380 fn test_iterator_full_iteration() {
381 let subspace = BinarySubspace::<B8>::default();
382 let collected: Vec<_> = subspace.iter().collect();
383
384 assert_eq!(collected.len(), 256);
385 for (i, elem) in collected.iter().enumerate() {
386 assert_eq!(*elem, subspace.get(i));
387 }
388 }
389
390 #[test]
391 fn test_iterator_partial_then_nth() {
392 let subspace = BinarySubspace::<B8>::with_dim(5);
393 let mut iter = subspace.iter();
394
395 assert_eq!(iter.next(), Some(subspace.get(0)));
397 assert_eq!(iter.next(), Some(subspace.get(1)));
398 assert_eq!(iter.next(), Some(subspace.get(2)));
399
400 assert_eq!(iter.nth(5), Some(subspace.get(8)));
402 assert_eq!(iter.next(), Some(subspace.get(9)));
403 }
404
405 #[test]
406 fn test_iterator_clone() {
407 let subspace = BinarySubspace::<B8>::with_dim(3);
408 let mut iter1 = subspace.iter();
409
410 iter1.next();
411 iter1.next();
412
413 let mut iter2 = iter1.clone();
414
415 assert_eq!(iter1.next(), iter2.next());
417 assert_eq!(iter1.collect::<Vec<_>>(), iter2.collect::<Vec<_>>());
418 }
419}