Skip to content

Commit a0df13d

Browse files
committed
Applied review comments
1 parent 7b0c529 commit a0df13d

File tree

5 files changed

+51
-28
lines changed

5 files changed

+51
-28
lines changed

dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
556556
const dpctl::tensor::usm_ndarray &src2,
557557
const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise
558558
sycl::queue &exec_q,
559-
const std::vector<sycl::event> depends,
559+
const std::vector<sycl::event> &depends,
560560
//
561561
const output_typesT &output_type_table,
562562
const contig_dispatchT &contig_dispatch_table,
@@ -870,7 +870,7 @@ std::pair<sycl::event, sycl::event>
870870
const dpctl::tensor::usm_ndarray &dst1,
871871
const dpctl::tensor::usm_ndarray &dst2,
872872
sycl::queue &exec_q,
873-
const std::vector<sycl::event> depends,
873+
const std::vector<sycl::event> &depends,
874874
//
875875
const output_typesT &output_types_table,
876876
const contig_dispatchT &contig_dispatch_table,
@@ -952,8 +952,10 @@ std::pair<sycl::event, sycl::event>
952952
auto const &same_logical_tensors =
953953
dpctl::tensor::overlap::SameLogicalTensors();
954954
if ((overlap(src1, dst1) && !same_logical_tensors(src1, dst1)) ||
955+
(overlap(src1, dst2) && !same_logical_tensors(src1, dst2)) ||
955956
(overlap(src2, dst1) && !same_logical_tensors(src2, dst1)) ||
956-
(overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2)))
957+
(overlap(src2, dst2) && !same_logical_tensors(src2, dst2)) ||
958+
(overlap(dst1, dst2)))
957959
{
958960
throw py::value_error("Arrays index overlapping segments of memory");
959961
}
@@ -1142,7 +1144,7 @@ std::pair<sycl::event, sycl::event>
11421144
py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs,
11431145
const dpctl::tensor::usm_ndarray &rhs,
11441146
sycl::queue &exec_q,
1145-
const std::vector<sycl::event> depends,
1147+
const std::vector<sycl::event> &depends,
11461148
//
11471149
const output_typesT &output_type_table,
11481150
const contig_dispatchT &contig_dispatch_table,

dpnp/backend/kernels/elementwise_functions/divmod.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
#pragma once
3030

31+
#include <type_traits>
32+
3133
#include <sycl/sycl.hpp>
3234

3335
namespace dpnp::kernels::divmod

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
510510
511511
Args:
512512
name : {str}
513-
Name of the unary function
513+
Name of the binary function
514514
result_type_resovle_fn : {callable}
515515
Function that takes dtype of the input and returns the dtype of
516516
the result if the implementation functions supports it, or
@@ -530,7 +530,7 @@ class DPNPBinaryFunc(BinaryElementwiseFunc):
530530
corresponds to computational tasks associated with function
531531
evaluation.
532532
docs : {str}
533-
Documentation string for the unary function.
533+
Documentation string for the binary function.
534534
mkl_fn_to_call : {None, str}
535535
Check input arguments to answer if function from OneMKL VM library
536536
can be used.
@@ -828,17 +828,17 @@ def __call__(self, *args, **kwargs):
828828

829829
class DPNPBinaryTwoOutputsFunc(BinaryElementwiseFunc):
830830
"""
831-
Class that implements unary element-wise functions with two output arrays.
831+
Class that implements binary element-wise functions with two output arrays.
832832
833833
Parameters
834834
----------
835835
name : {str}
836-
Name of the unary function
836+
Name of the binary function
837837
result_type_resolver_fn : {callable}
838838
Function that takes dtype of the input and returns the dtype of
839839
the result if the implementation functions supports it, or
840840
returns `None` otherwise.
841-
unary_dp_impl_fn : {callable}
841+
binary_dp_impl_fn : {callable}
842842
Data-parallel implementation function with signature
843843
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
844844
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
@@ -852,7 +852,7 @@ class DPNPBinaryTwoOutputsFunc(BinaryElementwiseFunc):
852852
computational tasks complete execution, while the second event
853853
corresponds to computational tasks associated with function evaluation.
854854
docs : {str}
855-
Documentation string for the unary function.
855+
Documentation string for the binary function.
856856
857857
"""
858858

