binius_math/
binary_subspace.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::Deref;
4
5use binius_field::{BinaryField, BinaryField1b};
6
7/// An $F_2$-linear subspace of a binary field.
8///
9/// The subspace is defined by a basis of elements from a binary field. The basis elements are
10/// ordered, which implies an ordering on the subspace elements.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct BinarySubspace<F, Data: Deref<Target = [F]> = Vec<F>> {
13	basis: Data,
14}
15
16impl<F: BinaryField, Data: Deref<Target = [F]>> BinarySubspace<F, Data> {
17	/// Creates a new subspace from a vector of ordered basis elements.
18	///
19	/// This constructor does not check that the basis elements are linearly independent.
20	pub const fn new_unchecked(basis: Data) -> Self {
21		Self { basis }
22	}
23
24	/// Creates a new subspace isomorphic to the given one.
25	pub fn isomorphic<FIso>(&self) -> BinarySubspace<FIso>
26	where
27		FIso: BinaryField + From<F>,
28	{
29		BinarySubspace {
30			basis: self.basis.iter().copied().map(Into::into).collect(),
31		}
32	}
33
34	/// Returns the dimension of the subspace.
35	pub fn dim(&self) -> usize {
36		self.basis.len()
37	}
38
39	/// Returns the slice of ordered basis elements.
40	pub fn basis(&self) -> &[F] {
41		&self.basis
42	}
43
44	pub fn get(&self, index: usize) -> F {
45		assert!(index < 1 << self.dim(), "precondition: index must be less than 2^dim");
46		self.basis
47			.iter()
48			.enumerate()
49			.map(|(i, &basis_i)| basis_i * BinaryField1b::from((index >> i) & 1 == 1))
50			.sum()
51	}
52
53	/// Returns an iterator over all elements of the subspace in order.
54	///
55	/// This has a limitation that the iterator only yields the first `2^usize::BITS` elements.
56	pub fn iter(&self) -> BinarySubspaceIterator<'_, F> {
57		BinarySubspaceIterator::new(&self.basis)
58	}
59}
60
61impl<F: BinaryField> BinarySubspace<F> {
62	/// Creates a new subspace of this binary subspace with the given dimension.
63	///
64	/// This creates a new sub-subspace using a prefix of the default $\mathbb{F}_2$ basis elements
65	/// of the field.
66	///
67	/// ## Preconditions
68	///
69	/// * `dim` must be at most `F::DEGREE`
70	pub fn with_dim(dim: usize) -> Self {
71		assert!(dim <= F::DEGREE, "precondition: dim must be at most F::DEGREE");
72		let basis = (0..dim).map(|i| F::basis(i)).collect();
73		Self { basis }
74	}
75
76	/// Creates a new subspace of this binary subspace with reduced dimension.
77	///
78	/// This creates a new sub-subspace using a prefix of the ordered basis elements.
79	///
80	/// ## Preconditions
81	///
82	/// * `dim` must be at most this subspace's dimension
83	pub fn reduce_dim(&self, dim: usize) -> Self {
84		assert!(dim <= self.dim(), "precondition: dim must be at most this subspace's dimension");
85		Self {
86			basis: self.basis[..dim].to_vec(),
87		}
88	}
89}
90
91/// Iterator over all elements of a binary subspace.
92///
93/// Each element is computed as a subset sum (XOR) of the basis elements.
94/// The iterator supports efficient `nth` operation without computing intermediate values.
95#[derive(Debug, Clone)]
96pub struct BinarySubspaceIterator<'a, F> {
97	basis: &'a [F],
98	index: usize,
99	next: Option<F>,
100}
101
102impl<'a, F: BinaryField> BinarySubspaceIterator<'a, F> {
103	pub fn new(basis: &'a [F]) -> Self {
104		assert!(basis.len() < usize::BITS as usize);
105		Self {
106			basis,
107			index: 0,
108			next: Some(F::ZERO),
109		}
110	}
111}
112
113impl<'a, F: BinaryField> Iterator for BinarySubspaceIterator<'a, F> {
114	type Item = F;
115
116	#[inline]
117	fn next(&mut self) -> Option<Self::Item> {
118		let ret = self.next?;
119
120		let mut next = ret;
121		let mut i = 0;
122		while (self.index >> i) & 1 == 1 {
123			next -= self.basis[i];
124			i += 1;
125		}
126		self.next = self.basis.get(i).map(|&basis_i| next + basis_i);
127
128		self.index += 1;
129		Some(ret)
130	}
131
132	fn size_hint(&self) -> (usize, Option<usize>) {
133		let last = 1 << self.basis.len();
134		let remaining = last - self.index;
135		(remaining, Some(remaining))
136	}
137
138	fn nth(&mut self, n: usize) -> Option<Self::Item> {
139		match self.index.checked_add(n) {
140			Some(new_index) if new_index < 1 << self.basis.len() => {
141				let new_next = BinarySubspace::new_unchecked(self.basis).get(new_index);
142
143				self.index = new_index;
144				self.next = Some(new_next);
145			}
146			_ => {
147				self.index = 1 << self.basis.len();
148				self.next = None;
149			}
150		}
151
152		self.next()
153	}
154}
155
156impl<'a, F: BinaryField> ExactSizeIterator for BinarySubspaceIterator<'a, F> {
157	fn len(&self) -> usize {
158		let last = 1 << self.basis.len();
159		last - self.index
160	}
161}
162
163impl<'a, F: BinaryField> std::iter::FusedIterator for BinarySubspaceIterator<'a, F> {}
164
165impl<F: BinaryField> Default for BinarySubspace<F> {
166	fn default() -> Self {
167		let basis = (0..F::DEGREE).map(|i| F::basis(i)).collect();
168		Self { basis }
169	}
170}
171
172#[cfg(test)]
173mod tests {
174	use binius_field::{AESTowerField8b as B8, BinaryField128bGhash as B128, Field};
175
176	use super::*;
177
178	#[test]
179	fn test_default_binary_subspace_iterates_elements() {
180		let subspace = BinarySubspace::<B8>::default();
181		for i in 0..=255 {
182			assert_eq!(subspace.get(i), B8::new(i as u8));
183		}
184	}
185
186	#[test]
187	#[should_panic(expected = "precondition")]
188	fn test_binary_subspace_range_error() {
189		let subspace = BinarySubspace::<B8>::default();
190		let _ = subspace.get(256);
191	}
192
193	#[test]
194	fn test_default_binary_subspace() {
195		let subspace = BinarySubspace::<B8>::default();
196		assert_eq!(subspace.dim(), 8);
197		assert_eq!(subspace.basis().len(), 8);
198
199		assert_eq!(
200			subspace.basis(),
201			[
202				B8::new(0b00000001),
203				B8::new(0b00000010),
204				B8::new(0b00000100),
205				B8::new(0b00001000),
206				B8::new(0b00010000),
207				B8::new(0b00100000),
208				B8::new(0b01000000),
209				B8::new(0b10000000)
210			]
211		);
212
213		let expected_elements: [u8; 256] = (0..=255).collect::<Vec<_>>().try_into().unwrap();
214
215		for (i, &expected) in expected_elements.iter().enumerate() {
216			assert_eq!(subspace.get(i), B8::new(expected));
217		}
218	}
219
220	#[test]
221	fn test_with_dim_valid() {
222		let subspace = BinarySubspace::<B8>::with_dim(3);
223		assert_eq!(subspace.dim(), 3);
224		assert_eq!(subspace.basis().len(), 3);
225
226		assert_eq!(subspace.basis(), [B8::new(0b001), B8::new(0b010), B8::new(0b100)]);
227
228		let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
229
230		for (i, &expected) in expected_elements.iter().enumerate() {
231			assert_eq!(subspace.get(i), B8::new(expected));
232		}
233	}
234
235	#[test]
236	#[should_panic(expected = "precondition")]
237	fn test_with_dim_invalid() {
238		let _ = BinarySubspace::<B8>::with_dim(10);
239	}
240
241	#[test]
242	fn test_reduce_dim_valid() {
243		let subspace = BinarySubspace::<B8>::with_dim(6);
244		let reduced = subspace.reduce_dim(4);
245		assert_eq!(reduced.dim(), 4);
246		assert_eq!(reduced.basis().len(), 4);
247
248		assert_eq!(
249			reduced.basis(),
250			[
251				B8::new(0b0001),
252				B8::new(0b0010),
253				B8::new(0b0100),
254				B8::new(0b1000)
255			]
256		);
257
258		let expected_elements: [u8; 16] = (0..16).collect::<Vec<_>>().try_into().unwrap();
259
260		for (i, &expected) in expected_elements.iter().enumerate() {
261			assert_eq!(reduced.get(i), B8::new(expected));
262		}
263	}
264
265	#[test]
266	#[should_panic(expected = "precondition")]
267	fn test_reduce_dim_invalid() {
268		let subspace = BinarySubspace::<B8>::with_dim(4);
269		let _ = subspace.reduce_dim(6);
270	}
271
272	#[test]
273	fn test_isomorphic_conversion() {
274		let subspace = BinarySubspace::<B8>::with_dim(3);
275		let iso_subspace: BinarySubspace<B128> = subspace.isomorphic();
276		assert_eq!(iso_subspace.dim(), 3);
277		assert_eq!(iso_subspace.basis().len(), 3);
278
279		assert_eq!(
280			iso_subspace.basis(),
281			[
282				B128::from(B8::new(0b001)),
283				B128::from(B8::new(0b010)),
284				B128::from(B8::new(0b100)),
285			]
286		);
287	}
288
289	#[test]
290	fn test_iterate_subspace() {
291		let subspace = BinarySubspace::<B8>::with_dim(3);
292		let elements: Vec<_> = subspace.iter().collect();
293		assert_eq!(elements.len(), 8);
294
295		let expected_elements: [u8; 8] = [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111];
296
297		for (i, &expected) in expected_elements.iter().enumerate() {
298			assert_eq!(elements[i], B8::new(expected));
299		}
300	}
301
302	#[test]
303	fn test_iterator_matches_get() {
304		let subspace = BinarySubspace::<B8>::with_dim(5);
305
306		// Test that iterator produces same elements as get()
307		for (i, elem) in subspace.iter().enumerate() {
308			assert_eq!(elem, subspace.get(i), "Mismatch at index {}", i);
309		}
310	}
311
312	#[test]
313	#[allow(clippy::iter_nth_zero)]
314	fn test_iterator_nth() {
315		let subspace = BinarySubspace::<B8>::with_dim(4);
316
317		// Test nth with various positions
318		let mut iter = subspace.iter();
319		assert_eq!(iter.nth(0), Some(subspace.get(0)));
320		assert_eq!(iter.nth(0), Some(subspace.get(1)));
321		assert_eq!(iter.nth(2), Some(subspace.get(4)));
322		assert_eq!(iter.nth(5), Some(subspace.get(10)));
323
324		// Test nth at the end
325		let mut iter = subspace.iter();
326		assert_eq!(iter.nth(15), Some(subspace.get(15)));
327		assert_eq!(iter.nth(0), None);
328	}
329
330	#[test]
331	fn test_iterator_nth_skips_efficiently() {
332		let subspace = BinarySubspace::<B8>::with_dim(6);
333
334		// Test that we can jump directly to any position
335		let mut iter = subspace.iter();
336		assert_eq!(iter.nth(30), Some(subspace.get(30)));
337		assert_eq!(iter.next(), Some(subspace.get(31)));
338
339		// Test large skip
340		let mut iter = subspace.iter();
341		assert_eq!(iter.nth(50), Some(subspace.get(50)));
342	}
343
344	#[test]
345	fn test_iterator_size_hint() {
346		let subspace = BinarySubspace::<B8>::with_dim(3);
347		let mut iter = subspace.iter();
348
349		assert_eq!(iter.size_hint(), (8, Some(8)));
350		iter.next();
351		assert_eq!(iter.size_hint(), (7, Some(7)));
352		iter.nth(3);
353		assert_eq!(iter.size_hint(), (3, Some(3)));
354	}
355
356	#[test]
357	fn test_iterator_exact_size() {
358		let subspace = BinarySubspace::<B8>::with_dim(4);
359		let mut iter = subspace.iter();
360
361		assert_eq!(iter.len(), 16);
362		iter.next();
363		assert_eq!(iter.len(), 15);
364		iter.nth(5);
365		assert_eq!(iter.len(), 9);
366	}
367
368	#[test]
369	fn test_iterator_empty_subspace() {
370		let subspace = BinarySubspace::<B8>::with_dim(0);
371		let mut iter = subspace.iter();
372
373		// Subspace of dimension 0 has only one element: zero
374		assert_eq!(iter.len(), 1);
375		assert_eq!(iter.next(), Some(B8::ZERO));
376		assert_eq!(iter.next(), None);
377	}
378
379	#[test]
380	fn test_iterator_full_iteration() {
381		let subspace = BinarySubspace::<B8>::default();
382		let collected: Vec<_> = subspace.iter().collect();
383
384		assert_eq!(collected.len(), 256);
385		for (i, elem) in collected.iter().enumerate() {
386			assert_eq!(*elem, subspace.get(i));
387		}
388	}
389
390	#[test]
391	fn test_iterator_partial_then_nth() {
392		let subspace = BinarySubspace::<B8>::with_dim(5);
393		let mut iter = subspace.iter();
394
395		// Iterate a few elements
396		assert_eq!(iter.next(), Some(subspace.get(0)));
397		assert_eq!(iter.next(), Some(subspace.get(1)));
398		assert_eq!(iter.next(), Some(subspace.get(2)));
399
400		// Then skip ahead
401		assert_eq!(iter.nth(5), Some(subspace.get(8)));
402		assert_eq!(iter.next(), Some(subspace.get(9)));
403	}
404
405	#[test]
406	fn test_iterator_clone() {
407		let subspace = BinarySubspace::<B8>::with_dim(3);
408		let mut iter1 = subspace.iter();
409
410		iter1.next();
411		iter1.next();
412
413		let mut iter2 = iter1.clone();
414
415		// Both iterators should produce the same remaining elements
416		assert_eq!(iter1.next(), iter2.next());
417		assert_eq!(iter1.collect::<Vec<_>>(), iter2.collect::<Vec<_>>());
418	}
419}