binius_utils/
sorting.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright 2024-2025 Irreducible Inc.

use itertools::Itertools;

/// Returns whether the given values are sorted in ascending order.
pub fn is_sorted_ascending<T: PartialOrd + Clone>(values: impl Iterator<Item = T>) -> bool {
	!values.tuple_windows().any(|(a, b)| a > b)
}

/// Stable sorts a collection of objects based on a key function and an optional descending flag.
///
/// This function takes a collection of objects, sorts them stably based on the value returned by a
/// key function, and can optionally use descending sort order. It returns a tuple containing a vector
/// of the original indices of the objects and a vector of the sorted objects.
///
/// # Arguments
///
/// * `objs`: An iterable collection of objects to be sorted.
/// * `key`: A function that takes a reference to an object and returns a `usize` value used for sorting.
/// * `descending`: A boolean flag indicating whether to sort in descending order.
///
/// # Returns
///
/// A tuple where the first element is a vector of the original indices of the objects and the second
/// element is a vector of the sorted objects.
///
/// # Note
///
/// This function uses stable sorting to ensure consistency in ordering objects that compare as equal.
pub fn stable_sort<T>(
	objs: impl IntoIterator<Item = T>,
	key: impl Fn(&T) -> usize,
	descending: bool,
) -> (Vec<usize>, Vec<T>) {
	let mut indexed_objs = objs.into_iter().enumerate().collect::<Vec<_>>();
	// NOTE: Important to use stable sorting for prover-verifier consistency!
	if descending {
		indexed_objs.sort_by(|a, b| key(&b.1).cmp(&key(&a.1)));
	} else {
		indexed_objs.sort_by(|a, b| key(&a.1).cmp(&key(&b.1)));
	}
	let (original_indices, sorted_objs) = indexed_objs.into_iter().unzip::<_, _, Vec<_>, Vec<_>>();
	(original_indices, sorted_objs)
}

/// Restores the original order of a collection of objects based on their original indices.
///
/// This function takes a collection of objects and their corresponding original indices, and returns
/// a vector of the objects sorted back to their original order.
///
/// # Arguments
///
/// - `original_indices`: An iterable collection of the original indices of the objects.
/// - `sorted_objs`: An iterable collection of the objects that have been sorted.
///
/// # Returns
///
/// A vector of objects restored to their original order based on the provided indices.
///
/// # Note
///
/// This function assumes that the length of `original_indices` and `sorted_objs` are the same, and that
/// `original_indices` contains unique values representing valid indices.
pub fn unsort<T>(
	original_indices: impl IntoIterator<Item = usize>,
	sorted_objs: impl IntoIterator<Item = T>,
) -> Vec<T> {
	let mut temp = original_indices
		.into_iter()
		.zip(sorted_objs)
		.collect::<Vec<_>>();
	temp.sort_by_key(|(i, _)| *i);
	temp.into_iter().map(|(_, obj)| obj).collect()
}

#[cfg(test)]
mod tests {
	use super::*;

	#[test]
	fn test_stable_sort() {
		let items = vec![
			("apple", 3),
			("banana", 2),
			("cherry", 3),
			("date", 1),
			("elderberry", 2),
		];

		// Sort by the second element of the tuple (the number) in ascending order
		let key = |item: &(&str, usize)| item.1;

		let (indices_asc, sorted_items_asc) = stable_sort(items.clone(), key, false);
		assert_eq!(indices_asc, vec![3, 1, 4, 0, 2]); // Expected original indices
		assert_eq!(
			sorted_items_asc,
			vec![
				("date", 1),
				("banana", 2),
				("elderberry", 2),
				("apple", 3),
				("cherry", 3)
			]
		);

		// Sort by the second element of the tuple (the number) in descending order
		let (indices_desc, sorted_items_desc) = stable_sort(items, key, true);
		assert_eq!(indices_desc, vec![0, 2, 1, 4, 3]); // Expected original indices
		assert_eq!(
			sorted_items_desc,
			vec![
				("apple", 3),
				("cherry", 3),
				("banana", 2),
				("elderberry", 2),
				("date", 1)
			]
		);
	}
}