1use std::ops::Deref;
4
5use binius_field::{BinaryField, BinaryField1b};
6
7use super::error::Error;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct BinarySubspace<F, Data: Deref<Target = [F]> = Vec<F>> {
15 basis: Data,
16}
17
18impl<F: BinaryField, Data: Deref<Target = [F]>> BinarySubspace<F, Data> {
19 pub const fn new_unchecked(basis: Data) -> Self {
23 Self { basis }
24 }
25
26 pub fn isomorphic<FIso>(&self) -> BinarySubspace<FIso>
28 where
29 FIso: BinaryField + From<F>,
30 {
31 BinarySubspace {
32 basis: self.basis.iter().copied().map(Into::into).collect(),
33 }
34 }
35
36 pub fn dim(&self) -> usize {
38 self.basis.len()
39 }
40
41 pub fn basis(&self) -> &[F] {
43 &self.basis
44 }
45
46 pub fn get(&self, index: usize) -> F {
47 self.basis
48 .iter()
49 .enumerate()
50 .map(|(i, &basis_i)| basis_i * BinaryField1b::from((index >> i) & 1 == 1))
51 .sum()
52 }
53
54 pub fn get_checked(&self, index: usize) -> Result<F, Error> {
55 if index >= 1 << self.basis.len() {
56 return Err(Error::ArgumentRangeError {
57 arg: "index".to_string(),
58 range: 0..1 << self.basis.len(),
59 });
60 }
61 Ok(self.get(index))
62 }
63
64 pub fn iter(&self) -> BinarySubspaceIterator<'_, F> {
68 BinarySubspaceIterator::new(&self.basis)
69 }
70}
71
72impl<F: BinaryField> BinarySubspace<F> {
73 pub fn with_dim(dim: usize) -> Result<Self, Error> {
82 let basis = (0..dim)
83 .map(|i| F::basis_checked(i).map_err(|_| Error::DomainSizeTooLarge))
84 .collect::<Result<_, _>>()?;
85 Ok(Self { basis })
86 }
87
88 pub fn reduce_dim(&self, dim: usize) -> Result<Self, Error> {
96 if dim > self.dim() {
97 return Err(Error::DomainSizeTooLarge);
98 }
99 Ok(Self {
100 basis: self.basis[..dim].to_vec(),
101 })
102 }
103}
104
105#[derive(Debug, Clone)]
110pub struct BinarySubspaceIterator<'a, F> {
111 basis: &'a [F],
112 index: usize,
113 next: Option<F>,
114}
115
116impl<'a, F: BinaryField> BinarySubspaceIterator<'a, F> {
117 pub fn new(basis: &'a [F]) -> Self {
118 assert!(basis.len() < usize::BITS as usize);
119 Self {
120 basis,
121 index: 0,
122 next: Some(F::ZERO),
123 }
124 }
125}
126
127impl<'a, F: BinaryField> Iterator for BinarySubspaceIterator<'a, F> {
128 type Item = F;
129
130 #[inline]
131 fn next(&mut self) -> Option<Self::Item> {
132 let ret = self.next?;
133
134 let mut next = ret;
135 let mut i = 0;
136 while (self.index >> i) & 1 == 1 {
137 next -= self.basis[i];
138 i += 1;
139 }
140 self.next = self.basis.get(i).map(|&basis_i| next + basis_i);
141
142 self.index += 1;
143 Some(ret)
144 }
145
146 fn size_hint(&self) -> (usize, Option<usize>) {
147 let last = 1 << self.basis.len();
148 let remaining = last - self.index;
149 (remaining, Some(remaining))
150 }
151
152 fn nth(&mut self, n: usize) -> Option<Self::Item> {
153 match self.index.checked_add(n) {
154 Some(new_index) if new_index < 1 << self.basis.len() => {
155 let new_next = BinarySubspace::new_unchecked(self.basis).get(new_index);
156
157 self.index = new_index;
158 self.next = Some(new_next);
159 }
160 _ => {
161 self.index = 1 << self.basis.len();
162 self.next = None;
163 }
164 }
165
166 self.next()
167 }
168}
169
170impl<'a, F: BinaryField> ExactSizeIterator for BinarySubspaceIterator<'a, F> {
171 fn len(&self) -> usize {
172 let last = 1 << self.basis.len();
173 last - self.index
174 }
175}
176
177impl<'a, F: BinaryField> std::iter::FusedIterator for BinarySubspaceIterator<'a, F> {}
178
179impl<F: BinaryField> Default for BinarySubspace<F> {
180 fn default() -> Self {
181 let basis = (0..F::DEGREE).map(|i| F::basis(i)).collect();
182 Self { basis }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use assert_matches::assert_matches;
189 use binius_field::{AESTowerField8b as B8, BinaryField128bGhash as B128, Field};
190
191 use super::*;
192
193 #[test]
194 fn test_default_binary_subspace_iterates_elements() {
195 let subspace = BinarySubspace::<B8>::default();
196 for i in 0..=255 {
197 assert_eq!(subspace.get(i), B8::new(i as u8));
198 }
199 }
200
201 #[test]
202 fn test_binary_subspace_range_error() {
203 let subspace = BinarySubspace::<B8>::default();
204 assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
205 }
206
207 #[test]
208 fn test_default_binary_subspace() {
209 let subspace = BinarySubspace::<B8>::default();
210 assert_eq!(subspace.dim(), 8);
211 assert_eq!(subspace.basis().len(), 8);
212
213 assert_eq!(
214 subspace.basis(),
215 [
216 B8::new(0b00000001),
217 B8::new(0b00000010),
218 B8::new(0b00000100),
219 B8::new(0b00001000),
220 B8::new(0b00010000),
221 B8::new(0b00100000),
222 B8::new(0b01000000),
223 B8::new(0b10000000)
224 ]
225 );
226
227 let expected_elements: [u8; 256] = (0..=255).collect::<Vec<_>>().try_into().unwrap();
228
229 for (i, &expected) in expected_elements.iter().enumerate() {
230 assert_eq!(subspace.get(i), B8::new(expected));
231 }
232 }
233
234 #[test]
235 fn test_with_dim_valid() {
236 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
237 assert_eq!(subspace.dim(), 3);
238 assert_eq!(subspace.basis().len(), 3);
239
240 assert_eq!(subspace.basis(), [B8::new(0b001), B8::new(0b010), B8::new(0b100)]);
241
242 let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
243
244 for (i, &expected) in expected_elements.iter().enumerate() {
245 assert_eq!(subspace.get(i), B8::new(expected));
246 }
247 }
248
249 #[test]
250 fn test_with_dim_invalid() {
251 let result = BinarySubspace::<B8>::with_dim(10);
252 assert_matches!(result, Err(Error::DomainSizeTooLarge));
253 }
254
255 #[test]
256 fn test_reduce_dim_valid() {
257 let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
258 let reduced = subspace.reduce_dim(4).unwrap();
259 assert_eq!(reduced.dim(), 4);
260 assert_eq!(reduced.basis().len(), 4);
261
262 assert_eq!(
263 reduced.basis(),
264 [
265 B8::new(0b0001),
266 B8::new(0b0010),
267 B8::new(0b0100),
268 B8::new(0b1000)
269 ]
270 );
271
272 let expected_elements: [u8; 16] = (0..16).collect::<Vec<_>>().try_into().unwrap();
273
274 for (i, &expected) in expected_elements.iter().enumerate() {
275 assert_eq!(reduced.get(i), B8::new(expected));
276 }
277 }
278
279 #[test]
280 fn test_reduce_dim_invalid() {
281 let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
282 let result = subspace.reduce_dim(6);
283 assert_matches!(result, Err(Error::DomainSizeTooLarge));
284 }
285
286 #[test]
287 fn test_isomorphic_conversion() {
288 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
289 let iso_subspace: BinarySubspace<B128> = subspace.isomorphic();
290 assert_eq!(iso_subspace.dim(), 3);
291 assert_eq!(iso_subspace.basis().len(), 3);
292
293 assert_eq!(
294 iso_subspace.basis(),
295 [
296 B128::from(B8::new(0b001)),
297 B128::from(B8::new(0b010)),
298 B128::from(B8::new(0b100)),
299 ]
300 );
301 }
302
303 #[test]
304 fn test_get_checked_valid() {
305 let subspace = BinarySubspace::<B8>::default();
306 for i in 0..256 {
307 assert!(subspace.get_checked(i).is_ok());
308 }
309 }
310
311 #[test]
312 fn test_get_checked_invalid() {
313 let subspace = BinarySubspace::<B8>::default();
314 assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
315 }
316
317 #[test]
318 fn test_iterate_subspace() {
319 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
320 let elements: Vec<_> = subspace.iter().collect();
321 assert_eq!(elements.len(), 8);
322
323 let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
324
325 for (i, &expected) in expected_elements.iter().enumerate() {
326 assert_eq!(elements[i], B8::new(expected));
327 }
328 }
329
330 #[test]
331 fn test_iterator_matches_get() {
332 let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
333
334 for (i, elem) in subspace.iter().enumerate() {
336 assert_eq!(elem, subspace.get(i), "Mismatch at index {}", i);
337 }
338 }
339
340 #[test]
341 #[allow(clippy::iter_nth_zero)]
342 fn test_iterator_nth() {
343 let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
344
345 let mut iter = subspace.iter();
347 assert_eq!(iter.nth(0), Some(subspace.get(0)));
348 assert_eq!(iter.nth(0), Some(subspace.get(1)));
349 assert_eq!(iter.nth(2), Some(subspace.get(4)));
350 assert_eq!(iter.nth(5), Some(subspace.get(10)));
351
352 let mut iter = subspace.iter();
354 assert_eq!(iter.nth(15), Some(subspace.get(15)));
355 assert_eq!(iter.nth(0), None);
356 }
357
358 #[test]
359 fn test_iterator_nth_skips_efficiently() {
360 let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
361
362 let mut iter = subspace.iter();
364 assert_eq!(iter.nth(30), Some(subspace.get(30)));
365 assert_eq!(iter.next(), Some(subspace.get(31)));
366
367 let mut iter = subspace.iter();
369 assert_eq!(iter.nth(50), Some(subspace.get(50)));
370 }
371
372 #[test]
373 fn test_iterator_size_hint() {
374 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
375 let mut iter = subspace.iter();
376
377 assert_eq!(iter.size_hint(), (8, Some(8)));
378 iter.next();
379 assert_eq!(iter.size_hint(), (7, Some(7)));
380 iter.nth(3);
381 assert_eq!(iter.size_hint(), (3, Some(3)));
382 }
383
384 #[test]
385 fn test_iterator_exact_size() {
386 let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
387 let mut iter = subspace.iter();
388
389 assert_eq!(iter.len(), 16);
390 iter.next();
391 assert_eq!(iter.len(), 15);
392 iter.nth(5);
393 assert_eq!(iter.len(), 9);
394 }
395
396 #[test]
397 fn test_iterator_empty_subspace() {
398 let subspace = BinarySubspace::<B8>::with_dim(0).unwrap();
399 let mut iter = subspace.iter();
400
401 assert_eq!(iter.len(), 1);
403 assert_eq!(iter.next(), Some(B8::ZERO));
404 assert_eq!(iter.next(), None);
405 }
406
407 #[test]
408 fn test_iterator_full_iteration() {
409 let subspace = BinarySubspace::<B8>::default();
410 let collected: Vec<_> = subspace.iter().collect();
411
412 assert_eq!(collected.len(), 256);
413 for (i, elem) in collected.iter().enumerate() {
414 assert_eq!(*elem, subspace.get(i));
415 }
416 }
417
418 #[test]
419 fn test_iterator_partial_then_nth() {
420 let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
421 let mut iter = subspace.iter();
422
423 assert_eq!(iter.next(), Some(subspace.get(0)));
425 assert_eq!(iter.next(), Some(subspace.get(1)));
426 assert_eq!(iter.next(), Some(subspace.get(2)));
427
428 assert_eq!(iter.nth(5), Some(subspace.get(8)));
430 assert_eq!(iter.next(), Some(subspace.get(9)));
431 }
432
433 #[test]
434 fn test_iterator_clone() {
435 let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
436 let mut iter1 = subspace.iter();
437
438 iter1.next();
439 iter1.next();
440
441 let mut iter2 = iter1.clone();
442
443 assert_eq!(iter1.next(), iter2.next());
445 assert_eq!(iter1.collect::<Vec<_>>(), iter2.collect::<Vec<_>>());
446 }
447}