Skip to content

Commit 9d505e7

Browse files
committed
Extend status propagation to XlaDataToTensors and update callers
Modify `XlaDataToTensors` function to use proper status propagation instead of `GetValueOrThrow`, and update all callers to handle the new `StatusOr<T>` return type. This continues the status propagation improvements started with `ReleaseGilAndTransferData`. Changes: - Update `XlaDataToTensors` signature to return `absl::StatusOr<std::vector<at::Tensor>>` - Replace `GetValueOrThrow` with `XLA_ASSIGN_OR_RETURN` for `ReleaseGilAndTransferData` call - Update all callers to use `GetValueOrThrow` wrapper: - `XLATensor::ToTensor` in tensor.cpp:515 - test_xla_sharding.cpp:31 - init_python_bindings.cpp:2716 - xla_backend_impl.cpp:95 - Add necessary status.h includes to xla_backend_impl.cpp and test_xla_sharding.cpp This maintains backward compatibility at the API level while enabling proper status propagation internally within the tensor conversion pipeline.
1 parent 8df24b7 commit 9d505e7

File tree

7 files changed

+13
-8
lines changed

7 files changed

+13
-8
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "torch_xla/csrc/tensor.h"
1818
#include "torch_xla/csrc/tensor_methods.h"
1919
#include "torch_xla/csrc/tensor_util.h"
20+
#include "torch_xla/csrc/status.h"
2021
#include "torch_xla/csrc/xla_sharding_util.h"
2122
#include "xla/protobuf_util.h"
2223
#include "xla/xla_data.pb.h"
@@ -28,7 +29,7 @@ bool XlaDataValuesEqual(torch::lazy::BackendDataPtr a,
2829
torch::lazy::BackendDataPtr b,
2930
at::ScalarType element_type) {
3031
std::vector<at::Tensor> tensors =
31-
XlaDataToTensors({a, b}, {element_type, element_type});
32+
GetValueOrThrow(XlaDataToTensors({a, b}, {element_type, element_type}));
3233
return TensorCompare(tensors[0], tensors[1]);
3334
}
3435
} // namespace

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2713,7 +2713,7 @@ void InitXlaModuleBindings(py::module m) {
27132713
}
27142714

27152715
std::vector<at::Tensor> cpu_shards =
2716-
XlaDataToTensors(WrapXlaData(handles), element_types);
2716+
GetValueOrThrow(XlaDataToTensors(WrapXlaData(handles), element_types));
27172717
// Populate the resulting vector of shards and device strings
27182718
std::vector<std::vector<std::pair<at::Tensor, std::string>>> result;
27192719
int shards_per_tensor =

torch_xla/csrc/tensor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
4141
#include "torch_xla/csrc/runtime/sys_util.h"
4242
#include "torch_xla/csrc/runtime/xla_util.h"
43+
#include "torch_xla/csrc/status.h"
4344
#include "torch_xla/csrc/tensor_util.h"
4445
#include "torch_xla/csrc/torch_util.h"
4546
#include "torch_xla/csrc/xla_graph_executor.h"
@@ -512,7 +513,7 @@ at::Tensor XLATensor::ToTensor(bool detached) {
512513
// The GetXlaData() call will trigger an ApplyPendingGraph() if an IR
513514
// XlaNode is available on the tensor.
514515
std::vector<at::Tensor> tensors =
515-
XlaDataToTensors({GetXlaData()}, {dtype()});
516+
GetValueOrThrow(XlaDataToTensors({GetXlaData()}, {dtype()}));
516517
tensor = std::move(tensors.front());
517518
if (!detached) {
518519
SetTensorData(tensor);

torch_xla/csrc/tensor_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -922,11 +922,11 @@ absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
922922
return literals;
923923
}
924924

925-
std::vector<at::Tensor> XlaDataToTensors(
925+
absl::StatusOr<std::vector<at::Tensor>> XlaDataToTensors(
926926
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
927927
absl::Span<const at::ScalarType> dest_element_type) {
928-
std::vector<xla::Literal> literals =
929-
GetValueOrThrow(ReleaseGilAndTransferData(xla_data));
928+
XLA_ASSIGN_OR_RETURN(std::vector<xla::Literal> literals,
929+
ReleaseGilAndTransferData(xla_data));
930930
std::vector<at::Tensor> tensors(literals.size());
931931
absl::BlockingCounter counter(literals.size());
932932
for (size_t i = 0; i < tensors.size(); ++i) {

torch_xla/csrc/tensor_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
3232
absl::Span<const torch::lazy::BackendDataPtr> xla_data);
3333

3434
// TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice
35-
std::vector<at::Tensor> XlaDataToTensors(
35+
absl::StatusOr<std::vector<at::Tensor>> XlaDataToTensors(
3636
absl::Span<const torch::lazy::BackendDataPtr> xla_data,
3737
absl::Span<const at::ScalarType> dest_element_type);
3838

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "torch_xla/csrc/runtime/computation_client.h"
1111
#include "torch_xla/csrc/runtime/debug_macros.h"
1212
#include "torch_xla/csrc/runtime/runtime.h"
13+
#include "torch_xla/csrc/status.h"
14+
#include "torch_xla/csrc/tensor_util.h"
1315

1416
namespace at {
1517
// This function is defined in the codegenerated RegisterDispatchKey.cpp file.
@@ -92,7 +94,7 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
9294
const torch::lazy::BackendDataPtr data,
9395
std::optional<at::ScalarType> logical_scalar_type) const override {
9496
// TODO(JackCaoG): handle the logical_scalar_type == nullptr case
95-
return XlaDataToTensors({data}, {*logical_scalar_type})[0];
97+
return GetValueOrThrow(XlaDataToTensors({data}, {*logical_scalar_type}))[0];
9698
}
9799

98100
std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "torch_xla/csrc/runtime/sys_util.h"
5353
#include "torch_xla/csrc/runtime/xla_util.h"
5454
#include "torch_xla/csrc/shape_helper.h"
55+
#include "torch_xla/csrc/status.h"
5556
#include "torch_xla/csrc/tensor_util.h"
5657
#include "torch_xla/csrc/thread_pool.h"
5758
#include "torch_xla/csrc/torch_util.h"

0 commit comments

Comments
 (0)