Skip to content

Commit 481c45d

Browse files
Add a basic implementation for slice-assign. (huggingface#1377)
1 parent 14a2bdc commit 481c45d

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

candle-core/src/tensor.rs

+58
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,64 @@ impl Tensor {
25032503
t.transpose(dim, last)
25042504
}
25052505
}
2506+
2507+
/// Returns a copy of `self` where the values within `ranges` have been replaced with the
2508+
/// content of `src`.
2509+
pub fn slice_assign<D: std::ops::RangeBounds<usize>>(
2510+
&self,
2511+
ranges: &[D],
2512+
src: &Tensor,
2513+
) -> Result<Self> {
2514+
let src_dims = src.dims();
2515+
let self_dims = self.dims();
2516+
if self_dims.len() != src_dims.len() {
2517+
crate::bail!(
2518+
"slice-assign requires input with the same rank {} <> {}",
2519+
self_dims.len(),
2520+
src_dims.len()
2521+
)
2522+
}
2523+
if self_dims.len() != ranges.len() {
2524+
crate::bail!(
2525+
"slice-assign requires input with the same rank as there are ranges {} <> {}",
2526+
self_dims.len(),
2527+
ranges.len()
2528+
)
2529+
}
2530+
let mut src = src.clone();
2531+
let mut mask = Self::ones(src.shape(), DType::U8, src.device())?;
2532+
for (i, range) in ranges.iter().enumerate() {
2533+
let start_included = match range.start_bound() {
2534+
std::ops::Bound::Unbounded => 0,
2535+
std::ops::Bound::Included(v) => *v,
2536+
std::ops::Bound::Excluded(v) => *v + 1,
2537+
};
2538+
let end_excluded = match range.end_bound() {
2539+
std::ops::Bound::Unbounded => self_dims[i],
2540+
std::ops::Bound::Included(v) => *v + 1,
2541+
std::ops::Bound::Excluded(v) => *v,
2542+
};
2543+
if end_excluded <= start_included {
2544+
crate::bail!(
2545+
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
2546+
)
2547+
}
2548+
if self_dims[i] < end_excluded {
2549+
crate::bail!(
2550+
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
2551+
self_dims[i]
2552+
)
2553+
}
2554+
if end_excluded - start_included != src_dims[i] {
2555+
crate::bail!(
2556+
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
2557+
)
2558+
}
2559+
src = src.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?;
2560+
mask = mask.pad_with_zeros(i, start_included, self_dims[i] - end_excluded)?
2561+
}
2562+
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
2563+
}
25062564
}
25072565

25082566
macro_rules! bin_trait {

candle-core/tests/indexing_tests.rs

+29
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,32 @@ fn index_3d() -> Result<()> {
9191
assert_eq!(tensor.i((1, .., 3))?.to_vec1::<u32>()?, &[15, 19, 23]);
9292
Ok(())
9393
}
94+
95+
#[test]
96+
fn slice_assign() -> Result<()> {
97+
let dev = Device::Cpu;
98+
99+
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
100+
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
101+
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
102+
assert_eq!(
103+
out.to_vec2::<u32>()?,
104+
&[
105+
[0, 1, 2, 3, 4],
106+
[5, 6, 7, 0, 1],
107+
[10, 11, 12, 2, 3],
108+
[15, 16, 17, 4, 5]
109+
]
110+
);
111+
let out = tensor.slice_assign(&[0..3, 0..2], &src)?;
112+
assert_eq!(
113+
out.to_vec2::<u32>()?,
114+
&[
115+
[0, 1, 2, 3, 4],
116+
[2, 3, 7, 8, 9],
117+
[4, 5, 12, 13, 14],
118+
[15, 16, 17, 18, 19]
119+
]
120+
);
121+
Ok(())
122+
}

0 commit comments

Comments
 (0)