@@ -2503,6 +2503,64 @@ impl Tensor {
2503
2503
t. transpose ( dim, last)
2504
2504
}
2505
2505
}
2506
+
2507
+ /// Returns a copy of `self` where the values within `ranges` have been replaced with the
2508
+ /// content of `src`.
2509
+ pub fn slice_assign < D : std:: ops:: RangeBounds < usize > > (
2510
+ & self ,
2511
+ ranges : & [ D ] ,
2512
+ src : & Tensor ,
2513
+ ) -> Result < Self > {
2514
+ let src_dims = src. dims ( ) ;
2515
+ let self_dims = self . dims ( ) ;
2516
+ if self_dims. len ( ) != src_dims. len ( ) {
2517
+ crate :: bail!(
2518
+ "slice-assign requires input with the same rank {} <> {}" ,
2519
+ self_dims. len( ) ,
2520
+ src_dims. len( )
2521
+ )
2522
+ }
2523
+ if self_dims. len ( ) != ranges. len ( ) {
2524
+ crate :: bail!(
2525
+ "slice-assign requires input with the same rank as there are ranges {} <> {}" ,
2526
+ self_dims. len( ) ,
2527
+ ranges. len( )
2528
+ )
2529
+ }
2530
+ let mut src = src. clone ( ) ;
2531
+ let mut mask = Self :: ones ( src. shape ( ) , DType :: U8 , src. device ( ) ) ?;
2532
+ for ( i, range) in ranges. iter ( ) . enumerate ( ) {
2533
+ let start_included = match range. start_bound ( ) {
2534
+ std:: ops:: Bound :: Unbounded => 0 ,
2535
+ std:: ops:: Bound :: Included ( v) => * v,
2536
+ std:: ops:: Bound :: Excluded ( v) => * v + 1 ,
2537
+ } ;
2538
+ let end_excluded = match range. end_bound ( ) {
2539
+ std:: ops:: Bound :: Unbounded => self_dims[ i] ,
2540
+ std:: ops:: Bound :: Included ( v) => * v + 1 ,
2541
+ std:: ops:: Bound :: Excluded ( v) => * v,
2542
+ } ;
2543
+ if end_excluded <= start_included {
2544
+ crate :: bail!(
2545
+ "slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
2546
+ )
2547
+ }
2548
+ if self_dims[ i] < end_excluded {
2549
+ crate :: bail!(
2550
+ "slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}" ,
2551
+ self_dims[ i]
2552
+ )
2553
+ }
2554
+ if end_excluded - start_included != src_dims[ i] {
2555
+ crate :: bail!(
2556
+ "slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}" , src_dims[ i]
2557
+ )
2558
+ }
2559
+ src = src. pad_with_zeros ( i, start_included, self_dims[ i] - end_excluded) ?;
2560
+ mask = mask. pad_with_zeros ( i, start_included, self_dims[ i] - end_excluded) ?
2561
+ }
2562
+ mask. where_cond ( /* on_true= */ & src, /* on_false= */ self )
2563
+ }
2506
2564
}
2507
2565
2508
2566
macro_rules! bin_trait {
0 commit comments