Skip to content

[cherry-pick to r2.6_aws_neuron] add torch_xla_graph_execution_check_level (default disabled) flag that emits warning(1) or throw error(2) during tensor sync and output the python frame #9077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: r2.6_aws_neuron
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ function run_xla_op_tests1 {
run_dynamic "$CDIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY
run_eager_debug "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py"
run_pt_xla_debug_level1 "$CDIR/debug_tool/test_pt_xla_debug.py"
Expand Down
75 changes: 75 additions & 0 deletions test/test_xla_graph_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Parse local options first, and rewrite the sys.argv[].
# We need to do that before import "common", as otherwise we get an error for
# unrecognized arguments.
import argparse
import os
import sys
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import unittest
import test_utils
import time

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--verbosity', type=int, default=0)
FLAGS, leftovers = parser.parse_known_args()
sys.argv = [sys.argv[0]] + leftovers
print(FLAGS)

XLA_DISABLE_FUNCTIONALIZATION = bool(
os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False))


class TestXlaGraphExecutionCheckLevel(test_utils.XlaTestCase):

def test_graph_execution_check_level_disabled(self):
# Test disabled checking
print("Test check level disabled.")
torch_xla._XLAC._set_torch_xla_graph_execution_check_level(0)
start_time = time.time()
x = torch.ones(2, device=xm.xla_device())
self.assertEqual(x[0], 1.0) # This should trigger the checking

torch_xla._XLAC._set_torch_xla_graph_execution_check_level(3)
start_time = time.time()
x = torch.ones(2, device=xm.xla_device())
self.assertEqual(x[0], 1.0) # This should trigger the checking
print("--- %s seconds ---" % (time.time() - start_time))
del x

def test_graph_execution_check_level_warning(self):
# Test WARNING level
print("Test check level as warning.")
torch_xla._XLAC._set_torch_xla_graph_execution_check_level(1)
start_time = time.time()
x = torch.ones(2, device=xm.xla_device())
self.assertEqual(x[0], 1.0) # This should trigger the checking
print("--- %s seconds ---" % (time.time() - start_time))
del x

def test_graph_execution_check_level_error(self):
# Test ERROR level
print(
"Test check level as runtime error with warning messages before that.")
torch_xla._XLAC._set_torch_xla_graph_execution_check_level(2)
start_time = time.time()
x = torch.ones(2, device=xm.xla_device())
with self.assertRaises(RuntimeError) as e:
self.assertEqual(x[0], 1.0) # This should trigger the checking
print("--- %s seconds ---" % (time.time() - start_time))
del x
print(
"--- Timers are added for reference. However, the 1st test runs slower due to memory initialization ---"
)


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(met.metrics_report())
sys.exit(0 if test.result.wasSuccessful() else 1)
19 changes: 19 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ ptxla_cc_library(
":shape_builder",
":shape_helper",
":version",
":config",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:stablehlo_helper",
"//torch_xla/csrc/runtime:xla_util",
Expand Down Expand Up @@ -177,6 +178,22 @@ ptxla_cc_library(
],
)

ptxla_cc_library(
name = "config",
srcs = ["config.cpp"],
hdrs = ["config.h"],
deps = [
"//torch_xla/csrc/runtime:tf_logging",
"//torch_xla/csrc/runtime:debug_macros",
"//torch_xla/csrc/runtime:sys_util",
"//torch_xla/csrc/runtime:util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@xla//xla/client:xla_builder",
],
)


