diff --git a/.github/scripts/run_tests.sh b/.github/scripts/run_tests.sh index d6127e44ebf4..a44e83c47b39 100755 --- a/.github/scripts/run_tests.sh +++ b/.github/scripts/run_tests.sh @@ -57,7 +57,6 @@ function run_torch_xla_cpp_tests() { #"test_xla_backend_intf" "test_xla_sharding" "test_runtime" - "test_status" "test_status_dont_show_cpp_error_context" "test_status_show_cpp_error_context") for name in "${test_names[@]}"; do diff --git a/BUILD b/BUILD index 7fb03efa757a..2e52f0e9e9b7 100644 --- a/BUILD +++ b/BUILD @@ -78,7 +78,6 @@ test_suite( "//test/cpp:test_tensor", "//test/cpp:test_xla_sharding", "//test/cpp:test_runtime", - "//test/cpp:test_status", "//test/cpp:test_status_dont_show_cpp_error_context", "//test/cpp:test_status_show_cpp_error_context", "//torch_xla/csrc/runtime:pjrt_computation_client_test", diff --git a/test/cpp/BUILD b/test/cpp/BUILD index 3d447f352664..eca2f5646b01 100644 --- a/test/cpp/BUILD +++ b/test/cpp/BUILD @@ -40,6 +40,7 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:runtime", "//torch_xla/csrc/runtime:debug_macros", "//torch_xla/csrc/runtime:sys_util", + "//torch_xla/csrc:status", "//torch_xla/csrc:tensor", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", @@ -159,15 +160,6 @@ ptxla_cc_test( ], ) -ptxla_cc_test( - name = "test_status", - srcs = ["test_status.cpp"], - deps = [ - "//torch_xla/csrc:status", - "@com_google_googletest//:gtest_main", - ], -) - ptxla_cc_test( name = "test_status_dont_show_cpp_error_context", srcs = ["test_status_dont_show_cpp_error_context.cpp"], diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 203736215974..b13981d850df 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -14,6 +14,7 @@ #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_impl.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" @@ -301,9 +302,8 @@ std::vector Execute( std::vector Fetch( absl::Span device_data) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - device_data); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(device_data)); std::vector tensors; for (auto& literal : literals) { tensors.push_back(MakeTensorFromXlaLiteral( diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 95dd6ac5da38..28e633460271 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -101,7 +101,6 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then #"test_xla_backend_intf" "test_xla_sharding" "test_runtime" - "test_status" "test_status_dont_show_cpp_error_context" "test_status_show_cpp_error_context") fi diff --git a/test/cpp/test_replication.cpp b/test/cpp/test_replication.cpp index e8bc7ca6add3..82f851aaaa9b 100644 --- a/test/cpp/test_replication.cpp +++ b/test/cpp/test_replication.cpp @@ -10,6 +10,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" @@ -78,9 +79,8 @@ void TestSingleReplication( counter.Wait(); for (size_t i = 0; i < results.size(); ++i) { - std::vector literals = - torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice( - results[i]); + std::vector literals = GetValueOrThrow( + runtime::GetComputationClientOrDie()->TransferFromDevice(results[i])); ASSERT_EQ(literals.size(), 1); // The result must be the original tensor value, multiplied by the number of diff --git a/test/cpp/test_status.cpp b/test/cpp/test_status.cpp deleted file mode 100644 index 1cfe9bafcf50..000000000000 --- a/test/cpp/test_status.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "torch_xla/csrc/status.h" - -namespace torch_xla { - -TEST(StatusTest, MaybeThrowWithOkStatus) { - absl::Status ok_status = absl::OkStatus(); - EXPECT_NO_THROW(MaybeThrow(ok_status)); -} - -TEST(StatusTest, MaybeThrowWithErrorStatus) { - absl::Status error_status = absl::InvalidArgumentError("Test error"); - EXPECT_THROW(MaybeThrow(error_status), std::runtime_error); -} - -TEST(StatusTest, GetValueOrThrowWithOkStatusOr) { - int value = 42; - absl::StatusOr status_or = value; - int result = GetValueOrThrow(std::move(status_or)); - EXPECT_EQ(result, value); -} - -TEST(StatusTest, GetValueOrThrowWithErrorStatusOr) { - absl::StatusOr status_or = absl::InvalidArgumentError("Test error"); - EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error); -} - -TEST(StatusTest, MacroReturnIfError) { - int value = 42; - - auto test_function = [=]() -> absl::StatusOr { - absl::Status ok_status = absl::OkStatus(); - XLA_RETURN_IF_ERROR(ok_status); - return value; - }; - - absl::StatusOr result = test_function(); - ASSERT_TRUE(result.ok()); - EXPECT_EQ(result.value(), value); -} - -TEST(StatusTest, MacroAssignOrReturn) { - int initial_value = 42; - int expected_value = initial_value * 2; - - auto test_function = [=]() -> absl::StatusOr { - absl::StatusOr status_or = initial_value; - XLA_ASSIGN_OR_RETURN(int value, status_or); - return value * 2; - }; - - absl::StatusOr result = test_function(); - ASSERT_TRUE(result.ok()); - EXPECT_EQ(result.value(), expected_value); -} - -} // namespace torch_xla diff --git a/test/cpp/test_status_dont_show_cpp_error_context.cpp b/test/cpp/test_status_dont_show_cpp_error_context.cpp index 7177769f2bee..d0d89efeab54 100644 --- a/test/cpp/test_status_dont_show_cpp_error_context.cpp +++ b/test/cpp/test_status_dont_show_cpp_error_context.cpp @@ -5,62 +5,181 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/status.h" +// Reminder +// ======== +// +// This file is a companion to test_status_show_cpp_error_context.cpp. +// This file specifically tests behavior when XLA_SHOW_CPP_ERROR_CONTEXT is +// set to "false". +// +// If you add or delete a test in this file, please make the corresponding +// change in test_status_show_cpp_error_context.cpp as well, adapting for +// XLA_SHOW_CPP_ERROR_CONTEXT being "true" in that file. + namespace torch_xla { +namespace { + +using absl::Status; +using absl::StatusCode; +using absl::StatusOr; +using absl::StrCat; + +constexpr char new_message[] = "New test error message"; +constexpr char message[] = "Test error message"; +constexpr char test_file[] = "test_file.cpp"; +constexpr int32_t line = 42; + +TEST(StatusWithoutErrorContextTest, MaybeThrowWithOkStatus) { + Status ok_status = absl::OkStatus(); + EXPECT_NO_THROW(MaybeThrow(ok_status)); +} + +TEST(StatusWithoutErrorContextTest, MaybeThrowWithErrorStatus) { + Status error_status = absl::InvalidArgumentError(message); + EXPECT_THROW(MaybeThrow(error_status), std::runtime_error); +} + +TEST(StatusWithoutErrorContextTest, GetValueOrThrowWithOkStatusOr) { + int value = 42; + StatusOr status_or = value; + int result = GetValueOrThrow(std::move(status_or)); + EXPECT_EQ(result, value); +} + +TEST(StatusWithoutErrorContextTest, GetValueOrThrowWithErrorStatusOr) { + StatusOr status_or = absl::InvalidArgumentError(message); + EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error); +} TEST(StatusWithoutErrorContextTest, MaybeWithLocationRetunsSameStatus) { - absl::Status error_status = absl::InvalidArgumentError("Test error message"); - absl::Status result = MaybeWithLocation(error_status, "test_file.cpp", 42); + Status error_status = absl::InvalidArgumentError(message); + Status result = MaybeWithLocation(error_status, test_file, line); EXPECT_EQ(result, error_status); } TEST(StatusWithoutErrorContextTest, MaybeWithNewMessageEmptyNewMessage) { - absl::Status error_status = absl::InvalidArgumentError("Original error"); - absl::Status result = MaybeWithNewMessage(error_status, "test_file.cpp", 42); + Status error_status = absl::InvalidArgumentError(message); + Status result = MaybeWithNewMessage(error_status, test_file, line); EXPECT_EQ(result, error_status); } TEST(StatusWithoutErrorContextTest, MaybeWithNewMessageNonEmptyNewMessage) { - constexpr char new_err_string[] = "New error message"; - absl::Status error_status = absl::InvalidArgumentError("Original error"); - absl::Status result = - MaybeWithNewMessage(error_status, "test_file.cpp", 42, new_err_string); + Status error_status = absl::InvalidArgumentError(message); + Status result = + MaybeWithNewMessage(error_status, test_file, line, new_message); - ASSERT_FALSE(result.ok()); ASSERT_NE(result, error_status); + ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), error_status.code()); - EXPECT_EQ(result.message(), new_err_string); + EXPECT_EQ(result.message(), std::string_view(new_message)); } -TEST(StatusWithoutErrorContextTest, MacroReturnIfErrorWithError) { - constexpr char err_string[] = "Test error"; +TEST(StatusWithoutErrorContextTest, MacroReturnIfError) { + int value = 42; + + auto test_function = [=]() -> StatusOr { + Status ok_status = absl::OkStatus(); + XLA_RETURN_IF_ERROR(ok_status); + return value; + }; - auto test_function = [=]() -> absl::Status { - absl::Status error_status = absl::InvalidArgumentError(err_string); + StatusOr result = test_function(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result.value(), value); +} + +TEST(StatusWithoutErrorContextTest, MacroReturnIfErrorWithError) { + auto test_function = [=]() -> Status { + Status error_status = absl::InvalidArgumentError(message); XLA_RETURN_IF_ERROR(error_status); return absl::OkStatus(); }; - absl::Status result = test_function(); + Status result = test_function(); ASSERT_FALSE(result.ok()); - EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ(result.message(), err_string); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(message)); +} + +TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithNestedError) { + auto inner_test_function = []() -> Status { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message)); + }; + + auto test_function = [&]() -> Status { + XLA_RETURN_IF_ERROR(inner_test_function()); + return absl::OkStatus(); + }; + + auto outer_test_function = [&]() -> Status { + XLA_RETURN_IF_ERROR(test_function()); + return absl::OkStatus(); + }; + + Status result = outer_test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(message)); +} + +TEST(StatusWithoutErrorContextTest, MacroReturnIfErrorWithErrorWithNewMessage) { + auto test_function = [=]() -> Status { + Status error_status = absl::InvalidArgumentError(message); + XLA_RETURN_IF_ERROR(error_status, new_message); + return absl::OkStatus(); + }; + + Status result = test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(new_message)); +} + +TEST(StatusWithoutErrorContextTest, MacroAssignOrReturn) { + int initial_value = 42; + int expected_value = initial_value * 2; + + auto test_function = [=]() -> StatusOr { + StatusOr status_or = initial_value; + XLA_ASSIGN_OR_RETURN(int value, status_or); + return value * 2; + }; + + StatusOr result = test_function(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result.value(), expected_value); } TEST(StatusWithoutErrorContextTest, MacroAssignOrReturnWithError) { - auto test_function = []() -> absl::StatusOr { - absl::StatusOr status_or = absl::InvalidArgumentError("Test error"); + auto test_function = []() -> StatusOr { + StatusOr status_or = absl::InvalidArgumentError(message); XLA_ASSIGN_OR_RETURN(int value, status_or); return value * 2; }; - absl::StatusOr result = test_function(); + StatusOr result = test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.status().code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().message(), std::string_view(message)); +} + +TEST(StatusWithoutErrorContextTest, + MacroAssignOrReturnWithErrorWithNewMessage) { + auto test_function = []() -> StatusOr { + StatusOr status_or = absl::InvalidArgumentError(message); + XLA_ASSIGN_OR_RETURN(int value, status_or, new_message); + return value * 2; + }; + + StatusOr result = test_function(); ASSERT_FALSE(result.ok()); - EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().message(), std::string_view(new_message)); } TEST(StatusWithoutErrorContextTest, MacroErrorWithLocation) { - absl::Status error_status = absl::InvalidArgumentError("Test error"); - absl::Status result = XLA_ERROR_WITH_LOCATION(error_status); + Status error_status = absl::InvalidArgumentError(message); + Status result = XLA_ERROR_WITH_LOCATION(error_status); EXPECT_EQ(result, error_status); } @@ -69,6 +188,7 @@ void SetUp() { /* replace= */ 1); } +} // namespace } // namespace torch_xla int main(int argc, char **argv) { diff --git a/test/cpp/test_status_show_cpp_error_context.cpp b/test/cpp/test_status_show_cpp_error_context.cpp index 929e1ab088a9..fa60a9161c7a 100644 --- a/test/cpp/test_status_show_cpp_error_context.cpp +++ b/test/cpp/test_status_show_cpp_error_context.cpp @@ -8,86 +8,237 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/status.h" +// Reminder +// ======== +// +// This file is a companion to test_status_dont_show_cpp_error_context.cpp. +// This file specifically tests behavior when XLA_SHOW_CPP_ERROR_CONTEXT is +// set to "true". +// +// If you add or delete a test in this file, please make the corresponding +// change in test_status_dont_show_cpp_error_context.cpp as well, adapting +// for XLA_SHOW_CPP_ERROR_CONTEXT being "false" in that file. + namespace torch_xla { namespace { +using absl::Status; +using absl::StatusCode; +using absl::StatusOr; +using absl::StrCat; + constexpr char new_message[] = "New test error message"; constexpr char message[] = "Test error message"; constexpr char test_file[] = "test_file.cpp"; constexpr int32_t line = 42; +TEST(StatusWithErrorContextTest, MaybeThrowWithOkStatus) { + Status ok_status = absl::OkStatus(); + EXPECT_NO_THROW(MaybeThrow(ok_status)); +} + +TEST(StatusWithErrorContextTest, MaybeThrowWithErrorStatus) { + Status error_status = absl::InvalidArgumentError(message); + EXPECT_THROW(MaybeThrow(error_status), std::runtime_error); +} + +TEST(StatusWithErrorContextTest, GetValueOrThrowWithOkStatusOr) { + int value = 42; + StatusOr status_or = value; + int result = GetValueOrThrow(std::move(status_or)); + EXPECT_EQ(result, value); +} + +TEST(StatusWithErrorContextTest, GetValueOrThrowWithErrorStatusOr) { + StatusOr status_or = absl::InvalidArgumentError(message); + EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error); +} + TEST(StatusWithErrorContextTest, MaybeWithLocationRetunsSameStatus) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = MaybeWithLocation(error_status, test_file, line); + Status error_status = absl::InvalidArgumentError(message); + Status result = MaybeWithLocation(error_status, test_file, line); ASSERT_NE(result, error_status); - ASSERT_EQ(result.code(), error_status.code()); + EXPECT_FALSE(result.ok()); + EXPECT_EQ(result.code(), error_status.code()); EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)"); } TEST(StatusWithErrorContextTest, MaybeWithNewMessageEmptyNewMessage) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = MaybeWithNewMessage(error_status, test_file, line); - ASSERT_NE(result, error_status); - ASSERT_EQ(result.code(), error_status.code()); - EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)"); + Status error_status = absl::InvalidArgumentError(message); + Status result = MaybeWithNewMessage(error_status, test_file, line); + EXPECT_EQ(result, error_status); } TEST(StatusWithErrorContextTest, MaybeWithNewMessageNonEmptyNewMessage) { - absl::Status error_status = absl::InvalidArgumentError(message); - absl::Status result = + Status error_status = absl::InvalidArgumentError(message); + Status result = MaybeWithNewMessage(error_status, test_file, line, new_message); + ASSERT_NE(result, error_status); ASSERT_FALSE(result.ok()); EXPECT_EQ(result.code(), error_status.code()); EXPECT_EQ(result.message(), - "New test error message (at test_file.cpp:42)\n" - "From Error: Test error message"); + StrCat("New test error message (at test_file.cpp:42)\n" + "From Error: Test error message")); } -TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithError) { - int32_t err_line = 0; +TEST(StatusWithErrorContextTest, MacroReturnIfError) { + int value = 42; + + auto test_function = [=]() -> StatusOr { + Status ok_status = absl::OkStatus(); + XLA_RETURN_IF_ERROR(ok_status); + return value; + }; + + StatusOr result = test_function(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result.value(), value); +} - auto test_function = [=, &err_line]() -> absl::Status { - absl::Status error_status = absl::InvalidArgumentError(message); - err_line = __LINE__ + 1; +TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithError) { + auto test_function = [=]() -> Status { + Status error_status = absl::InvalidArgumentError(message); XLA_RETURN_IF_ERROR(error_status); return absl::OkStatus(); }; - absl::Status result = test_function(); + Status result = test_function(); ASSERT_FALSE(result.ok()); - EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", __FILE__, - ":", err_line, ")")); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), std::string_view(message)); } -TEST(StatusWithErrorContextTest, MacroAssignOrReturnWithError) { - int32_t err_line = 0; +TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithNestedError) { + int32_t errline = 0; + auto inner_test_function = [&errline]() -> Status { + errline = __LINE__ + 1; + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message)); + }; + + auto test_function = [&]() -> Status { + XLA_RETURN_IF_ERROR(inner_test_function()); + return absl::OkStatus(); + }; + + auto outer_test_function = [&]() -> Status { + XLA_RETURN_IF_ERROR(test_function()); + return absl::OkStatus(); + }; + + Status result = outer_test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), + StrCat("Test error message (at ", __FILE__, ":", errline, ")")); +} + +TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithErrorWithNewMessage) { + int32_t errline = 0; + + auto test_function = [&errline]() -> Status { + Status error_status = absl::InvalidArgumentError(message); + errline = __LINE__ + 1; + XLA_RETURN_IF_ERROR(error_status, new_message); + return absl::OkStatus(); + }; - auto test_function = [&err_line]() -> absl::StatusOr { - absl::StatusOr status_or = absl::InvalidArgumentError(message); - err_line = __LINE__ + 1; + Status result = test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), + StrCat("New test error message (at ", __FILE__, ":", errline, + ")\nFrom Error: Test error message")); +} + +TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithLocationWithError) { + int32_t errline = 0; + + auto test_function = [&errline]() -> Status { + Status error_status = absl::InvalidArgumentError(message); + errline = __LINE__ + 1; + XLA_RETURN_IF_ERROR_WITH_LOCATION(error_status); + return absl::OkStatus(); + }; + + Status result = test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), + StrCat("Test error message (at ", __FILE__, ":", errline, ")")); +} + +TEST(StatusWithErrorContextTest, MacroAssignOrReturn) { + int initial_value = 42; + int expected_value = initial_value * 2; + + auto test_function = [=]() -> StatusOr { + StatusOr status_or = initial_value; + XLA_ASSIGN_OR_RETURN(int value, status_or); + return value * 2; + }; + + StatusOr result = test_function(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result.value(), expected_value); +} + +TEST(StatusWithErrorContextTest, MacroAssignOrReturnWithError) { + auto test_function = []() -> StatusOr { + StatusOr status_or = absl::InvalidArgumentError(message); XLA_ASSIGN_OR_RETURN(int value, status_or); return value * 2; }; - absl::StatusOr result = test_function(); + StatusOr result = test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.status().code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().message(), std::string_view(message)); +} + +TEST(StatusWithErrorContextTest, MacroAssignOrReturnWithErrorWithNewMessage) { + int32_t errline = 0; + + auto test_function = [&errline]() -> StatusOr { + StatusOr status_or = absl::InvalidArgumentError(message); + errline = __LINE__ + 1; + XLA_ASSIGN_OR_RETURN(int value, status_or, new_message); + return value * 2; + }; + + StatusOr result = test_function(); ASSERT_FALSE(result.ok()); - EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ( - result.status().message(), - absl::StrCat("Test error message (at ", __FILE__, ":", err_line, ")")); + EXPECT_EQ(result.status().code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().message(), + StrCat("New test error message (at ", __FILE__, ":", errline, + ")\nFrom Error: Test error message")); +} + +TEST(StatusWithErrorContextTest, MacroAssignOrReturnWithLocationWithError) { + int32_t errline = 0; + + auto test_function = [&errline]() -> StatusOr { + StatusOr status_or = absl::InvalidArgumentError(message); + errline = __LINE__ + 1; + XLA_ASSIGN_OR_RETURN_WITH_LOCATION(int value, status_or); + return value * 2; + }; + + StatusOr result = test_function(); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.status().code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.status().message(), + StrCat("Test error message (at ", __FILE__, ":", errline, ")")); } TEST(StatusWithErrorContextTest, MacroErrorWithLocation) { - absl::Status error_status = absl::InvalidArgumentError(message); - int32_t err_line = __LINE__ + 1; - absl::Status result = XLA_ERROR_WITH_LOCATION(error_status); - ASSERT_NE(result, error_status); + Status error_status = absl::InvalidArgumentError(message); + int32_t errline = __LINE__ + 1; + Status result = XLA_ERROR_WITH_LOCATION(error_status); ASSERT_FALSE(result.ok()); - EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument); - EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", __FILE__, - ":", err_line, ")")); + EXPECT_EQ(result.code(), StatusCode::kInvalidArgument); + EXPECT_EQ(result.message(), + StrCat("Test error message (at ", __FILE__, ":", errline, ")")); } void SetUp() { diff --git a/test/test_operations.py b/test/test_operations.py index 7b52d6c9fa8d..2cb015c9211b 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2445,6 +2445,13 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) + def test_construct_large_tensor_raises_error(self): + a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device()) + + with self.assertRaisesRegex(RuntimeError, + r"Out of memory allocating \d* bytes"): + a.cpu() + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index f9039bd6df73..74f3440b7ee8 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -125,6 +125,7 @@ ptxla_cc_library( ":layout_manager", ":shape_builder", ":shape_helper", + ":status", ":version", "//torch_xla/csrc:hash_util", "//torch_xla/csrc:thread_pool", diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c52ca3c74911..e37aed5374d6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1229,9 +1229,9 @@ class PyLoweringContext { lowering_ctx.GetParametersData(); // Fetch this parameter data - std::vector literals = + std::vector literals = GetValueOrThrow( runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(device_data)); + UnwrapXlaData(device_data))); // Create a mapping from paramater id to the tensor data std::unordered_map results; diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 7444a1977fe3..86df04f27cf3 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -121,6 +121,7 @@ cc_library( ":tf_logging", ":xla_coordinator", "//torch_xla/csrc:status", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 9f2a04a5d348..2ed28e8c8033 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -16,6 +16,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/device.h" @@ -317,7 +318,7 @@ class ComputationClient { // Note: `TransferFromDevice` call will block until the `DataPtrs` are ready // if they were created by `TransferToDevice` or `Execute*`. Calling this from // python while holding the GIL can cause deadlocks! - virtual std::vector TransferFromDevice( + virtual absl::StatusOr> TransferFromDevice( absl::Span handles) = 0; virtual std::uintptr_t UnsafeBufferPointer(const DataPtr handle) = 0; @@ -345,7 +346,7 @@ class ComputationClient { // The passed device must match the common device of the arguments Data. // If options.explode_tuple is true, the output tuple will be decomposed into // its single elements. - virtual std::vector ExecuteComputation( + virtual absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options = @@ -356,7 +357,7 @@ class ComputationClient { // as `devices`. If options.explode_tuple is true, the output tuples will be // decomposed into their single elements. Returns a vector of outputs, each // of which is sharded in the same order as `devices`. - virtual std::vector ExecuteReplicated( + virtual absl::StatusOr> ExecuteReplicated( const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) = 0; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cpp b/torch_xla/csrc/runtime/ifrt_computation_client.cpp index 0aa79dcae431..295a3104ebc0 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" @@ -403,8 +404,8 @@ tsl::RCReference IfrtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = ExecuteReplicated(*computations.front(), {{handle}}, - GetLocalDevices(), execute_options); + auto sharded_results = GetValueOrThrow(ExecuteReplicated( + *computations.front(), {{handle}}, GetLocalDevices(), execute_options)); auto replicated_output = std::dynamic_pointer_cast(sharded_results[0]) ->buffer->FullyReplicatedShard( @@ -423,8 +424,8 @@ std::shared_ptr IfrtComputationClient::GetPjRtBuffer( XLA_ERROR() << __FUNCTION__ << " not implemented"; } -std::vector IfrtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +IfrtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("IfrtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -442,9 +443,9 @@ std::vector IfrtComputationClient::TransferFromDevice( auto& literal = literals.emplace_back( xla::ShapeUtil::DeviceShapeToHostShape(ifrt_data->shape())); std::vector byte_strides(literal.shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal.shape(), - absl::MakeSpan(byte_strides))); - XLA_CHECK_OK( + XLA_RETURN_IF_ERROR(xla::ShapeUtil::ByteStrides( + literal.shape(), absl::MakeSpan(byte_strides))); + XLA_RETURN_IF_ERROR( replicated_array ->CopyToHostBuffer(literal.untyped_data(), byte_strides, xla::ifrt::ArrayCopySemantics::kAlwaysCopy) @@ -524,16 +525,16 @@ std::vector IfrtComputationClient::Compile( return computations; } -std::vector +absl::StatusOr> IfrtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) { // TODO: Implement sharded exec in IFRT - XLA_ERROR() << __FUNCTION__ << " not implemented"; + return absl::UnimplementedError("ExecuteComputation not implemented"); } -std::vector +absl::StatusOr> IfrtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, const absl::Span arguments, @@ -578,11 +579,10 @@ IfrtComputationClient::ExecuteReplicated( TF_VLOG(5) << "ExecuteReplicated acquiring IFRT device lock for " << spmd_device_str << " Done"; - xla::ifrt::LoadedExecutable::ExecuteResult result = - ifrt_computation.executable - ->Execute(absl::MakeSpan(argument_handles), execute_options, - std::nullopt) - .value(); + XLA_ASSIGN_OR_RETURN_WITH_LOCATION( + xla::ifrt::LoadedExecutable::ExecuteResult result, + ifrt_computation.executable->Execute(absl::MakeSpan(argument_handles), + execute_options, std::nullopt)); result.status.OnReady(std::move([timed, op_tracker = std::move(op_tracker)]( absl::Status status) mutable { @@ -599,7 +599,7 @@ IfrtComputationClient::ExecuteReplicated( ? *ifrt_computation.output_shardings_ : std::vector(outputs.size(), xla::HloSharding::Replicate().ToProto()); - XLA_CHECK_EQ(output_shardings.size(), outputs.size()); + ABSL_CHECK_EQ(output_shardings.size(), outputs.size()); std::vector data_handles(outputs.size()); { diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index d43fa5bd0167..6bd0b792a34e 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -53,7 +53,7 @@ class IfrtComputationClient : public ComputationClient { XLA_ERROR() << __FUNCTION__ << " not implemented"; } - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; @@ -69,12 +69,12 @@ class IfrtComputationClient : public ComputationClient { std::vector Compile( std::vector instances) override; - std::vector ExecuteComputation( + absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) override; - std::vector ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( const Computation& computation, const absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp index 04636e0d06e9..b5c6d86902ec 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/ifrt_computation_client_test.cpp @@ -64,13 +64,14 @@ TEST(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device)}; // Execute the graph. - std::vector results = client->ExecuteReplicated( - *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), - {device}, options); + std::vector results = + GetValueOrThrow(client->ExecuteReplicated( + *computations[0], client->TransferToDevice(absl::MakeConstSpan(args)), + {device}, options)); // Copy the output from device back to host and assert correctness.. ASSERT_EQ(results.size(), 1); - auto result_literals = client->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cpp b/torch_xla/csrc/runtime/pjrt_computation_client.cpp index 2012db9ffa5a..72d633db438b 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cpp @@ -4,6 +4,7 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/strings/ascii.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" @@ -374,8 +375,8 @@ PjRtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; auto sharded_results = - ExecuteReplicated(*computations.front(), {sharded_data}, - GetLocalDevices(), execute_options); + GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data}, + GetLocalDevices(), execute_options)); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; XLA_CHECK(sharded_results.size() == 1) @@ -461,8 +462,8 @@ std::vector PjRtComputationClient::ReshardData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto resharded_results = ExecuteReplicated( - *computation, handles, GetLocalDevices(), execute_options); + auto resharded_results = GetValueOrThrow(ExecuteReplicated( + *computation, handles, GetLocalDevices(), execute_options)); return resharded_results; } @@ -496,8 +497,8 @@ std::shared_ptr PjRtComputationClient::GetPjRtBuffer( } } -std::vector PjRtComputationClient::TransferFromDevice( - absl::Span handles) { +absl::StatusOr> +PjRtComputationClient::TransferFromDevice(absl::Span handles) { metrics::TimedSection timed(TransferFromDeviceMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice", tsl::profiler::TraceMeLevel::kInfo); @@ -510,21 +511,18 @@ std::vector PjRtComputationClient::TransferFromDevice( // Use XLA replication to reassemble the sharded data. If input handle // is not sharded, then it is a no-op. std::shared_ptr pjrt_data = ReplicateShardedData(handle); - XLA_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; - XLA_CHECK(pjrt_data->buffer != nullptr) + ABSL_CHECK(pjrt_data) << "PjRt_data is null in " << __FUNCTION__; + ABSL_CHECK(pjrt_data->buffer != nullptr) << "PjRt buffer is null in " << __FUNCTION__; - xla::Literal& literal = - literals.emplace_back(host_output_shape(pjrt_data->buffer.get())); + xla::Literal& literal = literals.emplace_back( + xla::Literal(host_output_shape(pjrt_data->buffer.get()), + /* allocate_arrays= */ false)); futures.push_back(pjrt_data->buffer->ToLiteral(&literal)); total_size += literal.size_bytes(); } - for (auto& future : futures) { - absl::Status status = future.Await(); - XLA_CHECK_OK(status) << "Failed to await future from buffer to literal in" - << __FUNCTION__; - } + XLA_RETURN_IF_ERROR_WITH_LOCATION(xla::JoinFutures(futures).Await()); InboundDataMetric()->AddSample(total_size); return literals; @@ -713,7 +711,7 @@ torch::lazy::hash_t PjRtComputationClient::HashCompilationEnv() { return comp_env_hash_; } -std::vector +absl::StatusOr> PjRtComputationClient::ExecuteComputation( const ComputationClient::Computation& computation, absl::Span arguments, @@ -733,14 +731,14 @@ PjRtComputationClient::ExecuteComputation( dynamic_cast(computation); xla::PjRtDevice* pjrt_device = StringToPjRtDevice(device); - XLA_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); + ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); std::vector buffers; buffers.reserve(arguments.size()); for (auto& argument : arguments) { const PjRtData* pjrt_data = dynamic_cast(argument.get()); - XLA_CHECK(pjrt_device == pjrt_data->buffer->device()) + ABSL_CHECK(pjrt_device == pjrt_data->buffer->device()) << "The device currently being used : " << pjrt_device->DebugString() << " is different from the device where the buffer resides: " << pjrt_data->buffer->device()->DebugString(); @@ -760,11 +758,10 @@ PjRtComputationClient::ExecuteComputation( << " Done"; std::optional> returned_future; - std::vector> results = - pjrt_computation.executable - ->ExecuteSharded(buffers, pjrt_device, execute_options, - returned_future) - .value(); + XLA_ASSIGN_OR_RETURN_WITH_LOCATION( + std::vector> results, + pjrt_computation.executable->ExecuteSharded( + buffers, pjrt_device, execute_options, returned_future)); returned_future->OnReady(std::move( [timed, op_tracker = std::move(op_tracker)](absl::Status unused) mutable { @@ -788,7 +785,7 @@ PjRtComputationClient::ExecuteComputation( return datas; } -std::vector +absl::StatusOr> PjRtComputationClient::ExecuteReplicated( const ComputationClient::Computation& computation, absl::Span arguments, @@ -822,15 +819,15 @@ PjRtComputationClient::ExecuteReplicated( for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); - XLA_CHECK_EQ(pjrt_data->shards.size(), devices.size()) + ABSL_CHECK_EQ(pjrt_data->shards.size(), devices.size()) << "Expected one shard per device"; for (int32_t d = 0; d < devices.size(); d++) { std::shared_ptr shard = pjrt_data->shards[d]; xla::PjRtDevice* pjrt_device = StringToPjRtDevice(devices[d]); - XLA_CHECK_EQ(shard->buffer->device(), pjrt_device); - XLA_CHECK(pjrt_device->IsAddressable()) + ABSL_CHECK_EQ(shard->buffer->device(), pjrt_device); + ABSL_CHECK(pjrt_device->IsAddressable()) << pjrt_device->DebugString(); argument_handles[d][i] = shard->buffer.get(); @@ -866,10 +863,10 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMe activity( "PjRtComputationClient::ExecuteReplicated_execute", tsl::profiler::TraceMeLevel::kInfo); - results = pjrt_computation.executable - ->Execute(std::move(argument_handles), execute_options, - returned_futures) - .value(); + XLA_ASSIGN_OR_RETURN_WITH_LOCATION( + results, + pjrt_computation.executable->Execute( + std::move(argument_handles), execute_options, returned_futures)); (*returned_futures)[0].OnReady( std::move([timed, op_tracker = std::move(op_tracker)]( @@ -892,7 +889,7 @@ PjRtComputationClient::ExecuteReplicated( const std::vector& output_shapes = result_shape.IsTuple() ? result_shape.tuple_shapes() : std::vector({result_shape}); - XLA_CHECK_EQ(output_shapes.size(), num_outputs); + ABSL_CHECK_EQ(output_shapes.size(), num_outputs); const std::vector& output_shardings = pjrt_computation.output_shardings_.has_value() && num_outputs > 0 @@ -901,7 +898,7 @@ PjRtComputationClient::ExecuteReplicated( // Without an explicit sharding annotation, the output is implicitly // replicated, and we mark explicitly replicated here. std::vector(num_outputs); - XLA_CHECK_EQ(output_shardings.size(), num_outputs); + ABSL_CHECK_EQ(output_shardings.size(), num_outputs); absl::BlockingCounter counter(num_outputs); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b845d73dadba..c16ef6a5cb94 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -56,7 +56,7 @@ class PjRtComputationClient : public ComputationClient { absl::Span handles, absl::Span shardings) override; - std::vector TransferFromDevice( + absl::StatusOr> TransferFromDevice( absl::Span handles) override; std::uintptr_t UnsafeBufferPointer(const DataPtr handle) override; @@ -76,12 +76,12 @@ class PjRtComputationClient : public ComputationClient { ComputationPtr DeserializeComputation(const std::string& serialized) override; - std::vector ExecuteComputation( + absl::StatusOr> ExecuteComputation( const Computation& computation, absl::Span arguments, const std::string& device, const ExecuteComputationOptions& options) override; - std::vector ExecuteReplicated( + absl::StatusOr> ExecuteReplicated( const Computation& computation, absl::Span arguments, absl::Span devices, const ExecuteReplicatedOptions& options) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp index 5caa057240c3..1683f77dbea5 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cpp @@ -115,13 +115,15 @@ TEST_F(PjRtComputationClientTest, Init) { std::make_shared(std::move(literal_y), device_)}; // Execute the graph. - std::vector results = client_->ExecuteComputation( - *computations[0], client_->TransferToDevice(absl::MakeConstSpan(args)), - device_, options); + std::vector results = + GetValueOrThrow(client_->ExecuteComputation( + *computations[0], + client_->TransferToDevice(absl::MakeConstSpan(args)), device_, + options)); // Copy the output from device back to host and assert correctness. ASSERT_EQ(results.size(), 1); - auto result_literals = client_->TransferFromDevice(results); + auto result_literals = GetValueOrThrow(client_->TransferFromDevice(results)); ASSERT_THAT(result_literals, ::testing::SizeIs(1)); EXPECT_TRUE(xla::LiteralTestUtil::Equal( xla::LiteralUtil::CreateR2({{6.0f, 8.0f}, {10.0f, 12.0f}}), diff --git a/torch_xla/csrc/status.cpp b/torch_xla/csrc/status.cpp index afca4590fdc9..0a92b9aec3f9 100644 --- a/torch_xla/csrc/status.cpp +++ b/torch_xla/csrc/status.cpp @@ -63,13 +63,11 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file, // // This should give more context for developers. Showing the older error // messages alongside their debug information. - std::string context; - if (ShouldShowCppErrorContext()) { - context = LocationStrWithSpace(file, line); - if (!new_message.empty()) { - context = absl::StrCat(context, "\nFrom Error: ", old_message); - } - } + std::string location = LocationStrWithSpace(file, line); + std::string context = + (ShouldShowCppErrorContext() && !new_message.empty()) + ? std::string(absl::StrCat(location, "\nFrom Error: ", old_message)) + : std::string(); return absl::Status(status.code(), absl::StrCat(message, context)); } diff --git a/torch_xla/csrc/status.h b/torch_xla/csrc/status.h index 43c4759be0b3..33536654a448 100644 --- a/torch_xla/csrc/status.h +++ b/torch_xla/csrc/status.h @@ -39,21 +39,30 @@ namespace torch_xla { // Unique identifier for the status variable for the current line. #define XLA_STATUS_VAR_ XLA_CONCAT_(status_, __LINE__) +// Fake wrapper to `status`. +// +// This is used in place of `XLA_ERROR_WITH_LOCATION`, whenever we don't +// want to append source code location information to the error message, +// e.g. `XLA_RETURN_IF_ERROR` and `XLA_ASSIGN_OR_RETURN`. +#define XLA_NO_WRAP_(status) status + // Provides a flexible way to handle error checking with optional message // modification. It evaluates `expr`, checks if it's OK, and either: // 1. Returns early with an error status (potentially modified by the provided // additional messages) // 2. Proceeds with the given `then` block if successful -#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \ - auto var = (expr); \ - if (!var.ok()) { \ - return ::torch_xla::MaybeWithNewMessage( \ - ::torch_xla::GetStatus(var), __FILE__, __LINE__, ##__VA_ARGS__); \ - } \ +#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, wrapper, then, ...) \ + auto var = (expr); \ + if (!var.ok()) { \ + return wrapper(::torch_xla::MaybeWithNewMessage( \ + ::torch_xla::GetStatus(var), __FILE__, __LINE__, ##__VA_ARGS__)); \ + } \ then // Propagates `rexpr`, in case it's a non-ok status. // +// This macro should be used for propagating status internally. +// // Example: // // XLA_RETURN_IF_ERROR( @@ -69,14 +78,37 @@ namespace torch_xla { // Previous error message. (at :) // ... // -#define XLA_RETURN_IF_ERROR(rexpr, ...) \ - do { \ - XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, {}, ##__VA_ARGS__) \ +#define XLA_RETURN_IF_ERROR(rexpr, ...) \ + do { \ + XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, XLA_NO_WRAP_, {}, \ + ##__VA_ARGS__) \ + } while (false) + +// Propagates `rexpr`, in case it's a non-ok status, appending the source code +// location to it. +// +// Note that while the macro above will append the source code location only if +// a new message is given, this macro will append the source code location if +// `XLA_SHOW_CPP_ERROR_CONTEXT` is set. +// +// This macro should be used whenever we are propagating some status that came +// from some external library. +// +// Example: +// +// XLA_RETURN_IF_ERROR_WITH_LOCATION(FnThatReturnsStatus()); +// +#define XLA_RETURN_IF_ERROR_WITH_LOCATION(rexpr) \ + do { \ + XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, XLA_ERROR_WITH_LOCATION, \ + {}) \ } while (false) // Propagates `rexpr`, in case it's a non-ok status. Otherwise, assign // its result to `lhs`. // +// This macro should be used for propagating status internally. +// // Note 1: `lhs` might be a variable declarate, e.g: // // Note 2: this macro will be replaced by multiple statements that live on @@ -100,10 +132,31 @@ namespace torch_xla { // ... // #define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \ - XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, \ + XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, XLA_NO_WRAP_, \ lhs = std::move(XLA_STATUS_VAR_).value(), \ ##__VA_ARGS__) +// Propagates `rexpr`, in case it's a non-ok status. Otherwise, assign +// its result to `lhs`. +// +// Note that while the macro above will append the source code location only if +// a new message is given, this macro will append the source code location if +// `XLA_SHOW_CPP_ERROR_CONTEXT` is set. +// +// This macro should be used whenever we are propagating some status that came +// from some external library. +// +// Example: +// +// XLA_ASSIGN_OR_RETURN_WITH_LOCATION( +// int result, +// FnThatReturnsStatus(), +// ); +// +#define XLA_ASSIGN_OR_RETURN_WITH_LOCATION(lhs, rexpr) \ + XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, XLA_ERROR_WITH_LOCATION, \ + lhs = std::move(XLA_STATUS_VAR_).value()) + // Maybe shows location information in the status message. // // This function assumes that `status` is a non-ok status. diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 0a7f184cda77..e2cd3a025f59 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -24,6 +24,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla/csrc/xla_backend_impl.h" @@ -909,8 +910,8 @@ std::vector ReleaseGilAndTransferData( save = PyEval_SaveThread(); } std::vector literals = - runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(xla_data)); + GetValueOrThrow(runtime::GetComputationClientOrDie()->TransferFromDevice( + UnwrapXlaData(xla_data))); if (save) { PyEval_RestoreThread(save); } diff --git a/torch_xla/csrc/xla_backend_impl.cpp b/torch_xla/csrc/xla_backend_impl.cpp index bf130e1fab73..06e6510cc1a1 100644 --- a/torch_xla/csrc/xla_backend_impl.cpp +++ b/torch_xla/csrc/xla_backend_impl.cpp @@ -10,6 +10,7 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" namespace at { // This function is defined in the codegenerated RegisterDispatchKey.cpp file. @@ -161,11 +162,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface { torch::lazy::ComputationPtr computation, c10::ArrayRef arguments, const torch::lazy::BackendDevice& device) const override { - std::vector results = + std::vector results = GetValueOrThrow( runtime::GetComputationClientOrDie()->ExecuteComputation( *std::dynamic_pointer_cast( computation), - UnwrapXlaData(arguments), device.toString()); + UnwrapXlaData(arguments), device.toString())); return WrapXlaData(results); } diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 031700fd9b55..ca58b1a7c403 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -52,6 +52,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" +#include "torch_xla/csrc/status.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/thread_pool.h" #include "torch_xla/csrc/torch_util.h" @@ -843,10 +844,11 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TF_VLOG(3) << "Executing Dynamo IR sharded graph hash " << torch::lazy::HashToString(hash) << " on devices " @@ -940,8 +942,8 @@ std::vector XLAGraphExecutor::ExecuteStablehlo( } std::vector result_data = - runtime::GetComputationClientOrDie()->ExecuteComputation( - *computations[0], UnwrapXlaData(arguments), device.toString()); + GetValueOrThrow(runtime::GetComputationClientOrDie()->ExecuteComputation( + *computations[0], UnwrapXlaData(arguments), device.toString())); return WrapXlaData(result_data); } @@ -1117,10 +1119,11 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( // tensor results. Both sharded and unsharded results should be // "Assign"ed to the corresponding data placeholders. std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteReplicated( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), devices, - execute_options); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteReplicated( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), devices, + execute_options)); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteReplicated", 1); TF_VLOG(3) << "Executing IR graph hash " @@ -1132,11 +1135,13 @@ XLAGraphExecutor::ScheduleSyncTensorsGraph( << torch::lazy::HashToString(hash) << " on device " << async->device << " ..."; std::vector outputs = - runtime::GetComputationClientOrDie()->ExecuteComputation( - *async->cached_computation->computation, - UnwrapXlaData(async->parameters_data), async->device.toString(), - {/*explode_tuple=*/true, - /*eager_mode=*/use_eager_mode}); + GetValueOrThrow( + runtime::GetComputationClientOrDie()->ExecuteComputation( + *async->cached_computation->computation, + UnwrapXlaData(async->parameters_data), + async->device.toString(), + {/*explode_tuple=*/true, + /*eager_mode=*/use_eager_mode})); results = WrapXlaData(outputs); TORCH_LAZY_COUNTER("ExecuteComputation", 1); TF_VLOG(3) << "Executing IR graph hash "