1use std::ops::Deref;
4
5use binius_field::{BinaryField, BinaryField1b};
6
7use super::error::Error;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct BinarySubspace<F, Data: Deref<Target = [F]> = Vec<F>> {
15 basis: Data,
16}
17
18impl<F: BinaryField, Data: Deref<Target = [F]>> BinarySubspace<F, Data> {
19 pub const fn new_unchecked(basis: Data) -> Self {
23 Self { basis }
24 }
25
26 pub fn isomorphic<FIso>(&self) -> BinarySubspace<FIso>
28 where
29 FIso: BinaryField + From<F>,
30 {
31 BinarySubspace {
32 basis: self.basis.iter().copied().map(Into::into).collect(),
33 }
34 }
35
36 pub fn dim(&self) -> usize {
38 self.basis.len()
39 }
40
41 pub fn basis(&self) -> &[F] {
43 &self.basis
44 }
45
46 pub fn get(&self, index: usize) -> F {
47 self.basis
48 .iter()
49 .enumerate()
50 .map(|(i, &basis_i)| basis_i * BinaryField1b::from((index >> i) & 1 == 1))
51 .sum()
52 }
53
54 pub fn get_checked(&self, index: usize) -> Result<F, Error> {
55 if index >= 1 << self.basis.len() {
56 return Err(Error::ArgumentRangeError {
57 arg: "index".to_string(),
58 range: 0..1 << self.basis.len(),
59 });
60 }
61 Ok(self.get(index))
62 }
63
64 pub fn iter(&self) -> BinarySubspaceIterator<'_, F> {
68 BinarySubspaceIterator::new(&self.basis)
69 }
70}
71
72impl<F: BinaryField> BinarySubspace<F> {
73 pub fn with_dim(dim: usize) -> Result<Self, Error> {
82 let basis = (0..dim)
83 .map(|i| F::basis_checked(i).map_err(|_| Error::DomainSizeTooLarge))
84 .collect::<Result<_, _>>()?;
85 Ok(Self { basis })
86 }
87
88 pub fn reduce_dim(&self, dim: usize) -> Result<Self, Error> {
96 if dim > self.dim() {
97 return Err(Error::DomainSizeTooLarge);
98 }
99 Ok(Self {
100 basis: self.basis[..dim].to_vec(),
101 })
102 }
103}
104
105#[derive(Debug, Clone)]
110pub struct BinarySubspaceIterator<'a, F> {
111 basis: &'a [F],
112 index: usize,
113 next: Option<F>,
114}
115
116impl<'a, F: BinaryField> BinarySubspaceIterator<'a, F> {
117 fn new(basis: &'a [F]) -> Self {
118 assert!(basis.len() < usize::BITS as usize);
119 Self {
120 basis,
121 index: 0,
122 next: Some(F::ZERO),
123 }
124 }
125}
126
127impl<'a, F: BinaryField> Iterator for BinarySubspaceIterator<'a, F> {
128 type Item = F;
129
130 fn next(&mut self) -> Option<Self::Item> {
131 let ret = self.next?;
132
133 let mut next = ret;
134 let mut i = 0;
135 while (self.index >> i) & 1 == 1 {
136 next -= self.basis[i];
137 i += 1;
138 }
139 self.next = self.basis.get(i).map(|&basis_i| next + basis_i);
140
141 self.index += 1;
142 Some(ret)
143 }
144
145 fn size_hint(&self) -> (usize, Option<usize>) {
146 let last = 1 << self.basis.len();
147 let remaining = last - self.index;
148 (remaining, Some(remaining))
149 }
150
151 fn nth(&mut self, n: usize) -> Option<Self::Item> {
152 match self.index.checked_add(n) {
153 Some(new_index) if new_index < 1 << self.basis.len() => {
154 let new_next = BinarySubspace::new_unchecked(self.basis).get(new_index);
155
156 self.index = new_index;
157 self.next = Some(new_next);
158 }
159 _ => {
160 self.index = 1 << self.basis.len();
161 self.next = None;
162 }
163 }
164
165 self.next()
166 }
167}
168
169impl<'a, F: BinaryField> ExactSizeIterator for BinarySubspaceIterator<'a, F> {
170 fn len(&self) -> usize {
171 let last = 1 << self.basis.len();
172 last - self.index
173 }
174}
175
176impl<'a, F: BinaryField> std::iter::FusedIterator for BinarySubspaceIterator<'a, F> {}
177
178impl<F: BinaryField> Default for BinarySubspace<F> {
179 fn default() -> Self {
180 let basis = (0..F::DEGREE).map(|i| F::basis(i)).collect();
181 Self { basis }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use assert_matches::assert_matches;
188 use binius_field::{AESTowerField8b as B8, BinaryField128bGhash as B128, Field};
189
190 use super::*;
191
192 #[test]
193 fn test_default_binary_subspace_iterates_elements() {
194 let subspace = BinarySubspace::<B8>::default();
195 for i in 0..=255 {
196 assert_eq!(subspace.get(i), B8::new(i as u8));
197 }
198 }
199
200 #[test]
201 fn test_binary_subspace_range_error() {
202 let subspace = BinarySubspace::<B8>::default();
203 assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
204 }
205
206 #[test]
207 fn test_default_binary_subspace() {
208 let subspace = BinarySubspace::<B8>::default();
209 assert_eq!(subspace.dim(), 8);
210 assert_eq!(subspace.basis().len(), 8);
211
212 assert_eq!(
213 subspace.basis(),
214 [
215 B8::new(0b00000001),
216 B8::new(0b00000010),
217 B8::new(0b00000100),
218 B8::new(0b00001000),
219 B8::new(0b00010000),
220 B8::new(0b00100000),
221 B8::new(0b01000000),
222 B8::new(0b10000000)
223 ]
224 );
225
226 let expected_elements: [u8; 256] = (0..=255).collect::<Vec<_>>().try_into().unwrap();
227
228 for (i, &expected) in expected_elements.iter().enumerate() {
229 assert_eq!(subspace.get(i), B8::new(expected));
230 }
231 }
232
233 #[test]
234 fn test_with_dim_valid() {
235 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
236 assert_eq!(subspace.dim(), 3);
237 assert_eq!(subspace.basis().len(), 3);
238
239 assert_eq!(subspace.basis(), [B8::new(0b001), B8::new(0b010), B8::new(0b100)]);
240
241 let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
242
243 for (i, &expected) in expected_elements.iter().enumerate() {
244 assert_eq!(subspace.get(i), B8::new(expected));
245 }
246 }
247
248 #[test]
249 fn test_with_dim_invalid() {
250 let result = BinarySubspace::<B8>::with_dim(10);
251 assert_matches!(result, Err(Error::DomainSizeTooLarge));
252 }
253
254 #[test]
255 fn test_reduce_dim_valid() {
256 let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
257 let reduced = subspace.reduce_dim(4).unwrap();
258 assert_eq!(reduced.dim(), 4);
259 assert_eq!(reduced.basis().len(), 4);
260
261 assert_eq!(
262 reduced.basis(),
263 [
264 B8::new(0b0001),
265 B8::new(0b0010),
266 B8::new(0b0100),
267 B8::new(0b1000)
268 ]
269 );
270
271 let expected_elements: [u8; 16] = (0..16).collect::<Vec<_>>().try_into().unwrap();
272
273 for (i, &expected) in expected_elements.iter().enumerate() {
274 assert_eq!(reduced.get(i), B8::new(expected));
275 }
276 }
277
278 #[test]
279 fn test_reduce_dim_invalid() {
280 let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
281 let result = subspace.reduce_dim(6);
282 assert_matches!(result, Err(Error::DomainSizeTooLarge));
283 }
284
285 #[test]
286 fn test_isomorphic_conversion() {
287 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
288 let iso_subspace: BinarySubspace<B128> = subspace.isomorphic();
289 assert_eq!(iso_subspace.dim(), 3);
290 assert_eq!(iso_subspace.basis().len(), 3);
291
292 assert_eq!(
293 iso_subspace.basis(),
294 [
295 B128::from(B8::new(0b001)),
296 B128::from(B8::new(0b010)),
297 B128::from(B8::new(0b100)),
298 ]
299 );
300 }
301
302 #[test]
303 fn test_get_checked_valid() {
304 let subspace = BinarySubspace::<B8>::default();
305 for i in 0..256 {
306 assert!(subspace.get_checked(i).is_ok());
307 }
308 }
309
310 #[test]
311 fn test_get_checked_invalid() {
312 let subspace = BinarySubspace::<B8>::default();
313 assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
314 }
315
316 #[test]
317 fn test_iterate_subspace() {
318 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
319 let elements: Vec<_> = subspace.iter().collect();
320 assert_eq!(elements.len(), 8);
321
322 let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
323
324 for (i, &expected) in expected_elements.iter().enumerate() {
325 assert_eq!(elements[i], B8::new(expected));
326 }
327 }
328
329 #[test]
330 fn test_iterator_matches_get() {
331 let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
332
333 for (i, elem) in subspace.iter().enumerate() {
335 assert_eq!(elem, subspace.get(i), "Mismatch at index {}", i);
336 }
337 }
338
339 #[test]
340 #[allow(clippy::iter_nth_zero)]
341 fn test_iterator_nth() {
342 let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
343
344 let mut iter = subspace.iter();
346 assert_eq!(iter.nth(0), Some(subspace.get(0)));
347 assert_eq!(iter.nth(0), Some(subspace.get(1)));
348 assert_eq!(iter.nth(2), Some(subspace.get(4)));
349 assert_eq!(iter.nth(5), Some(subspace.get(10)));
350
351 let mut iter = subspace.iter();
353 assert_eq!(iter.nth(15), Some(subspace.get(15)));
354 assert_eq!(iter.nth(0), None);
355 }
356
357 #[test]
358 fn test_iterator_nth_skips_efficiently() {
359 let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
360
361 let mut iter = subspace.iter();
363 assert_eq!(iter.nth(30), Some(subspace.get(30)));
364 assert_eq!(iter.next(), Some(subspace.get(31)));
365
366 let mut iter = subspace.iter();
368 assert_eq!(iter.nth(50), Some(subspace.get(50)));
369 }
370
371 #[test]
372 fn test_iterator_size_hint() {
373 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
374 let mut iter = subspace.iter();
375
376 assert_eq!(iter.size_hint(), (8, Some(8)));
377 iter.next();
378 assert_eq!(iter.size_hint(), (7, Some(7)));
379 iter.nth(3);
380 assert_eq!(iter.size_hint(), (3, Some(3)));
381 }
382
383 #[test]
384 fn test_iterator_exact_size() {
385 let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
386 let mut iter = subspace.iter();
387
388 assert_eq!(iter.len(), 16);
389 iter.next();
390 assert_eq!(iter.len(), 15);
391 iter.nth(5);
392 assert_eq!(iter.len(), 9);
393 }
394
395 #[test]
396 fn test_iterator_empty_subspace() {
397 let subspace = BinarySubspace::<B8>::with_dim(0).unwrap();
398 let mut iter = subspace.iter();
399
400 assert_eq!(iter.len(), 1);
402 assert_eq!(iter.next(), Some(B8::ZERO));
403 assert_eq!(iter.next(), None);
404 }
405
406 #[test]
407 fn test_iterator_full_iteration() {
408 let subspace = BinarySubspace::<B8>::default();
409 let collected: Vec<_> = subspace.iter().collect();
410
411 assert_eq!(collected.len(), 256);
412 for (i, elem) in collected.iter().enumerate() {
413 assert_eq!(*elem, subspace.get(i));
414 }
415 }
416
417 #[test]
418 fn test_iterator_partial_then_nth() {
419 let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
420 let mut iter = subspace.iter();
421
422 assert_eq!(iter.next(), Some(subspace.get(0)));
424 assert_eq!(iter.next(), Some(subspace.get(1)));
425 assert_eq!(iter.next(), Some(subspace.get(2)));
426
427 assert_eq!(iter.nth(5), Some(subspace.get(8)));
429 assert_eq!(iter.next(), Some(subspace.get(9)));
430 }
431
432 #[test]
433 fn test_iterator_clone() {
434 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
435 let mut iter1 = subspace.iter();
436
437 iter1.next();
438 iter1.next();
439
440 let mut iter2 = iter1.clone();
441
442 assert_eq!(iter1.next(), iter2.next());
444 assert_eq!(iter1.collect::<Vec<_>>(), iter2.collect::<Vec<_>>());
445 }
446}