binius_field/
tower_levels.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array,
5	ops::{Add, AddAssign, Index, IndexMut},
6};
7
8use binius_utils::checked_arithmetics::checked_log_2;
9
10/// Public API for recursive algorithms over data represented as an array of its limbs
11/// E.g. an F_{2^128} element expressed as 8 chunks of 2 bytes
12/// or a 256-bit integer represented as 32 individual bytes
13///
14/// Join and split can be used to combine and split underlying data into upper and lower halves
15///
16/// This is mostly useful for recursively implementing arithmetic operations
17///
18/// These separate implementations are necessary to overcome the limitations of const generics in Rust.
19/// These implementations eliminate costly bounds checking that would otherwise be imposed by the compiler
20/// and allow easy inlining of recursive functions.
21pub trait TowerLevel: 'static {
22	// WIDTH is ALWAYS a power of 2
23	const WIDTH: usize;
24	const LOG_WIDTH: usize = checked_log_2(Self::WIDTH);
25
26	// The underlying Data should ALWAYS be a fixed-width array of T's
27	type Data<T>: AsMut<[T]>
28		+ AsRef<[T]>
29		+ Sized
30		+ Index<usize, Output = T>
31		+ IndexMut<usize, Output = T>;
32	type Base: TowerLevel;
33
34	// Split something of type Self::Data<T>into two equal halves
35	#[allow(clippy::type_complexity)]
36	fn split<T>(
37		data: &Self::Data<T>,
38	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>);
39
40	// Split something of type Self::Data<T>into two equal mutable halves
41	#[allow(clippy::type_complexity)]
42	fn split_mut<T>(
43		data: &mut Self::Data<T>,
44	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>);
45
46	// Join two equal-length arrays (the reverse of split)
47	fn join<T: Copy + Default>(
48		first: &<Self::Base as TowerLevel>::Data<T>,
49		second: &<Self::Base as TowerLevel>::Data<T>,
50	) -> Self::Data<T>;
51
52	// Fills an array of T's containing WIDTH elements
53	fn from_fn<T: Copy>(f: impl FnMut(usize) -> T) -> Self::Data<T>;
54
55	// Fills an array of T's containing WIDTH elements with T::default()
56	fn default<T: Copy + Default>() -> Self::Data<T> {
57		Self::from_fn(|_| T::default())
58	}
59}
60
61pub trait TowerLevelWithArithOps: TowerLevel {
62	#[inline(always)]
63	fn add_into<T: AddAssign + Copy>(
64		field_element: &Self::Data<T>,
65		destination: &mut Self::Data<T>,
66	) {
67		for i in 0..Self::WIDTH {
68			destination[i] += field_element[i];
69		}
70	}
71
72	#[inline(always)]
73	fn copy_into<T: Copy>(field_element: &Self::Data<T>, destination: &mut Self::Data<T>) {
74		for i in 0..Self::WIDTH {
75			destination[i] = field_element[i];
76		}
77	}
78
79	#[inline(always)]
80	fn sum<T: Copy + Add<Output = T>>(
81		field_element_a: &Self::Data<T>,
82		field_element_b: &Self::Data<T>,
83	) -> Self::Data<T> {
84		Self::from_fn(|i| field_element_a[i] + field_element_b[i])
85	}
86}
87
88impl<T: TowerLevel> TowerLevelWithArithOps for T {}
89
90pub struct TowerLevel64;
91
92impl TowerLevel for TowerLevel64 {
93	const WIDTH: usize = 64;
94
95	type Data<T> = [T; 64];
96	type Base = TowerLevel32;
97
98	#[inline(always)]
99	fn split<T>(
100		data: &Self::Data<T>,
101	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
102		((data[0..32].try_into().unwrap()), (data[32..64].try_into().unwrap()))
103	}
104
105	#[inline(always)]
106	fn split_mut<T>(
107		data: &mut Self::Data<T>,
108	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
109		let (chunk_1, chunk_2) = data.split_at_mut(32);
110
111		((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
112	}
113
114	#[inline(always)]
115	fn join<T: Copy + Default>(
116		left: &<Self::Base as TowerLevel>::Data<T>,
117		right: &<Self::Base as TowerLevel>::Data<T>,
118	) -> Self::Data<T> {
119		let mut result = [T::default(); 64];
120		result[..32].copy_from_slice(left);
121		result[32..].copy_from_slice(right);
122		result
123	}
124
125	#[inline(always)]
126	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
127		array::from_fn(f)
128	}
129}
130
131pub struct TowerLevel32;
132
133impl TowerLevel for TowerLevel32 {
134	const WIDTH: usize = 32;
135
136	type Data<T> = [T; 32];
137	type Base = TowerLevel16;
138
139	#[inline(always)]
140	fn split<T>(
141		data: &Self::Data<T>,
142	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
143		((data[0..16].try_into().unwrap()), (data[16..32].try_into().unwrap()))
144	}
145
146	#[inline(always)]
147	fn split_mut<T>(
148		data: &mut Self::Data<T>,
149	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
150		let (chunk_1, chunk_2) = data.split_at_mut(16);
151
152		((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
153	}
154
155	#[inline(always)]
156	fn join<T: Copy + Default>(
157		left: &<Self::Base as TowerLevel>::Data<T>,
158		right: &<Self::Base as TowerLevel>::Data<T>,
159	) -> Self::Data<T> {
160		let mut result = [T::default(); 32];
161		result[..16].copy_from_slice(left);
162		result[16..].copy_from_slice(right);
163		result
164	}
165
166	#[inline(always)]
167	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
168		array::from_fn(f)
169	}
170}
171
172pub struct TowerLevel16;
173
174impl TowerLevel for TowerLevel16 {
175	const WIDTH: usize = 16;
176
177	type Data<T> = [T; 16];
178	type Base = TowerLevel8;
179
180	#[inline(always)]
181	fn split<T>(
182		data: &Self::Data<T>,
183	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
184		((data[0..8].try_into().unwrap()), (data[8..16].try_into().unwrap()))
185	}
186
187	#[inline(always)]
188	fn split_mut<T>(
189		data: &mut Self::Data<T>,
190	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
191		let (chunk_1, chunk_2) = data.split_at_mut(8);
192
193		((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
194	}
195
196	#[inline(always)]
197	fn join<T: Copy + Default>(
198		left: &<Self::Base as TowerLevel>::Data<T>,
199		right: &<Self::Base as TowerLevel>::Data<T>,
200	) -> Self::Data<T> {
201		let mut result = [T::default(); 16];
202		result[..8].copy_from_slice(left);
203		result[8..].copy_from_slice(right);
204		result
205	}
206
207	#[inline(always)]
208	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
209		array::from_fn(f)
210	}
211}
212
213pub struct TowerLevel8;
214
215impl TowerLevel for TowerLevel8 {
216	const WIDTH: usize = 8;
217
218	type Data<T> = [T; 8];
219	type Base = TowerLevel4;
220
221	#[inline(always)]
222	fn split<T>(
223		data: &Self::Data<T>,
224	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
225		((data[0..4].try_into().unwrap()), (data[4..8].try_into().unwrap()))
226	}
227
228	#[inline(always)]
229	fn split_mut<T>(
230		data: &mut Self::Data<T>,
231	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
232		let (chunk_1, chunk_2) = data.split_at_mut(4);
233
234		((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
235	}
236
237	#[inline(always)]
238	fn join<T: Copy + Default>(
239		left: &<Self::Base as TowerLevel>::Data<T>,
240		right: &<Self::Base as TowerLevel>::Data<T>,
241	) -> Self::Data<T> {
242		let mut result = [T::default(); 8];
243		result[..4].copy_from_slice(left);
244		result[4..].copy_from_slice(right);
245		result
246	}
247
248	#[inline(always)]
249	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
250		array::from_fn(f)
251	}
252}
253
254pub struct TowerLevel4;
255
256impl TowerLevel for TowerLevel4 {
257	const WIDTH: usize = 4;
258
259	type Data<T> = [T; 4];
260	type Base = TowerLevel2;
261
262	#[inline(always)]
263	fn split<T>(
264		data: &Self::Data<T>,
265	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
266		((data[0..2].try_into().unwrap()), (data[2..4].try_into().unwrap()))
267	}
268
269	#[inline(always)]
270	fn split_mut<T>(
271		data: &mut Self::Data<T>,
272	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
273		let (chunk_1, chunk_2) = data.split_at_mut(2);
274
275		((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
276	}
277
278	#[inline(always)]
279	fn join<T: Copy + Default>(
280		left: &<Self::Base as TowerLevel>::Data<T>,
281		right: &<Self::Base as TowerLevel>::Data<T>,
282	) -> Self::Data<T> {
283		let mut result = [T::default(); 4];
284		result[..2].copy_from_slice(left);
285		result[2..].copy_from_slice(right);
286		result
287	}
288
289	#[inline(always)]
290	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
291		array::from_fn(f)
292	}
293}
294
295pub struct TowerLevel2;
296
297impl TowerLevel for TowerLevel2 {
298	const WIDTH: usize = 2;
299
300	type Data<T> = [T; 2];
301	type Base = TowerLevel1;
302
303	#[inline(always)]
304	fn split<T>(
305		data: &Self::Data<T>,
306	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
307		((data[0..1].try_into().unwrap()), (data[1..2].try_into().unwrap()))
308	}
309
310	#[inline(always)]
311	fn split_mut<T>(
312		data: &mut Self::Data<T>,
313	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
314		let (chunk_1, chunk_2) = data.split_at_mut(1);
315
316		((chunk_1.try_into().unwrap()), (chunk_2.try_into().unwrap()))
317	}
318
319	#[inline(always)]
320	fn join<T: Copy + Default>(
321		left: &<Self::Base as TowerLevel>::Data<T>,
322		right: &<Self::Base as TowerLevel>::Data<T>,
323	) -> Self::Data<T> {
324		let mut result = [T::default(); 2];
325		result[..1].copy_from_slice(left);
326		result[1..].copy_from_slice(right);
327		result
328	}
329
330	#[inline(always)]
331	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
332		array::from_fn(f)
333	}
334}
335
336pub struct TowerLevel1;
337
338impl TowerLevel for TowerLevel1 {
339	const WIDTH: usize = 1;
340
341	type Data<T> = [T; 1];
342	type Base = Self;
343
344	// Level 1 is the atomic unit of backing data and must not be split.
345
346	#[inline(always)]
347	fn split<T>(
348		_data: &Self::Data<T>,
349	) -> (&<Self::Base as TowerLevel>::Data<T>, &<Self::Base as TowerLevel>::Data<T>) {
350		unreachable!()
351	}
352
353	#[inline(always)]
354	fn split_mut<T>(
355		_data: &mut Self::Data<T>,
356	) -> (&mut <Self::Base as TowerLevel>::Data<T>, &mut <Self::Base as TowerLevel>::Data<T>) {
357		unreachable!()
358	}
359
360	#[inline(always)]
361	fn join<T: Copy + Default>(
362		_left: &<Self::Base as TowerLevel>::Data<T>,
363		_right: &<Self::Base as TowerLevel>::Data<T>,
364	) -> Self::Data<T> {
365		unreachable!()
366	}
367
368	#[inline(always)]
369	fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self::Data<T> {
370		array::from_fn(f)
371	}
372}
373
374#[cfg(test)]
375mod tests {
376	use super::*;
377
378	#[test]
379	fn test_split_join() {
380		fn check_split_and_join<TL: TowerLevel>() {
381			let data = TL::from_fn(|i| i as u32);
382			let (left, right) = TL::split(&data);
383			assert_eq!(left.as_ref(), &data.as_ref()[..TL::WIDTH / 2]);
384			assert_eq!(right.as_ref(), &data.as_ref()[TL::WIDTH / 2..]);
385
386			let joined = TL::join(left, right);
387			assert_eq!(joined.as_ref(), data.as_ref());
388		}
389
390		check_split_and_join::<TowerLevel64>();
391		check_split_and_join::<TowerLevel32>();
392		check_split_and_join::<TowerLevel16>();
393		check_split_and_join::<TowerLevel8>();
394		check_split_and_join::<TowerLevel4>();
395		check_split_and_join::<TowerLevel2>();
396	}
397
398	#[test]
399	fn test_split_mut_join() {
400		fn check_mut_and_join<TL: TowerLevel>() {
401			let mut data = TL::from_fn(|i| i as u32);
402			let expected_left = data.as_ref()[..TL::WIDTH / 2].to_vec();
403			let expected_right = data.as_ref()[TL::WIDTH / 2..].to_vec();
404			let (left, right) = TL::split_mut(&mut data);
405
406			assert_eq!(left.as_mut(), expected_left);
407			assert_eq!(right.as_mut(), expected_right);
408
409			let joined = TL::join(left, right);
410			assert_eq!(joined.as_ref(), data.as_ref());
411		}
412
413		check_mut_and_join::<TowerLevel64>();
414		check_mut_and_join::<TowerLevel32>();
415		check_mut_and_join::<TowerLevel16>();
416		check_mut_and_join::<TowerLevel8>();
417		check_mut_and_join::<TowerLevel4>();
418		check_mut_and_join::<TowerLevel2>();
419	}
420
421	#[test]
422	fn test_from_fn() {
423		fn check_from_fn<TL: TowerLevel>() {
424			let data = TL::from_fn(|i| i as u32);
425			let expected = (0..TL::WIDTH).map(|i| i as u32).collect::<Vec<_>>();
426			assert_eq!(data.as_ref(), expected.as_slice());
427		}
428
429		check_from_fn::<TowerLevel64>();
430		check_from_fn::<TowerLevel32>();
431		check_from_fn::<TowerLevel16>();
432		check_from_fn::<TowerLevel8>();
433		check_from_fn::<TowerLevel4>();
434		check_from_fn::<TowerLevel2>();
435	}
436
437	#[test]
438	fn test_default() {
439		fn check_default<TL: TowerLevel>() {
440			let data = TL::default::<u32>();
441			let expected = vec![0u32; TL::WIDTH];
442			assert_eq!(data.as_ref(), expected);
443		}
444
445		check_default::<TowerLevel64>();
446		check_default::<TowerLevel32>();
447		check_default::<TowerLevel16>();
448		check_default::<TowerLevel8>();
449		check_default::<TowerLevel4>();
450		check_default::<TowerLevel2>();
451	}
452
453	#[test]
454	fn test_add_into() {
455		fn check_add_into<TL: TowerLevel>() {
456			let a = TL::from_fn(|i| i as u32);
457			let mut b = TL::default::<u32>();
458			TL::add_into(&a, &mut b);
459			assert_eq!(b.as_ref(), a.as_ref());
460		}
461
462		check_add_into::<TowerLevel64>();
463		check_add_into::<TowerLevel32>();
464		check_add_into::<TowerLevel16>();
465		check_add_into::<TowerLevel8>();
466		check_add_into::<TowerLevel4>();
467		check_add_into::<TowerLevel2>();
468	}
469
470	#[test]
471	fn test_copy_into() {
472		fn check_copy_into<TL: TowerLevel>() {
473			let a = TL::from_fn(|i| i as u32);
474			let mut b = TL::default::<u32>();
475			TL::copy_into(&a, &mut b);
476			assert_eq!(b.as_ref(), a.as_ref());
477		}
478
479		check_copy_into::<TowerLevel64>();
480		check_copy_into::<TowerLevel32>();
481		check_copy_into::<TowerLevel16>();
482		check_copy_into::<TowerLevel8>();
483		check_copy_into::<TowerLevel4>();
484		check_copy_into::<TowerLevel2>();
485	}
486
487	#[test]
488	fn test_sum() {
489		fn check_sum<TL: TowerLevel>() {
490			let a = TL::from_fn(|i| i as u32);
491			let b = TL::from_fn(|i| i as u32);
492			let sum_result = TL::sum(&a, &b);
493			let expected = TL::from_fn(|i| 2 * (i as u32));
494			assert_eq!(sum_result.as_ref(), expected.as_ref());
495		}
496
497		check_sum::<TowerLevel64>();
498		check_sum::<TowerLevel32>();
499		check_sum::<TowerLevel16>();
500		check_sum::<TowerLevel8>();
501		check_sum::<TowerLevel4>();
502		check_sum::<TowerLevel2>();
503	}
504}