binius_field/
packed_extension_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_utils::rayon::prelude::{
4	IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{Error, ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(
10	lhs: &mut [PE],
11	rhs: &[PE::PackedSubfield],
12) -> Result<(), Error> {
13	ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
14}
15
16pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
17	lhs: &mut [PE],
18	rhs: &[PE::PackedSubfield],
19) -> Result<(), Error> {
20	ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
21		PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
22	})
23}
24
25/// # Safety
26///
27/// Width of PackedSubfield is >= the width of the field implementing PackedExtension.
28pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
29	packed_subfields: &[PE::PackedSubfield],
30	i: usize,
31) -> PE::PackedSubfield {
32	let bottom_most_scalar_idx = i * PE::WIDTH;
33	let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
34	let bottom_most_scalar_idx_within_packed_subfield =
35		bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
36	let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
37
38	unsafe {
39		packed_subfields
40			.get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
41			.spread_unchecked(PE::LOG_WIDTH, block_idx)
42	}
43}
44
45/// Refer to the functions above for examples of closures to pass
46/// Func takes in the following parameters
47///
48/// Note that this function overwrites the lhs buffer, copy that data before
49/// invoking this function if you need to use it elsewhere
50///
51/// lhs: PE::WIDTH extension field scalars
52///
53/// broadcasted_rhs: a broadcasted version of PE::WIDTH subfield scalars
54/// with each one occurring PE::PackedSubfield::WIDTH/PE::WIDTH times in  a row
55/// such that the bits of the broadcasted scalars align with the lhs scalars
56pub fn ext_base_op<PE, F, Func>(
57	lhs: &mut [PE],
58	rhs: &[PE::PackedSubfield],
59	op: Func,
60) -> Result<(), Error>
61where
62	PE: PackedExtension<F>,
63	F: Field,
64	Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
65{
66	if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
67		return Err(Error::MismatchedLengths);
68	}
69
70	lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
71		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing
72		// PackedExtension
73		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
74
75		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
76	});
77	Ok(())
78}
79
80/// A multithreaded version of the function directly above, use for long arrays
81/// on the prover side
82pub fn ext_base_op_par<PE, F, Func>(
83	lhs: &mut [PE],
84	rhs: &[PE::PackedSubfield],
85	op: Func,
86) -> Result<(), Error>
87where
88	PE: PackedExtension<F>,
89	F: Field,
90	Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
91{
92	if lhs.len() != rhs.len() * PE::Scalar::DEGREE {
93		return Err(Error::MismatchedLengths);
94	}
95
96	lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
97		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing
98		// PackedExtension
99		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
100
101		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
102	});
103
104	Ok(())
105}