ptxla_cc_library(
name = "dtype",
srcs = ["dtype.cpp"],
Expand Down Expand Up @@ -260,6 +277,7 @@ ptxla_cc_library(
":dtype",
":tensor",
":version",
":config",
"//torch_xla/csrc/runtime",
"//torch_xla/csrc/runtime:pjrt_computation_client",
"//torch_xla/csrc/runtime:metrics",
Expand Down Expand Up @@ -348,3 +366,4 @@ ptxla_cc_library(
"@pybind11//:pybind11_embed",
],
)

8 changes: 8 additions & 0 deletions torch_xla/csrc/config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// torch_xla/csrc/config.cpp
#include "torch_xla/csrc/config.h"

#include <c10/util/Flags.h>

C10_DEFINE_int(torch_xla_graph_execution_check_level, -1,
"set torch xla tensor graph execution check level, specify <= 0 "
"(DISABLED), 1 (WARN), 2 (ERROR), >2 (DISABLED)");
9 changes: 9 additions & 0 deletions torch_xla/csrc/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// config.h
#ifndef XLA_TORCH_XLA_CSRC_CONFIG_H_
#define XLA_TORCH_XLA_CSRC_CONFIG_H_

#include <c10/util/Flags.h>

C10_DECLARE_int(torch_xla_graph_execution_check_level);

#endif // XLA_TORCH_XLA_CSRC_CONFIG_H_
8 changes: 8 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "torch_xla/csrc/aten_autograd_ops.h"
#include "torch_xla/csrc/aten_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/config.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/dl_convertor.h"
#include "torch_xla/csrc/dtype.h"
Expand Down Expand Up @@ -2575,6 +2576,13 @@ void InitXlaModuleBindings(py::module m) {
});
m.def("_get_xla_enable_device_data_cache",
[]() { return FLAGS_torch_lazy_enable_device_data_cache; });
m.def("_set_torch_xla_graph_execution_check_level",
[](int torch_xla_graph_execution_check_level) {
FLAGS_torch_xla_graph_execution_check_level =
torch_xla_graph_execution_check_level;
});
m.def("_get_torch_xla_graph_execution_check_level",
[]() { return FLAGS_torch_xla_graph_execution_check_level; });
m.def("_set_use_eager_mode", [](bool use_eager_mode) {
XLAGraphExecutor::Get()->SetUseEagerMode(use_eager_mode);
});
Expand Down
41 changes: 41 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <Python.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/hash.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/ir_util.h>
Expand All @@ -10,6 +11,7 @@
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/unique.h>
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/python/python_util.h>

#include <algorithm>
#include <cmath>
Expand All @@ -28,6 +30,7 @@
#include "absl/strings/str_join.h"
#include "stablehlo/dialect/Serialization.h" // from @stablehlo
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/config.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ir_dump_util.h"
Expand Down Expand Up @@ -478,6 +481,44 @@ std::vector<at::Tensor> XLAGraphExecutor::GetTensors(
std::vector<XLATensorPtr>* tensors) {
TF_VLOG(4) << "Trying to get the value of " << tensors->size()
<< " tensor(s)";
if (FLAGS_torch_xla_graph_execution_check_level > 0 and
FLAGS_torch_xla_graph_execution_check_level <= 2) {
// Add Python stack trace information
auto frames = torch::lazy::GetPythonFrames();
if (!frames.empty()) {
TF_LOG(WARNING) << "Python call stack:";
for (auto& location : frames) {
TF_LOG(WARNING) << " " << location.function << " (" << location.file
<< ":" << location.line << ")";
}
} else {
TF_LOG(WARNING) << "Python frame is empty ...";
}

// print tensor shape and type
TF_VLOG(4) << "Printing tensor shape and type for investigation ...";

std::stringstream ss;
ss << "[" << __FILE__ << ":" << __LINE__ << "] Tensors:\n";

for (size_t i = 0; i < tensors->size(); ++i) {
const auto& tensor = (*tensors)[i];
ss << " [" << i << "]: shape=" << tensor->shape()
<< ", dtype=" << tensor->dtype() << "\n";
}

TF_LOG(WARNING) << ss.str();

if (FLAGS_torch_xla_graph_execution_check_level == 1) {
TF_LOG(WARNING)
<< "Trying to get the value of tensor(s): Use the tensor value "
"during tracing may lead to unexpected behavior.";
} else {
XLA_ERROR() << "Trying to get the value of tensor(s): Use the tensor "
"value during tracing may lead to unexpected behavior.";
}
}

SyncTensorsConfig config;
config.force_ltc_data = false;
auto async = SyncTensorsGraphInternal(tensors, {}, config);
Expand Down