binius_math/
rows_batch.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::ops::{Bound, RangeBounds};
4
5/// This struct represents a batch of rows, each row having the same length equal to `row_len`.
6pub struct RowsBatch<'a, T> {
7	rows: Vec<&'a [T]>,
8	row_len: usize,
9}
10
11impl<'a, T> RowsBatch<'a, T> {
12	/// Create a new `RowsBatch` from a vector of rows and the given row length
13	///
14	/// # Panics
15	/// In case if any of the rows has a length different from `row_len`.
16	#[inline]
17	pub fn new(rows: Vec<&'a [T]>, row_len: usize) -> Self {
18		for row in &rows {
19			assert_eq!(row.len(), row_len);
20		}
21
22		Self { rows, row_len }
23	}
24
25	/// Create a new `RowsBatch` from an iterator of rows and the given row length.
26	///
27	/// # Panics
28	/// In case if any of the rows has a length less than `row_len`.
29	#[inline]
30	pub fn new_from_iter(rows: impl IntoIterator<Item = &'a [T]>, row_len: usize) -> Self {
31		let rows = rows.into_iter().map(|x| &x[..row_len]).collect();
32		Self { rows, row_len }
33	}
34
35	#[inline(always)]
36	pub fn get_ref(&self) -> RowsBatchRef<'_, T> {
37		RowsBatchRef {
38			rows: self.rows.as_slice(),
39			row_len: self.row_len,
40			offset: 0,
41		}
42	}
43}
44
45/// This struct is similar to `RowsBatch`, but it holds a reference to a slice of rows
46/// (instead of an owned vector) and an offset in each row.
47///
48/// It is guaranteed that all rows in `self.rows` have a length of at least `row_len + offset`. The
49/// effective row returned by `row` method is `&self.rows[index][self.offset..self.offset +
50/// self.row_len]`. Unfortunately, due to lifetime issues, we can't unify `RowsBatch` and
51/// `RowsBatchRef` into a single generic struct parameterized by the container type.
52pub struct RowsBatchRef<'a, T> {
53	rows: &'a [&'a [T]],
54	row_len: usize,
55	offset: usize,
56}
57
58impl<'a, T> RowsBatchRef<'a, T> {
59	/// Create a new `RowsBatchRef` from a slice of rows and the given row length.
60	///
61	/// # Panics
62	/// In case if any of the rows has a length smaller than `row_len`.
63	#[inline]
64	pub fn new(rows: &'a [&'a [T]], row_len: usize) -> Self {
65		Self::new_with_offset(rows, 0, row_len)
66	}
67
68	/// Create a new `RowsBatchRef` from a slice of rows and the given row length and offset.
69	///
70	/// # Panics
71	/// In case if any of the rows has a length smaller than from `offset + row_len`.
72	#[inline]
73	pub fn new_with_offset(rows: &'a [&'a [T]], offset: usize, row_len: usize) -> Self {
74		for row in rows {
75			assert!(offset + row_len <= row.len());
76		}
77
78		Self {
79			rows,
80			row_len,
81			offset,
82		}
83	}
84
85	/// Create a new `RowsBatchRef` from a slice of rows and the given row length and offset.
86	///
87	/// # Safety
88	/// This function is unsafe because it does not check if the rows have the enough length.
89	/// It is the caller's responsibility to ensure that `row_len` is less than or equal
90	/// to the length of each row.
91	pub unsafe fn new_unchecked(rows: &'a [&'a [T]], row_len: usize) -> Self {
92		unsafe { Self::new_with_offset_unchecked(rows, 0, row_len) }
93	}
94
95	/// Create a new `RowsBatchRef` from a slice of rows and the given row length and offset.
96	///
97	/// # Safety
98	/// This function is unsafe because it does not check if the rows have the enough length.
99	/// It is the caller's responsibility to ensure that `offset + row_len` is less than or equal
100	/// to the length of each row.
101	pub unsafe fn new_with_offset_unchecked(
102		rows: &'a [&'a [T]],
103		offset: usize,
104		row_len: usize,
105	) -> Self {
106		Self {
107			rows,
108			row_len,
109			offset,
110		}
111	}
112
113	#[inline]
114	pub fn iter(&self) -> impl Iterator<Item = &'a [T]> + '_ {
115		self.rows.as_ref().iter().copied()
116	}
117
118	#[inline(always)]
119	pub fn row(&self, index: usize) -> &'a [T] {
120		&self.rows[index][self.offset..self.offset + self.row_len]
121	}
122
123	#[inline(always)]
124	pub fn n_rows(&self) -> usize {
125		self.rows.as_ref().len()
126	}
127
128	#[inline(always)]
129	pub fn row_len(&self) -> usize {
130		self.row_len
131	}
132
133	#[inline(always)]
134	pub fn is_empty(&self) -> bool {
135		self.rows.as_ref().is_empty()
136	}
137
138	/// Returns a new `RowsBatch` with the specified rows selected by the given indices.
139	pub fn map(&self, indices: impl AsRef<[usize]>) -> RowsBatch<'a, T> {
140		let rows = indices.as_ref().iter().map(|&i| self.rows[i]).collect();
141
142		RowsBatch {
143			rows,
144			row_len: self.row_len,
145		}
146	}
147
148	/// Returns a new `RowsBatchRef` with the specified columns selected by the given indices range.
149	pub fn columns_subrange(&self, range: impl RangeBounds<usize>) -> Self {
150		let start = match range.start_bound() {
151			Bound::Included(&start) => start,
152			Bound::Excluded(&start) => start + 1,
153			Bound::Unbounded => 0,
154		};
155		let end = match range.end_bound() {
156			Bound::Included(&end) => end + 1,
157			Bound::Excluded(&end) => end,
158			Bound::Unbounded => self.row_len,
159		};
160
161		assert!(start <= end);
162		assert!(end <= self.row_len);
163
164		Self {
165			rows: self.rows,
166			row_len: end - start,
167			offset: self.offset + start,
168		}
169	}
170}