@@ -1020,6 +1020,14 @@ def __call__(
10201020
if not res.flags.writable:
10211021
raise ValueError("output array is read-only")
10221022

1023+
for other_out in out[:i]:
1024+
if other_out is None:
1025+
continue
1026+
1027+
other_out = dpnp.get_usm_ndarray(other_out)
1028+
if dti._array_overlap(res, other_out):
1029+
raise ValueError("Output arrays cannot overlap")
1030+
10231031
if res.shape != res_shape:
10241032
raise ValueError(
10251033
"The shape of input and output arrays are inconsistent. "
@@ -1042,19 +1050,20 @@ def __call__(
10421050
# Allocate a temporary buffer with the required dtype
10431051
out[i] = dpt.empty_like(res, dtype=res_dt)
10441052
else:
1045-
for x, dt in zip([x1, x2], buf_dts):
1046-
if dpnp.isscalar(x):
1047-
pass
1048-
elif dt is not None:
1049-
pass
1050-
elif not dti._array_overlap(x, res):
1051-
pass
1052-
elif dti._same_logical_tensors(x, res):
1053-
pass
1054-
1055-
# Allocate a temporary buffer to avoid memory overlapping.
1056-
# Note if `dt` is not None, a temporary copy of `x` will be
1057-
# created, so the array overlap check isn't needed.
1053+
# If `dt` is not None, a temporary copy of `x` will be created,
1054+
# so the array overlap check isn't needed.
1055+
x_to_check = [
1056+
x
1057+
for x, dt in zip([x1, x2], buf_dts)
1058+
if not dpnp.isscalar(x) and dt is None
1059+
]
1060+
1061+
if any(
1062+
dti._array_overlap(x, res)
1063+
and not dti._same_logical_tensors(x, res)
1064+
for x in x_to_check
1065+
):
1066+
# allocate a temporary buffer to avoid memory overlapping
10581067
out[i] = dpt.empty_like(res)
10591068

10601069
x1 = dpnp.as_usm_ndarray(x1, dtype=x1_dt, sycl_queue=exec_q)

dpnp/dpnp_iface_mathematical.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,19 +1581,19 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
15811581
Divisor input array, expected to have a real-valued floating-point data
15821582
type.
15831583
out1 : {None, dpnp.ndarray, usm_ndarray}, optional
1584-
Output array for the quotient to populate. Array must have the same shape
1585-
as `x` and the expected data type.
1584+
Output array for the quotient to populate. Array must have a shape that
1585+
the inputs broadcast to and the expected data type.
15861586
15871587
Default: ``None``.
15881588
out2 : {None, dpnp.ndarray, usm_ndarray}, optional
1589-
Output array for the remainder to populate. Array must have the same shape
1590-
as `x` and the expected data type.
1589+
Output array for the remainder to populate. Array must have a shape that
1590+
the inputs broadcast to and the expected data type.
15911591
15921592
Default: ``None``.
15931593
out : tuple of None, dpnp.ndarray, or usm_ndarray, optional
15941594
A location into which the result is stored. If provided, it must be a tuple
15951595
and have length equal to the number of outputs. Each provided array must
1596-
have the same shape as `x` and the expected data type.
1596+
have a shape that the inputs broadcast to and the expected data type.
15971597
It is prohibited to pass output arrays through `out` keyword when either
15981598
`out1` or `out2` is passed.
15991599

dpnp/tests/test_binary_two_outputs_ufuncs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,16 @@ def test_out1_overlap(self, func, dt):
217217
_ = getattr(numpy, func)(a[size::], 1, a[::2])
218218
assert_array_equal(ia, a)
219219

220+
def test_out_arrays_overlap(self, func):
221+
a = dpnp.arange(7)
222+
out = dpnp.zeros_like(a)
223+
224+
with pytest.raises(
225+
ValueError,
226+
match="Output arrays cannot overlap",
227+
):
228+
_ = getattr(dpnp, func)(a, 2, out=(out, out))
229+
220230
def test_out_scalar_input(self, func):
221231
a = generate_random_numpy_array((3, 7), low=1, dtype=int)
222232
out = numpy.zeros_like(a)

0 commit comments

Comments
 (0)