|
1 | 1 | use ndarray::prelude::*;
|
| 2 | +use ndarray::{ShapeError, ErrorKind, arr3}; |
2 | 3 |
|
3 | 4 | #[test]
|
4 | 5 | #[cfg(feature = "std")]
|
@@ -81,3 +82,33 @@ fn test_broadcast_1d() {
|
81 | 82 | println!("b2=\n{:?}", b2);
|
82 | 83 | assert_eq!(b0, b2);
|
83 | 84 | }
|
| 85 | + |
| 86 | +#[test] |
| 87 | +fn test_broadcast_with() { |
| 88 | + let a = arr2(&[[1., 2.], [3., 4.]]); |
| 89 | + let b = aview0(&1.); |
| 90 | + let (a1, b1) = a.broadcast_with(&b).unwrap(); |
| 91 | + assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]])); |
| 92 | + assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]])); |
| 93 | + |
| 94 | + let a = arr2(&[[2], [3], [4]]); |
| 95 | + let b = arr1(&[5, 6, 7]); |
| 96 | + let (a1, b1) = a.broadcast_with(&b).unwrap(); |
| 97 | + assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]])); |
| 98 | + assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]])); |
| 99 | + |
| 100 | + // Negative strides and non-contiguous memory |
| 101 | + let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; |
| 102 | + let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap(); |
| 103 | + let a = s.slice(s![..;-1,..;2,..]); |
| 104 | + let b = s.slice(s![..2, -1, ..]); |
| 105 | + let (a1, b1) = a.broadcast_with(&b).unwrap(); |
| 106 | + assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]])); |
| 107 | + assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]])); |
| 108 | + |
| 109 | + // ShapeError |
| 110 | + let a = arr2(&[[2, 2], [3, 3], [4, 4]]); |
| 111 | + let b = arr1(&[5, 6, 7]); |
| 112 | + let e = a.broadcast_with(&b); |
| 113 | + assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))); |
| 114 | +} |
0 commit comments