binius_math/
binary_subspace.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::Deref;
4
5use binius_field::{BinaryField, BinaryField1b};
6
7use super::error::Error;
8
9/// An $F_2$-linear subspace of a binary field.
10///
11/// The subspace is defined by a basis of elements from a binary field. The basis elements are
12/// ordered, which implies an ordering on the subspace elements.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct BinarySubspace<F, Data: Deref<Target = [F]> = Vec<F>> {
15	basis: Data,
16}
17
18impl<F: BinaryField, Data: Deref<Target = [F]>> BinarySubspace<F, Data> {
19	/// Creates a new subspace from a vector of ordered basis elements.
20	///
21	/// This constructor does not check that the basis elements are linearly independent.
22	pub const fn new_unchecked(basis: Data) -> Self {
23		Self { basis }
24	}
25
26	/// Creates a new subspace isomorphic to the given one.
27	pub fn isomorphic<FIso>(&self) -> BinarySubspace<FIso>
28	where
29		FIso: BinaryField + From<F>,
30	{
31		BinarySubspace {
32			basis: self.basis.iter().copied().map(Into::into).collect(),
33		}
34	}
35
36	/// Returns the dimension of the subspace.
37	pub fn dim(&self) -> usize {
38		self.basis.len()
39	}
40
41	/// Returns the slice of ordered basis elements.
42	pub fn basis(&self) -> &[F] {
43		&self.basis
44	}
45
46	pub fn get(&self, index: usize) -> F {
47		self.basis
48			.iter()
49			.enumerate()
50			.map(|(i, &basis_i)| basis_i * BinaryField1b::from((index >> i) & 1 == 1))
51			.sum()
52	}
53
54	pub fn get_checked(&self, index: usize) -> Result<F, Error> {
55		if index >= 1 << self.basis.len() {
56			return Err(Error::ArgumentRangeError {
57				arg: "index".to_string(),
58				range: 0..1 << self.basis.len(),
59			});
60		}
61		Ok(self.get(index))
62	}
63
64	/// Returns an iterator over all elements of the subspace in order.
65	///
66	/// This has a limitation that the iterator only yields the first `2^usize::BITS` elements.
67	pub fn iter(&self) -> BinarySubspaceIterator<'_, F> {
68		BinarySubspaceIterator::new(&self.basis)
69	}
70}
71
72impl<F: BinaryField> BinarySubspace<F> {
73	/// Creates a new subspace of this binary subspace with the given dimension.
74	///
75	/// This creates a new sub-subspace using a prefix of the default $\mathbb{F}_2$ basis elements
76	/// of the field.
77	///
78	/// ## Throws
79	///
80	/// * `Error::DomainSizeTooLarge` if `dim` is greater than this subspace's dimension.
81	pub fn with_dim(dim: usize) -> Result<Self, Error> {
82		let basis = (0..dim)
83			.map(|i| F::basis_checked(i).map_err(|_| Error::DomainSizeTooLarge))
84			.collect::<Result<_, _>>()?;
85		Ok(Self { basis })
86	}
87
88	/// Creates a new subspace of this binary subspace with reduced dimension.
89	///
90	/// This creates a new sub-subspace using a prefix of the ordered basis elements.
91	///
92	/// ## Throws
93	///
94	/// * `Error::DomainSizeTooLarge` if `dim` is greater than this subspace's dimension.
95	pub fn reduce_dim(&self, dim: usize) -> Result<Self, Error> {
96		if dim > self.dim() {
97			return Err(Error::DomainSizeTooLarge);
98		}
99		Ok(Self {
100			basis: self.basis[..dim].to_vec(),
101		})
102	}
103}
104
105/// Iterator over all elements of a binary subspace.
106///
107/// Each element is computed as a subset sum (XOR) of the basis elements.
108/// The iterator supports efficient `nth` operation without computing intermediate values.
109#[derive(Debug, Clone)]
110pub struct BinarySubspaceIterator<'a, F> {
111	basis: &'a [F],
112	index: usize,
113	next: Option<F>,
114}
115
116impl<'a, F: BinaryField> BinarySubspaceIterator<'a, F> {
117	fn new(basis: &'a [F]) -> Self {
118		assert!(basis.len() < usize::BITS as usize);
119		Self {
120			basis,
121			index: 0,
122			next: Some(F::ZERO),
123		}
124	}
125}
126
127impl<'a, F: BinaryField> Iterator for BinarySubspaceIterator<'a, F> {
128	type Item = F;
129
130	fn next(&mut self) -> Option<Self::Item> {
131		let ret = self.next?;
132
133		let mut next = ret;
134		let mut i = 0;
135		while (self.index >> i) & 1 == 1 {
136			next -= self.basis[i];
137			i += 1;
138		}
139		self.next = self.basis.get(i).map(|&basis_i| next + basis_i);
140
141		self.index += 1;
142		Some(ret)
143	}
144
145	fn size_hint(&self) -> (usize, Option<usize>) {
146		let last = 1 << self.basis.len();
147		let remaining = last - self.index;
148		(remaining, Some(remaining))
149	}
150
151	fn nth(&mut self, n: usize) -> Option<Self::Item> {
152		match self.index.checked_add(n) {
153			Some(new_index) if new_index < 1 << self.basis.len() => {
154				let new_next = BinarySubspace::new_unchecked(self.basis).get(new_index);
155
156				self.index = new_index;
157				self.next = Some(new_next);
158			}
159			_ => {
160				self.index = 1 << self.basis.len();
161				self.next = None;
162			}
163		}
164
165		self.next()
166	}
167}
168
169impl<'a, F: BinaryField> ExactSizeIterator for BinarySubspaceIterator<'a, F> {
170	fn len(&self) -> usize {
171		let last = 1 << self.basis.len();
172		last - self.index
173	}
174}
175
176impl<'a, F: BinaryField> std::iter::FusedIterator for BinarySubspaceIterator<'a, F> {}
177
178impl<F: BinaryField> Default for BinarySubspace<F> {
179	fn default() -> Self {
180		let basis = (0..F::DEGREE).map(|i| F::basis(i)).collect();
181		Self { basis }
182	}
183}
184
185#[cfg(test)]
186mod tests {
187	use assert_matches::assert_matches;
188	use binius_field::{AESTowerField8b as B8, BinaryField128bGhash as B128, Field};
189
190	use super::*;
191
192	#[test]
193	fn test_default_binary_subspace_iterates_elements() {
194		let subspace = BinarySubspace::<B8>::default();
195		for i in 0..=255 {
196			assert_eq!(subspace.get(i), B8::new(i as u8));
197		}
198	}
199
200	#[test]
201	fn test_binary_subspace_range_error() {
202		let subspace = BinarySubspace::<B8>::default();
203		assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
204	}
205
206	#[test]
207	fn test_default_binary_subspace() {
208		let subspace = BinarySubspace::<B8>::default();
209		assert_eq!(subspace.dim(), 8);
210		assert_eq!(subspace.basis().len(), 8);
211
212		assert_eq!(
213			subspace.basis(),
214			[
215				B8::new(0b00000001),
216				B8::new(0b00000010),
217				B8::new(0b00000100),
218				B8::new(0b00001000),
219				B8::new(0b00010000),
220				B8::new(0b00100000),
221				B8::new(0b01000000),
222				B8::new(0b10000000)
223			]
224		);
225
226		let expected_elements: [u8; 256] = (0..=255).collect::<Vec<_>>().try_into().unwrap();
227
228		for (i, &expected) in expected_elements.iter().enumerate() {
229			assert_eq!(subspace.get(i), B8::new(expected));
230		}
231	}
232
233	#[test]
234	fn test_with_dim_valid() {
235		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
236		assert_eq!(subspace.dim(), 3);
237		assert_eq!(subspace.basis().len(), 3);
238
239		assert_eq!(subspace.basis(), [B8::new(0b001), B8::new(0b010), B8::new(0b100)]);
240
241		let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
242
243		for (i, &expected) in expected_elements.iter().enumerate() {
244			assert_eq!(subspace.get(i), B8::new(expected));
245		}
246	}
247
248	#[test]
249	fn test_with_dim_invalid() {
250		let result = BinarySubspace::<B8>::with_dim(10);
251		assert_matches!(result, Err(Error::DomainSizeTooLarge));
252	}
253
254	#[test]
255	fn test_reduce_dim_valid() {
256		let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
257		let reduced = subspace.reduce_dim(4).unwrap();
258		assert_eq!(reduced.dim(), 4);
259		assert_eq!(reduced.basis().len(), 4);
260
261		assert_eq!(
262			reduced.basis(),
263			[
264				B8::new(0b0001),
265				B8::new(0b0010),
266				B8::new(0b0100),
267				B8::new(0b1000)
268			]
269		);
270
271		let expected_elements: [u8; 16] = (0..16).collect::<Vec<_>>().try_into().unwrap();
272
273		for (i, &expected) in expected_elements.iter().enumerate() {
274			assert_eq!(reduced.get(i), B8::new(expected));
275		}
276	}
277
278	#[test]
279	fn test_reduce_dim_invalid() {
280		let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
281		let result = subspace.reduce_dim(6);
282		assert_matches!(result, Err(Error::DomainSizeTooLarge));
283	}
284
285	#[test]
286	fn test_isomorphic_conversion() {
287		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
288		let iso_subspace: BinarySubspace<B128> = subspace.isomorphic();
289		assert_eq!(iso_subspace.dim(), 3);
290		assert_eq!(iso_subspace.basis().len(), 3);
291
292		assert_eq!(
293			iso_subspace.basis(),
294			[
295				B128::from(B8::new(0b001)),
296				B128::from(B8::new(0b010)),
297				B128::from(B8::new(0b100)),
298			]
299		);
300	}
301
302	#[test]
303	fn test_get_checked_valid() {
304		let subspace = BinarySubspace::<B8>::default();
305		for i in 0..256 {
306			assert!(subspace.get_checked(i).is_ok());
307		}
308	}
309
310	#[test]
311	fn test_get_checked_invalid() {
312		let subspace = BinarySubspace::<B8>::default();
313		assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
314	}
315
316	#[test]
317	fn test_iterate_subspace() {
318		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
319		let elements: Vec<_> = subspace.iter().collect();
320		assert_eq!(elements.len(), 8);
321
322		let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
323
324		for (i, &expected) in expected_elements.iter().enumerate() {
325			assert_eq!(elements[i], B8::new(expected));
326		}
327	}
328
329	#[test]
330	fn test_iterator_matches_get() {
331		let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
332
333		// Test that iterator produces same elements as get()
334		for (i, elem) in subspace.iter().enumerate() {
335			assert_eq!(elem, subspace.get(i), "Mismatch at index {}", i);
336		}
337	}
338
339	#[test]
340	#[allow(clippy::iter_nth_zero)]
341	fn test_iterator_nth() {
342		let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
343
344		// Test nth with various positions
345		let mut iter = subspace.iter();
346		assert_eq!(iter.nth(0), Some(subspace.get(0)));
347		assert_eq!(iter.nth(0), Some(subspace.get(1)));
348		assert_eq!(iter.nth(2), Some(subspace.get(4)));
349		assert_eq!(iter.nth(5), Some(subspace.get(10)));
350
351		// Test nth at the end
352		let mut iter = subspace.iter();
353		assert_eq!(iter.nth(15), Some(subspace.get(15)));
354		assert_eq!(iter.nth(0), None);
355	}
356
357	#[test]
358	fn test_iterator_nth_skips_efficiently() {
359		let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
360
361		// Test that we can jump directly to any position
362		let mut iter = subspace.iter();
363		assert_eq!(iter.nth(30), Some(subspace.get(30)));
364		assert_eq!(iter.next(), Some(subspace.get(31)));
365
366		// Test large skip
367		let mut iter = subspace.iter();
368		assert_eq!(iter.nth(50), Some(subspace.get(50)));
369	}
370
371	#[test]
372	fn test_iterator_size_hint() {
373		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
374		let mut iter = subspace.iter();
375
376		assert_eq!(iter.size_hint(), (8, Some(8)));
377		iter.next();
378		assert_eq!(iter.size_hint(), (7, Some(7)));
379		iter.nth(3);
380		assert_eq!(iter.size_hint(), (3, Some(3)));
381	}
382
383	#[test]
384	fn test_iterator_exact_size() {
385		let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
386		let mut iter = subspace.iter();
387
388		assert_eq!(iter.len(), 16);
389		iter.next();
390		assert_eq!(iter.len(), 15);
391		iter.nth(5);
392		assert_eq!(iter.len(), 9);
393	}
394
395	#[test]
396	fn test_iterator_empty_subspace() {
397		let subspace = BinarySubspace::<B8>::with_dim(0).unwrap();
398		let mut iter = subspace.iter();
399
400		// Subspace of dimension 0 has only one element: zero
401		assert_eq!(iter.len(), 1);
402		assert_eq!(iter.next(), Some(B8::ZERO));
403		assert_eq!(iter.next(), None);
404	}
405
406	#[test]
407	fn test_iterator_full_iteration() {
408		let subspace = BinarySubspace::<B8>::default();
409		let collected: Vec<_> = subspace.iter().collect();
410
411		assert_eq!(collected.len(), 256);
412		for (i, elem) in collected.iter().enumerate() {
413			assert_eq!(*elem, subspace.get(i));
414		}
415	}
416
417	#[test]
418	fn test_iterator_partial_then_nth() {
419		let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
420		let mut iter = subspace.iter();
421
422		// Iterate a few elements
423		assert_eq!(iter.next(), Some(subspace.get(0)));
424		assert_eq!(iter.next(), Some(subspace.get(1)));
425		assert_eq!(iter.next(), Some(subspace.get(2)));
426
427		// Then skip ahead
428		assert_eq!(iter.nth(5), Some(subspace.get(8)));
429		assert_eq!(iter.next(), Some(subspace.get(9)));
430	}
431
432	#[test]
433	fn test_iterator_clone() {
434		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
435		let mut iter1 = subspace.iter();
436
437		iter1.next();
438		iter1.next();
439
440		let mut iter2 = iter1.clone();
441
442		// Both iterators should produce the same remaining elements
443		assert_eq!(iter1.next(), iter2.next());
444		assert_eq!(iter1.collect::<Vec<_>>(), iter2.collect::<Vec<_>>());
445	}
446}