Skip to content

Commit d48923e

Browse files
amit0365claude
andcommitted
Add mutable and parallel iteration to Matrix trait
- Add `cells_mut` to the `Matrix<T>` trait - Add `par_cells` and `par_cells_mut` under `parallel` feature - Implement `as_rows_par` and `as_rows_mut_par` for `DenseRowMatrix` Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent f410914 commit d48923e

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ num-bigint = { version = "0.4", default-features = false, optional = true }
2727

2828
# Optional utility for which implementations are provided
2929
rand = { workspace = true, optional = true }
30+
rayon = { version = "1", optional = true }
3031
serde = { version = "1", default-features = false, optional = true }
3132
zeroize = { version = "1", default-features = false, optional = true }
3233

@@ -39,6 +40,7 @@ rand = { workspace = true, features = ["std_rng"] }
3940
[features]
4041
default = ["std"]
4142

43+
parallel = ["dep:rayon"]
4244
rand = ["dep:rand", "crypto-bigint/rand_core"]
4345
serde = ["dep:serde", "crypto-bigint/serde"]
4446
zeroize = ["dep:zeroize", "crypto-bigint/zeroize"]

src/matrix.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use alloc::vec::Vec;
22
use core::mem::{ManuallyDrop, MaybeUninit};
33

4+
#[cfg(feature = "parallel")]
5+
use rayon::prelude::*;
6+
47
/// A matrix, rectangular table of values
58
pub trait Matrix<T> {
69
/// Number of rows in this matrix
@@ -13,9 +16,29 @@ pub trait Matrix<T> {
1316
where
1417
T: 'a;
1518

19+
fn cells_mut<'a>(
20+
&'a mut self,
21+
) -> impl Iterator<Item = impl Iterator<Item = (usize, &'a mut T)>>
22+
where
23+
T: 'a;
24+
1625
fn is_empty(&self) -> bool {
1726
self.num_rows() == 0 || self.num_cols() == 0
1827
}
28+
29+
#[cfg(feature = "parallel")]
30+
fn par_cells<'a>(
31+
&'a self,
32+
) -> impl ParallelIterator<Item = impl Iterator<Item = (usize, &'a T)>>
33+
where
34+
T: 'a + Send + Sync;
35+
36+
#[cfg(feature = "parallel")]
37+
fn par_cells_mut<'a>(
38+
&'a mut self,
39+
) -> impl ParallelIterator<Item = impl Iterator<Item = (usize, &'a mut T)>>
40+
where
41+
T: 'a + Send + Sync;
1942
}
2043

2144
/// Sparse matrix is a matrix with a fixed number non-zero of elements per row
@@ -56,6 +79,37 @@ impl<T> Matrix<T> for SparseMatrix<T> {
5679
.chunks(self.density)
5780
.map(|chunk| chunk.iter().map(|v| (v.0, &v.1)))
5881
}
82+
83+
fn cells_mut<'a>(&'a mut self) -> impl Iterator<Item = impl Iterator<Item = (usize, &'a mut T)>>
84+
where
85+
T: 'a,
86+
{
87+
self.cells
88+
.chunks_mut(self.density)
89+
.map(|chunk| chunk.iter_mut().map(|v| (v.0, &mut v.1)))
90+
}
91+
92+
#[cfg(feature = "parallel")]
93+
fn par_cells<'a>(&'a self) -> impl ParallelIterator<Item = impl Iterator<Item = (usize, &'a T)>>
94+
where
95+
T: 'a + Send + Sync,
96+
{
97+
self.cells
98+
.par_chunks(self.density)
99+
.map(|chunk| chunk.iter().map(|v| (v.0, &v.1)))
100+
}
101+
102+
#[cfg(feature = "parallel")]
103+
fn par_cells_mut<'a>(
104+
&'a mut self,
105+
) -> impl ParallelIterator<Item = impl Iterator<Item = (usize, &'a mut T)>>
106+
where
107+
T: 'a + Send + Sync,
108+
{
109+
self.cells
110+
.par_chunks_mut(self.density)
111+
.map(|chunk| chunk.iter_mut().map(|v| (v.0, &mut v.1)))
112+
}
59113
}
60114

61115
impl<T, T2> From<&SparseMatrix<T2>> for SparseMatrix<T>
@@ -111,6 +165,22 @@ impl<T: Clone> DenseRowMatrix<T> {
111165
self.data.chunks_exact_mut(self.num_cols)
112166
}
113167

168+
#[cfg(feature = "parallel")]
169+
pub fn as_rows_par(&self) -> impl ParallelIterator<Item = &[T]>
170+
where
171+
T: Send + Sync,
172+
{
173+
self.data.par_chunks_exact(self.num_cols)
174+
}
175+
176+
#[cfg(feature = "parallel")]
177+
pub fn as_rows_mut_par(&mut self) -> impl ParallelIterator<Item = &mut [T]>
178+
where
179+
T: Send + Sync,
180+
{
181+
self.data.par_chunks_exact_mut(self.num_cols)
182+
}
183+
114184
pub fn to_rows_slices(&self) -> Vec<&[T]> {
115185
self.as_rows().collect()
116186
}
@@ -167,6 +237,37 @@ impl<T> Matrix<T> for DenseRowMatrix<T> {
167237
.chunks_exact(self.num_cols)
168238
.map(|row| row.iter().enumerate())
169239
}
240+
241+
fn cells_mut<'a>(&'a mut self) -> impl Iterator<Item = impl Iterator<Item = (usize, &'a mut T)>>
242+
where
243+
T: 'a,
244+
{
245+
self.data
246+
.chunks_exact_mut(self.num_cols)
247+
.map(|row| row.iter_mut().enumerate())
248+
}
249+
250+
#[cfg(feature = "parallel")]
251+
fn par_cells<'a>(&'a self) -> impl ParallelIterator<Item = impl Iterator<Item = (usize, &'a T)>>
252+
where
253+
T: 'a + Send + Sync,
254+
{
255+
self.data
256+
.par_chunks_exact(self.num_cols)
257+
.map(|row| row.iter().enumerate())
258+
}
259+
260+
#[cfg(feature = "parallel")]
261+
fn par_cells_mut<'a>(
262+
&'a mut self,
263+
) -> impl ParallelIterator<Item = impl Iterator<Item = (usize, &'a mut T)>>
264+
where
265+
T: 'a + Send + Sync,
266+
{
267+
self.data
268+
.par_chunks_exact_mut(self.num_cols)
269+
.map(|row| row.iter_mut().enumerate())
270+
}
170271
}
171272

172273
impl<T> From<Vec<Vec<T>>> for DenseRowMatrix<T> {

0 commit comments

Comments
 (0)