binius_utils/sorting.rs
1// Copyright 2024-2025 Irreducible Inc.
2
3use itertools::Itertools;
4
5/// Returns whether the given values are sorted in ascending order.
6pub fn is_sorted_ascending<T: PartialOrd + Clone>(values: impl Iterator<Item = T>) -> bool {
7 !values.tuple_windows().any(|(a, b)| a > b)
8}
9
10/// Stable sorts a collection of objects based on a key function and an optional descending flag.
11///
12/// This function takes a collection of objects, sorts them stably based on the value returned by a
13/// key function, and can optionally use descending sort order. It returns a tuple containing a vector
14/// of the original indices of the objects and a vector of the sorted objects.
15///
16/// # Arguments
17///
18/// * `objs`: An iterable collection of objects to be sorted.
19/// * `key`: A function that takes a reference to an object and returns a `usize` value used for sorting.
20/// * `descending`: A boolean flag indicating whether to sort in descending order.
21///
22/// # Returns
23///
24/// A tuple where the first element is a vector of the original indices of the objects and the second
25/// element is a vector of the sorted objects.
26///
27/// # Note
28///
29/// This function uses stable sorting to ensure consistency in ordering objects that compare as equal.
30pub fn stable_sort<T>(
31 objs: impl IntoIterator<Item = T>,
32 key: impl Fn(&T) -> usize,
33 descending: bool,
34) -> (Vec<usize>, Vec<T>) {
35 let mut indexed_objs = objs.into_iter().enumerate().collect::<Vec<_>>();
36 // NOTE: Important to use stable sorting for prover-verifier consistency!
37 if descending {
38 indexed_objs.sort_by(|a, b| key(&b.1).cmp(&key(&a.1)));
39 } else {
40 indexed_objs.sort_by(|a, b| key(&a.1).cmp(&key(&b.1)));
41 }
42 indexed_objs.into_iter().unzip()
43}
44
45/// Restores the original order of a collection of objects based on their original indices.
46///
47/// This function takes a collection of objects and their corresponding original indices, and returns
48/// a vector of the objects sorted back to their original order.
49///
50/// # Arguments
51///
52/// - `original_indices`: An iterable collection of the original indices of the objects.
53/// - `sorted_objs`: An iterable collection of the objects that have been sorted.
54///
55/// # Returns
56///
57/// A vector of objects restored to their original order based on the provided indices.
58///
59/// # Note
60///
61/// This function assumes that the length of `original_indices` and `sorted_objs` are the same, and that
62/// `original_indices` contains unique values representing valid indices.
63pub fn unsort<T>(
64 original_indices: impl IntoIterator<Item = usize>,
65 sorted_objs: impl IntoIterator<Item = T>,
66) -> Vec<T> {
67 let mut temp = original_indices
68 .into_iter()
69 .zip(sorted_objs)
70 .collect::<Vec<_>>();
71 temp.sort_by_key(|(i, _)| *i);
72 temp.into_iter().map(|(_, obj)| obj).collect()
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[test]
80 fn test_stable_sort() {
81 let items = vec![
82 ("apple", 3),
83 ("banana", 2),
84 ("cherry", 3),
85 ("date", 1),
86 ("elderberry", 2),
87 ];
88
89 // Sort by the second element of the tuple (the number) in ascending order
90 let key = |item: &(&str, usize)| item.1;
91
92 let (indices_asc, sorted_items_asc) = stable_sort(items.clone(), key, false);
93 assert_eq!(indices_asc, vec![3, 1, 4, 0, 2]); // Expected original indices
94 assert_eq!(
95 sorted_items_asc,
96 vec![
97 ("date", 1),
98 ("banana", 2),
99 ("elderberry", 2),
100 ("apple", 3),
101 ("cherry", 3)
102 ]
103 );
104
105 // Sort by the second element of the tuple (the number) in descending order
106 let (indices_desc, sorted_items_desc) = stable_sort(items, key, true);
107 assert_eq!(indices_desc, vec![0, 2, 1, 4, 3]); // Expected original indices
108 assert_eq!(
109 sorted_items_desc,
110 vec![
111 ("apple", 3),
112 ("cherry", 3),
113 ("banana", 2),
114 ("elderberry", 2),
115 ("date", 1)
116 ]
117 );
118 }
119}