binius_utils/sorting.rs
1// Copyright 2024-2025 Irreducible Inc.
2
3use itertools::Itertools;
4
5/// Returns whether the given values are sorted in ascending order.
6pub fn is_sorted_ascending<T: PartialOrd + Clone>(values: impl Iterator<Item = T>) -> bool {
7 !values.tuple_windows().any(|(a, b)| a > b)
8}
9
10/// Stable sorts a collection of objects based on a key function and an optional descending flag.
11///
12/// This function takes a collection of objects, sorts them stably based on the value returned by a
13/// key function, and can optionally use descending sort order. It returns a tuple containing a
14/// vector of the original indices of the objects and a vector of the sorted objects.
15///
16/// # Arguments
17///
18/// * `objs`: An iterable collection of objects to be sorted.
19/// * `key`: A function that takes a reference to an object and returns a `usize` value used for
20/// sorting.
21/// * `descending`: A boolean flag indicating whether to sort in descending order.
22///
23/// # Returns
24///
25/// A tuple where the first element is a vector of the original indices of the objects and the
26/// second element is a vector of the sorted objects.
27///
28/// # Note
29///
30/// This function uses stable sorting to ensure consistency in ordering objects that compare as
31/// equal.
32pub fn stable_sort<T>(
33 objs: impl IntoIterator<Item = T>,
34 key: impl Fn(&T) -> usize,
35 descending: bool,
36) -> (Vec<usize>, Vec<T>) {
37 let mut indexed_objs = objs.into_iter().enumerate().collect::<Vec<_>>();
38 // NOTE: Important to use stable sorting for prover-verifier consistency!
39 if descending {
40 indexed_objs.sort_by(|a, b| key(&b.1).cmp(&key(&a.1)));
41 } else {
42 indexed_objs.sort_by(|a, b| key(&a.1).cmp(&key(&b.1)));
43 }
44 indexed_objs.into_iter().unzip()
45}
46
47/// Restores the original order of a collection of objects based on their original indices.
48///
49/// This function takes a collection of objects and their corresponding original indices, and
50/// returns a vector of the objects sorted back to their original order.
51///
52/// # Arguments
53///
54/// - `original_indices`: An iterable collection of the original indices of the objects.
55/// - `sorted_objs`: An iterable collection of the objects that have been sorted.
56///
57/// # Returns
58///
59/// A vector of objects restored to their original order based on the provided indices.
60///
61/// # Note
62///
63/// This function assumes that the length of `original_indices` and `sorted_objs` are the same, and
64/// that `original_indices` contains unique values representing valid indices.
65pub fn unsort<T>(
66 original_indices: impl IntoIterator<Item = usize>,
67 sorted_objs: impl IntoIterator<Item = T>,
68) -> Vec<T> {
69 let mut temp = original_indices
70 .into_iter()
71 .zip(sorted_objs)
72 .collect::<Vec<_>>();
73 temp.sort_by_key(|(i, _)| *i);
74 temp.into_iter().map(|(_, obj)| obj).collect()
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 #[test]
82 fn test_stable_sort() {
83 let items = vec![
84 ("apple", 3),
85 ("banana", 2),
86 ("cherry", 3),
87 ("date", 1),
88 ("elderberry", 2),
89 ];
90
91 // Sort by the second element of the tuple (the number) in ascending order
92 let key = |item: &(&str, usize)| item.1;
93
94 let (indices_asc, sorted_items_asc) = stable_sort(items.clone(), key, false);
95 assert_eq!(indices_asc, vec![3, 1, 4, 0, 2]); // Expected original indices
96 assert_eq!(
97 sorted_items_asc,
98 vec![
99 ("date", 1),
100 ("banana", 2),
101 ("elderberry", 2),
102 ("apple", 3),
103 ("cherry", 3)
104 ]
105 );
106
107 // Sort by the second element of the tuple (the number) in descending order
108 let (indices_desc, sorted_items_desc) = stable_sort(items, key, true);
109 assert_eq!(indices_desc, vec![0, 2, 1, 4, 3]); // Expected original indices
110 assert_eq!(
111 sorted_items_desc,
112 vec![
113 ("apple", 3),
114 ("cherry", 3),
115 ("banana", 2),
116 ("elderberry", 2),
117 ("date", 1)
118 ]
119 );
120 }
121}