binius_field/
packed_extension_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_maybe_rayon::prelude::{
4	IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{Error, ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(
10	lhs: &mut [PE],
11	rhs: &[PE::PackedSubfield],
12) -> Result<(), Error> {
13	ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
14}
15
16pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
17	lhs: &mut [PE],
18	rhs: &[PE::PackedSubfield],
19) -> Result<(), Error> {
20	ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
21		PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
22	})
23}
24
25/// # Safety
26///
27/// Width of PackedSubfield is >= the width of the field implementing PackedExtension.
28pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
29	packed_subfields: &[PE::PackedSubfield],
30	i: usize,
31) -> PE::PackedSubfield {
32	let bottom_most_scalar_idx = i * PE::WIDTH;
33	let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
34	let bottom_most_scalar_idx_within_packed_subfield =
35		bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
36	let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
37
38	packed_subfields
39		.get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
40		.spread_unchecked(PE::LOG_WIDTH, block_idx)
41}
42
43/// Refer to the functions above for examples of closures to pass
44/// Func takes in the following parameters
45///
46/// Note that this function overwrites the lhs buffer, copy that data before
47/// invoking this function if you need to use it elsewhere
48///
49/// lhs: PE::WIDTH extension field scalars
50///
51/// broadcasted_rhs: a broadcasted version of PE::WIDTH subfield scalars
52/// with each one occurring PE::PackedSubfield::WIDTH/PE::WIDTH times in  a row
53/// such that the bits of the broadcasted scalars align with the lhs scalars
54pub fn ext_base_op<PE, F, Func>(
55	lhs: &mut [PE],
56	rhs: &[PE::PackedSubfield],
57	op: Func,
58) -> Result<(), Error>
59where
60	PE: PackedExtension<F>,
61	F: Field,
62	Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
63{
64	if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
65		return Err(Error::MismatchedLengths);
66	}
67
68	lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
69		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing PackedExtension
70		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
71
72		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
73	});
74	Ok(())
75}
76
77/// A multithreaded version of the funcion directly above, use for long arrays
78/// on the prover side
79pub fn ext_base_op_par<PE, F, Func>(
80	lhs: &mut [PE],
81	rhs: &[PE::PackedSubfield],
82	op: Func,
83) -> Result<(), Error>
84where
85	PE: PackedExtension<F>,
86	F: Field,
87	Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
88{
89	if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
90		return Err(Error::MismatchedLengths);
91	}
92
93	lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
94		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing PackedExtension
95		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
96
97		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
98	});
99
100	Ok(())
101}
102
103#[cfg(test)]
104mod tests {
105	use proptest::prelude::*;
106
107	use crate::{
108		ext_base_mul, ext_base_mul_par,
109		packed::{get_packed_slice, pack_slice},
110		underlier::WithUnderlier,
111		BinaryField128b, BinaryField16b, BinaryField8b, PackedBinaryField16x16b,
112		PackedBinaryField2x128b, PackedBinaryField32x8b,
113	};
114
115	fn strategy_8b_scalars() -> impl Strategy<Value = [BinaryField8b; 32]> {
116		any::<[<BinaryField8b as WithUnderlier>::Underlier; 32]>()
117			.prop_map(|arr| arr.map(<BinaryField8b>::from_underlier))
118	}
119
120	fn strategy_16b_scalars() -> impl Strategy<Value = [BinaryField16b; 32]> {
121		any::<[<BinaryField16b as WithUnderlier>::Underlier; 32]>()
122			.prop_map(|arr| arr.map(<BinaryField16b>::from_underlier))
123	}
124
125	fn strategy_128b_scalars() -> impl Strategy<Value = [BinaryField128b; 32]> {
126		any::<[<BinaryField128b as WithUnderlier>::Underlier; 32]>()
127			.prop_map(|arr| arr.map(<BinaryField128b>::from_underlier))
128	}
129
130	proptest! {
131		#[test]
132		fn test_base_ext_mul_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
133			let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
134			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
135
136			ext_base_mul(&mut ext_packed, &base_packed).unwrap();
137
138			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
139				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
140			}
141		}
142
143		#[test]
144		fn test_base_ext_mul_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
145			let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
146			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
147
148			ext_base_mul(&mut ext_packed, &base_packed).unwrap();
149
150			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
151				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
152			}
153		}
154
155
156		#[test]
157		fn test_base_ext_mul_par_8(base_scalars in strategy_8b_scalars(), ext_scalars in strategy_128b_scalars()){
158			let base_packed = pack_slice::<PackedBinaryField32x8b>(&base_scalars);
159			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
160
161			ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
162
163			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
164				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
165			}
166		}
167
168		#[test]
169		fn test_base_ext_mul_par_16(base_scalars in strategy_16b_scalars(), ext_scalars in strategy_128b_scalars()){
170			let base_packed = pack_slice::<PackedBinaryField16x16b>(&base_scalars);
171			let mut ext_packed = pack_slice::<PackedBinaryField2x128b>(&ext_scalars);
172
173			ext_base_mul_par(&mut ext_packed, &base_packed).unwrap();
174
175			for (i, (base, ext)) in base_scalars.iter().zip(ext_scalars).enumerate(){
176				assert_eq!(ext * *base, get_packed_slice(&ext_packed, i));
177			}
178		}
179	}
180}