binius_field/
packed_extension_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_maybe_rayon::prelude::{
4	IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{Error, ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(
10	lhs: &mut [PE],
11	rhs: &[PE::PackedSubfield],
12) -> Result<(), Error> {
13	ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
14}
15
16pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
17	lhs: &mut [PE],
18	rhs: &[PE::PackedSubfield],
19) -> Result<(), Error> {
20	ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
21		PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
22	})
23}
24
25/// # Safety
26///
27/// Width of PackedSubfield is >= the width of the field implementing PackedExtension.
28pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
29	packed_subfields: &[PE::PackedSubfield],
30	i: usize,
31) -> PE::PackedSubfield {
32	let bottom_most_scalar_idx = i * PE::WIDTH;
33	let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
34	let bottom_most_scalar_idx_within_packed_subfield =
35		bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
36	let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
37
38	unsafe {
39		packed_subfields
40			.get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
41			.spread_unchecked(PE::LOG_WIDTH, block_idx)
42	}
43}
44
45/// Refer to the functions above for examples of closures to pass
46/// Func takes in the following parameters
47///
48/// Note that this function overwrites the lhs buffer, copy that data before
49/// invoking this function if you need to use it elsewhere
50///
51/// lhs: PE::WIDTH extension field scalars
52///
53/// broadcasted_rhs: a broadcasted version of PE::WIDTH subfield scalars
54/// with each one occurring PE::PackedSubfield::WIDTH/PE::WIDTH times in  a row
55/// such that the bits of the broadcasted scalars align with the lhs scalars
56pub fn ext_base_op<PE, F, Func>(
57	lhs: &mut [PE],
58	rhs: &[PE::PackedSubfield],
59	op: Func,
60) -> Result<(), Error>
61where
62	PE: PackedExtension<F>,
63	F: Field,
64	Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
65{
66	if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
67		return Err(Error::MismatchedLengths);
68	}
69
70	lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
71		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing
72		// PackedExtension
73		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
74
75		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
76	});
77	Ok(())
78}
79
80/// A multithreaded version of the function directly above, use for long arrays
81/// on the prover side
82pub fn ext_base_op_par<PE, F, Func>(
83	lhs: &mut [PE],
84	rhs: &[PE::PackedSubfield],
85	op: Func,
86) -> Result<(), Error>
87where
88	PE: PackedExtension<F>,
89	F: Field,
90	Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
91{
92	if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
93		return Err(Error::MismatchedLengths);
94	}
95
96	lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
97		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing
98		// PackedExtension
99		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
100
101		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
102	});
103
104	Ok(())
105}
106
107#[cfg(test)]
108mod tests {
109	use proptest::prelude::*;
110
111	use crate::{
112		BinaryField8b, BinaryField16b, BinaryField128b, PackedBinaryField2x128b,
113		PackedBinaryField16x16b, PackedBinaryField32x8b, ext_base_mul, ext_base_mul_par,
114		packed::{get_packed_slice, pack_slice},
115		underlier::WithUnderlier,
116	};
117
118	fn strategy_8b_scalars() -> impl Strategy<Value = [BinaryField8b; 32]> {
119		any::<[<BinaryField8b as WithUnderlier>::Underlier; 32]>()
120			.prop_map(|arr| arr.map(<BinaryField8b>::from_underlier))
121	}
122
123	fn strategy_16b_scalars() -> impl Strategy<Value = [BinaryField16b; 32]> {
124		any::<[<BinaryField16b as WithUnderlier>::Underlier; 32]>()
125			.prop_map(|arr| arr.map(<BinaryField16b>::from_underlier))
126	}
127
128	fn strategy_128b_scalars() -> impl Strategy<Value = [BinaryField128b; 32]> {
129		any::<[<BinaryField128b as WithUnderlier>::Underlier; 32]>()
130			.prop_map(|arr| arr.map(<BinaryField128b>::from_underlier))
131	}
132
133	proptest! {
134		#[test]
135		fn test_base_ext_mul_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
136			let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
137			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
138
139			ext_base_mul(&mut ext_packed, &base_packed).unwrap();
140
141			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
142				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
143			}
144		}
145
146		#[test]
147		fn test_base_ext_mul_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
148			let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
149			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
150
151			ext_base_mul(&mut ext_packed, &base_packed).unwrap();
152
153			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
154				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
155			}
156		}
157
158
159		#[test]
160		fn test_base_ext_mul_par_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
161			let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
162			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
163
164			ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
165
166			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
167				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
168			}
169		}
170
171		#[test]
172		fn test_base_ext_mul_par_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
173			let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
174			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
175
176			ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
177
178			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
179				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
180			}
181		}
182	}
183}