binius_utils/
sorting.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use itertools::Itertools;
4
5/// Returns whether the given values are sorted in ascending order.
6pub fn is_sorted_ascending<T: PartialOrd + Clone>(values: impl Iterator<Item = T>) -> bool {
7	!values.tuple_windows().any(|(a, b)| a > b)
8}
9
10/// Stable sorts a collection of objects based on a key function and an optional descending flag.
11///
12/// This function takes a collection of objects, sorts them stably based on the value returned by a
13/// key function, and can optionally use descending sort order. It returns a tuple containing a vector
14/// of the original indices of the objects and a vector of the sorted objects.
15///
16/// # Arguments
17///
18/// * `objs`: An iterable collection of objects to be sorted.
19/// * `key`: A function that takes a reference to an object and returns a `usize` value used for sorting.
20/// * `descending`: A boolean flag indicating whether to sort in descending order.
21///
22/// # Returns
23///
24/// A tuple where the first element is a vector of the original indices of the objects and the second
25/// element is a vector of the sorted objects.
26///
27/// # Note
28///
29/// This function uses stable sorting to ensure consistency in ordering objects that compare as equal.
30pub fn stable_sort<T>(
31	objs: impl IntoIterator<Item = T>,
32	key: impl Fn(&T) -> usize,
33	descending: bool,
34) -> (Vec<usize>, Vec<T>) {
35	let mut indexed_objs = objs.into_iter().enumerate().collect::<Vec<_>>();
36	// NOTE: Important to use stable sorting for prover-verifier consistency!
37	if descending {
38		indexed_objs.sort_by(|a, b| key(&b.1).cmp(&key(&a.1)));
39	} else {
40		indexed_objs.sort_by(|a, b| key(&a.1).cmp(&key(&b.1)));
41	}
42	indexed_objs.into_iter().unzip()
43}
44
45/// Restores the original order of a collection of objects based on their original indices.
46///
47/// This function takes a collection of objects and their corresponding original indices, and returns
48/// a vector of the objects sorted back to their original order.
49///
50/// # Arguments
51///
52/// - `original_indices`: An iterable collection of the original indices of the objects.
53/// - `sorted_objs`: An iterable collection of the objects that have been sorted.
54///
55/// # Returns
56///
57/// A vector of objects restored to their original order based on the provided indices.
58///
59/// # Note
60///
61/// This function assumes that the length of `original_indices` and `sorted_objs` are the same, and that
62/// `original_indices` contains unique values representing valid indices.
63pub fn unsort<T>(
64	original_indices: impl IntoIterator<Item = usize>,
65	sorted_objs: impl IntoIterator<Item = T>,
66) -> Vec<T> {
67	let mut temp = original_indices
68		.into_iter()
69		.zip(sorted_objs)
70		.collect::<Vec<_>>();
71	temp.sort_by_key(|(i, _)| *i);
72	temp.into_iter().map(|(_, obj)| obj).collect()
73}
74
75#[cfg(test)]
76mod tests {
77	use super::*;
78
79	#[test]
80	fn test_stable_sort() {
81		let items = vec![
82			("apple", 3),
83			("banana", 2),
84			("cherry", 3),
85			("date", 1),
86			("elderberry", 2),
87		];
88
89		// Sort by the second element of the tuple (the number) in ascending order
90		let key = |item: &(&str, usize)| item.1;
91
92		let (indices_asc, sorted_items_asc) = stable_sort(items.clone(), key, false);
93		assert_eq!(indices_asc, vec![3, 1, 4, 0, 2]); // Expected original indices
94		assert_eq!(
95			sorted_items_asc,
96			vec![
97				("date", 1),
98				("banana", 2),
99				("elderberry", 2),
100				("apple", 3),
101				("cherry", 3)
102			]
103		);
104
105		// Sort by the second element of the tuple (the number) in descending order
106		let (indices_desc, sorted_items_desc) = stable_sort(items, key, true);
107		assert_eq!(indices_desc, vec![0, 2, 1, 4, 3]); // Expected original indices
108		assert_eq!(
109			sorted_items_desc,
110			vec![
111				("apple", 3),
112				("cherry", 3),
113				("banana", 2),
114				("elderberry", 2),
115				("date", 1)
116			]
117		);
118	}
119}