binius_field/
packed_extension_ops.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_utils::rayon::prelude::{
4	IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
5};
6
7use crate::{ExtensionField, Field, PackedExtension, PackedField};
8
9pub fn ext_base_mul<PE: PackedExtension<F>, F: Field>(lhs: &mut [PE], rhs: &[PE::PackedSubfield]) {
10	ext_base_op(lhs, rhs, |_, lhs, broadcasted_rhs| PE::cast_ext(lhs.cast_base() * broadcasted_rhs))
11}
12
13pub fn ext_base_mul_par<PE: PackedExtension<F>, F: Field>(
14	lhs: &mut [PE],
15	rhs: &[PE::PackedSubfield],
16) {
17	ext_base_op_par(lhs, rhs, |_, lhs, broadcasted_rhs| {
18		PE::cast_ext(lhs.cast_base() * broadcasted_rhs)
19	})
20}
21
22/// # Safety
23///
24/// Width of PackedSubfield is >= the width of the field implementing PackedExtension.
25pub unsafe fn get_packed_subfields_at_pe_idx<PE: PackedExtension<F>, F: Field>(
26	packed_subfields: &[PE::PackedSubfield],
27	i: usize,
28) -> PE::PackedSubfield {
29	let bottom_most_scalar_idx = i * PE::WIDTH;
30	let bottom_most_scalar_idx_in_subfield_arr = bottom_most_scalar_idx / PE::PackedSubfield::WIDTH;
31	let bottom_most_scalar_idx_within_packed_subfield =
32		bottom_most_scalar_idx % PE::PackedSubfield::WIDTH;
33	let block_idx = bottom_most_scalar_idx_within_packed_subfield / PE::WIDTH;
34
35	unsafe {
36		packed_subfields
37			.get_unchecked(bottom_most_scalar_idx_in_subfield_arr)
38			.spread_unchecked(PE::LOG_WIDTH, block_idx)
39	}
40}
41
42/// Refer to the functions above for examples of closures to pass
43/// Func takes in the following parameters
44///
45/// Note that this function overwrites the lhs buffer, copy that data before
46/// invoking this function if you need to use it elsewhere
47///
48/// lhs: PE::WIDTH extension field scalars
49///
50/// broadcasted_rhs: a broadcasted version of PE::WIDTH subfield scalars
51/// with each one occurring PE::PackedSubfield::WIDTH/PE::WIDTH times in  a row
52/// such that the bits of the broadcasted scalars align with the lhs scalars
53///
54/// # Preconditions
55///
56/// * `lhs.len()` must equal `rhs.len() * PE::Scalar::DEGREE`.
57pub fn ext_base_op<PE, F, Func>(lhs: &mut [PE], rhs: &[PE::PackedSubfield], op: Func)
58where
59	PE: PackedExtension<F>,
60	F: Field,
61	Func: Fn(usize, PE, PE::PackedSubfield) -> PE,
62{
63	assert!(
64		lhs.len() == rhs.len() * PE::Scalar::DEGREE,
65		"lhs.len() ({}) must equal rhs.len() * PE::Scalar::DEGREE ({})",
66		lhs.len(),
67		rhs.len() * PE::Scalar::DEGREE
68	);
69
70	lhs.iter_mut().enumerate().for_each(|(i, lhs_elem)| {
71		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing
72		// PackedExtension
73		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
74
75		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
76	});
77}
78
79/// A multithreaded version of the function directly above, use for long arrays
80/// on the prover side
81///
82/// # Preconditions
83///
84/// * `lhs.len()` must equal `rhs.len() * PE::Scalar::DEGREE`.
85pub fn ext_base_op_par<PE, F, Func>(lhs: &mut [PE], rhs: &[PE::PackedSubfield], op: Func)
86where
87	PE: PackedExtension<F>,
88	F: Field,
89	Func: Fn(usize, PE, PE::PackedSubfield) -> PE + std::marker::Sync,
90{
91	assert!(
92		lhs.len() == rhs.len() * PE::Scalar::DEGREE,
93		"lhs.len() ({}) must equal rhs.len() * PE::Scalar::DEGREE ({})",
94		lhs.len(),
95		rhs.len() * PE::Scalar::DEGREE
96	);
97
98	lhs.par_iter_mut().enumerate().for_each(|(i, lhs_elem)| {
99		// SAFETY: Width of PackedSubfield is always >= the width of the field implementing
100		// PackedExtension
101		let broadcasted_rhs = unsafe { get_packed_subfields_at_pe_idx::<PE, F>(rhs, i) };
102
103		*lhs_elem = op(i, *lhs_elem, broadcasted_rhs);
104	});
105}