binius_utils/
iter.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::iter::FusedIterator;
4
5pub trait IterExtensions: Iterator + Sized {
6	fn map_skippable<R, F>(self, f: F) -> SkippableMap<Self, F>
7	where
8		F: Fn(Self::Item) -> R,
9	{
10		SkippableMap::new(self, f)
11	}
12}
13
14impl<T: Iterator + Sized> IterExtensions for T {}
15
16/// A map iterator that skips values when `nth` is called.
17///
18/// `std::iter::Map` guarantees that the function will be called for every value of the inner
19/// iterator. However, it makes it impossible to implement `nth` in efficient way. `SkippableMap`
20/// pretty much follows the interface of the `std::iter::Map` except that F is required to be `Fn`
21/// instead of `FnMut`.
22#[derive(Debug, Clone)]
23pub struct SkippableMap<I, F> {
24	iter: I,
25	func: F,
26}
27
28impl<I, F> SkippableMap<I, F> {
29	const fn new(iter: I, func: F) -> Self {
30		Self { iter, func }
31	}
32}
33
34impl<B, I: Iterator, F> Iterator for SkippableMap<I, F>
35where
36	F: Fn(I::Item) -> B,
37{
38	type Item = B;
39
40	#[inline]
41	fn next(&mut self) -> Option<B> {
42		self.iter.next().map(&self.func)
43	}
44
45	#[inline]
46	fn size_hint(&self) -> (usize, Option<usize>) {
47		self.iter.size_hint()
48	}
49
50	#[inline]
51	fn fold<Acc, G>(self, init: Acc, mut g: G) -> Acc
52	where
53		G: FnMut(Acc, Self::Item) -> Acc,
54	{
55		let func = self.func;
56		self.iter.fold(init, move |acc, elt| g(acc, func(elt)))
57	}
58
59	// Consider `advance_by` once it gets stabilised (https://github.com/rust-lang/rust/issues/77404)
60	fn nth(&mut self, n: usize) -> Option<Self::Item> {
61		self.iter.nth(n).map(&self.func)
62	}
63}
64
65impl<B, I: ExactSizeIterator, F> ExactSizeIterator for SkippableMap<I, F>
66where
67	F: Fn(I::Item) -> B,
68{
69	fn len(&self) -> usize {
70		self.iter.len()
71	}
72}
73
74impl<B, I: FusedIterator, F> FusedIterator for SkippableMap<I, F> where F: Fn(I::Item) -> B {}
75
76#[cfg(test)]
77mod tests {
78	use std::cell::RefCell;
79
80	use super::*;
81
82	#[test]
83	fn test_map_skippable() {
84		// Test that obervable behaviour is equivalent to `map`
85		let vals = [1, 2, 3, 4, 5];
86		assert_eq!(
87			vals.iter().map(|v| v * v).collect::<Vec<_>>(),
88			vals.iter().map_skippable(|v| v * v).collect::<Vec<_>>()
89		);
90		assert_eq!(
91			vals.iter().map(|v| v * v).fold(0, |l, r| l + 2 * r),
92			vals.iter()
93				.map_skippable(|v| v * v)
94				.fold(0, |l, r| l + 2 * r)
95		);
96		assert_eq!(vals.iter().size_hint(), vals.iter().map_skippable(|v| v * v).size_hint());
97
98		// Test that `nth` skips values
99		let count = RefCell::new(0);
100		let mut iter = vals.iter().map_skippable(|i| {
101			*count.borrow_mut() += 1;
102			i * i
103		});
104		assert_eq!(iter.nth(3), Some(16));
105		assert_eq!(*count.borrow(), 1);
106		assert_eq!(iter.nth(2), None);
107		assert_eq!(*count.borrow(), 1);
108	}
109}