binius_math/
binary_subspace.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::Deref;
4
5use binius_field::{BinaryField, BinaryField1b};
6
7use super::error::Error;
8
9/// An $F_2$-linear subspace of a binary field.
10///
11/// The subspace is defined by a basis of elements from a binary field. The basis elements are
12/// ordered, which implies an ordering on the subspace elements.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct BinarySubspace<F, Data: Deref<Target = [F]> = Vec<F>> {
15	basis: Data,
16}
17
18impl<F: BinaryField, Data: Deref<Target = [F]>> BinarySubspace<F, Data> {
19	/// Creates a new subspace from a vector of ordered basis elements.
20	///
21	/// This constructor does not check that the basis elements are linearly independent.
22	pub const fn new_unchecked(basis: Data) -> Self {
23		Self { basis }
24	}
25
26	/// Creates a new subspace isomorphic to the given one.
27	pub fn isomorphic<FIso>(&self) -> BinarySubspace<FIso>
28	where
29		FIso: BinaryField + From<F>,
30	{
31		BinarySubspace {
32			basis: self.basis.iter().copied().map(Into::into).collect(),
33		}
34	}
35
36	/// Returns the dimension of the subspace.
37	pub fn dim(&self) -> usize {
38		self.basis.len()
39	}
40
41	/// Returns the slice of ordered basis elements.
42	pub fn basis(&self) -> &[F] {
43		&self.basis
44	}
45
46	pub fn get(&self, index: usize) -> F {
47		self.basis
48			.iter()
49			.enumerate()
50			.map(|(i, &basis_i)| basis_i * BinaryField1b::from((index >> i) & 1 == 1))
51			.sum()
52	}
53
54	pub fn get_checked(&self, index: usize) -> Result<F, Error> {
55		if index >= 1 << self.basis.len() {
56			return Err(Error::ArgumentRangeError {
57				arg: "index".to_string(),
58				range: 0..1 << self.basis.len(),
59			});
60		}
61		Ok(self.get(index))
62	}
63
64	/// Returns an iterator over all elements of the subspace in order.
65	///
66	/// This has a limitation that the iterator only yields the first `2^usize::BITS` elements.
67	pub fn iter(&self) -> BinarySubspaceIterator<'_, F> {
68		BinarySubspaceIterator::new(&self.basis)
69	}
70}
71
72impl<F: BinaryField> BinarySubspace<F> {
73	/// Creates a new subspace of this binary subspace with the given dimension.
74	///
75	/// This creates a new sub-subspace using a prefix of the default $\mathbb{F}_2$ basis elements
76	/// of the field.
77	///
78	/// ## Throws
79	///
80	/// * `Error::DomainSizeTooLarge` if `dim` is greater than this subspace's dimension.
81	pub fn with_dim(dim: usize) -> Result<Self, Error> {
82		let basis = (0..dim)
83			.map(|i| F::basis_checked(i).map_err(|_| Error::DomainSizeTooLarge))
84			.collect::<Result<_, _>>()?;
85		Ok(Self { basis })
86	}
87
88	/// Creates a new subspace of this binary subspace with reduced dimension.
89	///
90	/// This creates a new sub-subspace using a prefix of the ordered basis elements.
91	///
92	/// ## Throws
93	///
94	/// * `Error::DomainSizeTooLarge` if `dim` is greater than this subspace's dimension.
95	pub fn reduce_dim(&self, dim: usize) -> Result<Self, Error> {
96		if dim > self.dim() {
97			return Err(Error::DomainSizeTooLarge);
98		}
99		Ok(Self {
100			basis: self.basis[..dim].to_vec(),
101		})
102	}
103}
104
105/// Iterator over all elements of a binary subspace.
106///
107/// Each element is computed as a subset sum (XOR) of the basis elements.
108/// The iterator supports efficient `nth` operation without computing intermediate values.
109#[derive(Debug, Clone)]
110pub struct BinarySubspaceIterator<'a, F> {
111	basis: &'a [F],
112	index: usize,
113	next: Option<F>,
114}
115
116impl<'a, F: BinaryField> BinarySubspaceIterator<'a, F> {
117	pub fn new(basis: &'a [F]) -> Self {
118		assert!(basis.len() < usize::BITS as usize);
119		Self {
120			basis,
121			index: 0,
122			next: Some(F::ZERO),
123		}
124	}
125}
126
127impl<'a, F: BinaryField> Iterator for BinarySubspaceIterator<'a, F> {
128	type Item = F;
129
130	#[inline]
131	fn next(&mut self) -> Option<Self::Item> {
132		let ret = self.next?;
133
134		let mut next = ret;
135		let mut i = 0;
136		while (self.index >> i) & 1 == 1 {
137			next -= self.basis[i];
138			i += 1;
139		}
140		self.next = self.basis.get(i).map(|&basis_i| next + basis_i);
141
142		self.index += 1;
143		Some(ret)
144	}
145
146	fn size_hint(&self) -> (usize, Option<usize>) {
147		let last = 1 << self.basis.len();
148		let remaining = last - self.index;
149		(remaining, Some(remaining))
150	}
151
152	fn nth(&mut self, n: usize) -> Option<Self::Item> {
153		match self.index.checked_add(n) {
154			Some(new_index) if new_index < 1 << self.basis.len() => {
155				let new_next = BinarySubspace::new_unchecked(self.basis).get(new_index);
156
157				self.index = new_index;
158				self.next = Some(new_next);
159			}
160			_ => {
161				self.index = 1 << self.basis.len();
162				self.next = None;
163			}
164		}
165
166		self.next()
167	}
168}
169
170impl<'a, F: BinaryField> ExactSizeIterator for BinarySubspaceIterator<'a, F> {
171	fn len(&self) -> usize {
172		let last = 1 << self.basis.len();
173		last - self.index
174	}
175}
176
177impl<'a, F: BinaryField> std::iter::FusedIterator for BinarySubspaceIterator<'a, F> {}
178
179impl<F: BinaryField> Default for BinarySubspace<F> {
180	fn default() -> Self {
181		let basis = (0..F::DEGREE).map(|i| F::basis(i)).collect();
182		Self { basis }
183	}
184}
185
186#[cfg(test)]
187mod tests {
188	use assert_matches::assert_matches;
189	use binius_field::{AESTowerField8b as B8, BinaryField128bGhash as B128, Field};
190
191	use super::*;
192
193	#[test]
194	fn test_default_binary_subspace_iterates_elements() {
195		let subspace = BinarySubspace::<B8>::default();
196		for i in 0..=255 {
197			assert_eq!(subspace.get(i), B8::new(i as u8));
198		}
199	}
200
201	#[test]
202	fn test_binary_subspace_range_error() {
203		let subspace = BinarySubspace::<B8>::default();
204		assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
205	}
206
207	#[test]
208	fn test_default_binary_subspace() {
209		let subspace = BinarySubspace::<B8>::default();
210		assert_eq!(subspace.dim(), 8);
211		assert_eq!(subspace.basis().len(), 8);
212
213		assert_eq!(
214			subspace.basis(),
215			[
216				B8::new(0b00000001),
217				B8::new(0b00000010),
218				B8::new(0b00000100),
219				B8::new(0b00001000),
220				B8::new(0b00010000),
221				B8::new(0b00100000),
222				B8::new(0b01000000),
223				B8::new(0b10000000)
224			]
225		);
226
227		let expected_elements: [u8; 256] = (0..=255).collect::<Vec<_>>().try_into().unwrap();
228
229		for (i, &expected) in expected_elements.iter().enumerate() {
230			assert_eq!(subspace.get(i), B8::new(expected));
231		}
232	}
233
234	#[test]
235	fn test_with_dim_valid() {
236		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
237		assert_eq!(subspace.dim(), 3);
238		assert_eq!(subspace.basis().len(), 3);
239
240		assert_eq!(subspace.basis(), [B8::new(0b001), B8::new(0b010), B8::new(0b100)]);
241
242		let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
243
244		for (i, &expected) in expected_elements.iter().enumerate() {
245			assert_eq!(subspace.get(i), B8::new(expected));
246		}
247	}
248
249	#[test]
250	fn test_with_dim_invalid() {
251		let result = BinarySubspace::<B8>::with_dim(10);
252		assert_matches!(result, Err(Error::DomainSizeTooLarge));
253	}
254
255	#[test]
256	fn test_reduce_dim_valid() {
257		let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
258		let reduced = subspace.reduce_dim(4).unwrap();
259		assert_eq!(reduced.dim(), 4);
260		assert_eq!(reduced.basis().len(), 4);
261
262		assert_eq!(
263			reduced.basis(),
264			[
265				B8::new(0b0001),
266				B8::new(0b0010),
267				B8::new(0b0100),
268				B8::new(0b1000)
269			]
270		);
271
272		let expected_elements: [u8; 16] = (0..16).collect::<Vec<_>>().try_into().unwrap();
273
274		for (i, &expected) in expected_elements.iter().enumerate() {
275			assert_eq!(reduced.get(i), B8::new(expected));
276		}
277	}
278
279	#[test]
280	fn test_reduce_dim_invalid() {
281		let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
282		let result = subspace.reduce_dim(6);
283		assert_matches!(result, Err(Error::DomainSizeTooLarge));
284	}
285
286	#[test]
287	fn test_isomorphic_conversion() {
288		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
289		let iso_subspace: BinarySubspace<B128> = subspace.isomorphic();
290		assert_eq!(iso_subspace.dim(), 3);
291		assert_eq!(iso_subspace.basis().len(), 3);
292
293		assert_eq!(
294			iso_subspace.basis(),
295			[
296				B128::from(B8::new(0b001)),
297				B128::from(B8::new(0b010)),
298				B128::from(B8::new(0b100)),
299			]
300		);
301	}
302
303	#[test]
304	fn test_get_checked_valid() {
305		let subspace = BinarySubspace::<B8>::default();
306		for i in 0..256 {
307			assert!(subspace.get_checked(i).is_ok());
308		}
309	}
310
311	#[test]
312	fn test_get_checked_invalid() {
313		let subspace = BinarySubspace::<B8>::default();
314		assert_matches!(subspace.get_checked(256), Err(Error::ArgumentRangeError { .. }));
315	}
316
317	#[test]
318	fn test_iterate_subspace() {
319		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
320		let elements: Vec<_> = subspace.iter().collect();
321		assert_eq!(elements.len(), 8);
322
323		let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
324
325		for (i, &expected) in expected_elements.iter().enumerate() {
326			assert_eq!(elements[i], B8::new(expected));
327		}
328	}
329
330	#[test]
331	fn test_iterator_matches_get() {
332		let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
333
334		// Test that iterator produces same elements as get()
335		for (i, elem) in subspace.iter().enumerate() {
336			assert_eq!(elem, subspace.get(i), "Mismatch at index {}", i);
337		}
338	}
339
340	#[test]
341	#[allow(clippy::iter_nth_zero)]
342	fn test_iterator_nth() {
343		let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
344
345		// Test nth with various positions
346		let mut iter = subspace.iter();
347		assert_eq!(iter.nth(0), Some(subspace.get(0)));
348		assert_eq!(iter.nth(0), Some(subspace.get(1)));
349		assert_eq!(iter.nth(2), Some(subspace.get(4)));
350		assert_eq!(iter.nth(5), Some(subspace.get(10)));
351
352		// Test nth at the end
353		let mut iter = subspace.iter();
354		assert_eq!(iter.nth(15), Some(subspace.get(15)));
355		assert_eq!(iter.nth(0), None);
356	}
357
358	#[test]
359	fn test_iterator_nth_skips_efficiently() {
360		let subspace = BinarySubspace::<B8>::with_dim(6).unwrap();
361
362		// Test that we can jump directly to any position
363		let mut iter = subspace.iter();
364		assert_eq!(iter.nth(30), Some(subspace.get(30)));
365		assert_eq!(iter.next(), Some(subspace.get(31)));
366
367		// Test large skip
368		let mut iter = subspace.iter();
369		assert_eq!(iter.nth(50), Some(subspace.get(50)));
370	}
371
372	#[test]
373	fn test_iterator_size_hint() {
374		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
375		let mut iter = subspace.iter();
376
377		assert_eq!(iter.size_hint(), (8, Some(8)));
378		iter.next();
379		assert_eq!(iter.size_hint(), (7, Some(7)));
380		iter.nth(3);
381		assert_eq!(iter.size_hint(), (3, Some(3)));
382	}
383
384	#[test]
385	fn test_iterator_exact_size() {
386		let subspace = BinarySubspace::<B8>::with_dim(4).unwrap();
387		let mut iter = subspace.iter();
388
389		assert_eq!(iter.len(), 16);
390		iter.next();
391		assert_eq!(iter.len(), 15);
392		iter.nth(5);
393		assert_eq!(iter.len(), 9);
394	}
395
396	#[test]
397	fn test_iterator_empty_subspace() {
398		let subspace = BinarySubspace::<B8>::with_dim(0).unwrap();
399		let mut iter = subspace.iter();
400
401		// Subspace of dimension 0 has only one element: zero
402		assert_eq!(iter.len(), 1);
403		assert_eq!(iter.next(), Some(B8::ZERO));
404		assert_eq!(iter.next(), None);
405	}
406
407	#[test]
408	fn test_iterator_full_iteration() {
409		let subspace = BinarySubspace::<B8>::default();
410		let collected: Vec<_> = subspace.iter().collect();
411
412		assert_eq!(collected.len(), 256);
413		for (i, elem) in collected.iter().enumerate() {
414			assert_eq!(*elem, subspace.get(i));
415		}
416	}
417
418	#[test]
419	fn test_iterator_partial_then_nth() {
420		let subspace = BinarySubspace::<B8>::with_dim(5).unwrap();
421		let mut iter = subspace.iter();
422
423		// Iterate a few elements
424		assert_eq!(iter.next(), Some(subspace.get(0)));
425		assert_eq!(iter.next(), Some(subspace.get(1)));
426		assert_eq!(iter.next(), Some(subspace.get(2)));
427
428		// Then skip ahead
429		assert_eq!(iter.nth(5), Some(subspace.get(8)));
430		assert_eq!(iter.next(), Some(subspace.get(9)));
431	}
432
433	#[test]
434	fn test_iterator_clone() {
435		let subspace = BinarySubspace::<B8>::with_dim(3).unwrap();
436		let mut iter1 = subspace.iter();
437
438		iter1.next();
439		iter1.next();
440
441		let mut iter2 = iter1.clone();
442
443		// Both iterators should produce the same remaining elements
444		assert_eq!(iter1.next(), iter2.next());
445		assert_eq!(iter1.collect::<Vec<_>>(), iter2.collect::<Vec<_>>());
446	}
447}