From 7ca505dcbb25097eb355b6cb3d0e9960134f11a2 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Apr 2025 22:12:20 +0000 Subject: [PATCH] add torch_xla_graph_execution_check_level (default disabled) flag that emits warning(1) or throw error(2) during tensor sync and output the python frame --- test/run_tests.sh | 1 + test/test_xla_graph_execution.py | 75 +++++++++++++++++++++++++ torch_xla/csrc/BUILD | 19 +++++++ torch_xla/csrc/config.cpp | 8 +++ torch_xla/csrc/config.h | 9 +++ torch_xla/csrc/init_python_bindings.cpp | 8 +++ torch_xla/csrc/xla_graph_executor.cpp | 41 ++++++++++++++ 7 files changed, 161 insertions(+) create mode 100644 test/test_xla_graph_execution.py create mode 100644 torch_xla/csrc/config.cpp create mode 100644 torch_xla/csrc/config.h diff --git a/test/run_tests.sh b/test/run_tests.sh index 874f80a26746..f56e3d23517f 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -151,6 +151,7 @@ function run_xla_op_tests1 { run_dynamic "$CDIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY run_eager_debug "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + run_test "$CDIR/test_xla_graph_execution.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py" run_pt_xla_debug_level1 "$CDIR/debug_tool/test_pt_xla_debug.py" diff --git a/test/test_xla_graph_execution.py b/test/test_xla_graph_execution.py new file mode 100644 index 000000000000..627d2af16bc3 --- /dev/null +++ b/test/test_xla_graph_execution.py @@ -0,0 +1,75 @@ +# Parse local options first, and rewrite the sys.argv[]. +# We need to do that before import "common", as otherwise we get an error for +# unrecognized arguments. +import argparse +import os +import sys +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.utils.utils as xu +import unittest +import test_utils +import time + +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument('--verbosity', type=int, default=0) +FLAGS, leftovers = parser.parse_known_args() +sys.argv = [sys.argv[0]] + leftovers +print(FLAGS) + +XLA_DISABLE_FUNCTIONALIZATION = bool( + os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False)) + + +class TestXlaGraphExecutionCheckLevel(test_utils.XlaTestCase): + + def test_graph_execution_check_level_disabled(self): + # Test disabled checking + print("Test check level disabled.") + torch_xla._XLAC._set_torch_xla_graph_execution_check_level(0) + start_time = time.time() + x = torch.ones(2, device=xm.xla_device()) + self.assertEqual(x[0], 1.0) # This should trigger the checking + + torch_xla._XLAC._set_torch_xla_graph_execution_check_level(3) + start_time = time.time() + x = torch.ones(2, device=xm.xla_device()) + self.assertEqual(x[0], 1.0) # This should trigger the checking + print("--- %s seconds ---" % (time.time() - start_time)) + del x + + def test_graph_execution_check_level_warning(self): + # Test WARNING level + print("Test check level as warning.") + torch_xla._XLAC._set_torch_xla_graph_execution_check_level(1) + start_time = time.time() + x = torch.ones(2, device=xm.xla_device()) + self.assertEqual(x[0], 1.0) # This should trigger the checking + print("--- %s seconds ---" % (time.time() - start_time)) + del x + + def test_graph_execution_check_level_error(self): + # Test ERROR level + print( + "Test check level as runtime error with warning messages before that.") + torch_xla._XLAC._set_torch_xla_graph_execution_check_level(2) + start_time = time.time() + x = torch.ones(2, device=xm.xla_device()) + with self.assertRaises(RuntimeError) as e: + self.assertEqual(x[0], 1.0) # This should trigger the checking + print("--- %s seconds ---" % (time.time() - start_time)) + del x + print( + "--- Timers are added for reference. However, the 1st test runs slower due to memory initialization ---" + ) + + +if __name__ == '__main__': + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + torch_xla._XLAC._xla_set_mat_mul_precision('highest') + test = unittest.main(verbosity=FLAGS.verbosity, exit=False) + if xu.getenv_as('METRICS_DEBUG', bool, defval=False): + print(met.metrics_report()) + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 1287ffbde986..eeec2045d0eb 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -126,6 +126,7 @@ ptxla_cc_library( ":shape_builder", ":shape_helper", ":version", + ":config", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:stablehlo_helper", "//torch_xla/csrc/runtime:xla_util", @@ -177,6 +178,22 @@ ptxla_cc_library( ], ) +ptxla_cc_library( + name = "config", + srcs = ["config.cpp"], + hdrs = ["config.h"], + deps = [ + "//torch_xla/csrc/runtime:tf_logging", + "//torch_xla/csrc/runtime:debug_macros", + "//torch_xla/csrc/runtime:sys_util", + "//torch_xla/csrc/runtime:util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@xla//xla/client:xla_builder", + ], +) + + ptxla_cc_library( name = "dtype", srcs = ["dtype.cpp"], @@ -260,6 +277,7 @@ ptxla_cc_library( ":dtype", ":tensor", ":version", + ":config", "//torch_xla/csrc/runtime", "//torch_xla/csrc/runtime:pjrt_computation_client", "//torch_xla/csrc/runtime:metrics", @@ -348,3 +366,4 @@ ptxla_cc_library( "@pybind11//:pybind11_embed", ], ) + diff --git a/torch_xla/csrc/config.cpp b/torch_xla/csrc/config.cpp new file mode 100644 index 000000000000..3409f5cda017 --- /dev/null +++ b/torch_xla/csrc/config.cpp @@ -0,0 +1,8 @@ +// torch_xla/csrc/config.cpp +#include "torch_xla/csrc/config.h" + +#include + +C10_DEFINE_int(torch_xla_graph_execution_check_level, -1, + "set torch xla tensor graph execution check level, specify <= 0 " + "(DISABLED), 1 (WARN), 2 (ERROR), >2 (DISABLED)"); diff --git a/torch_xla/csrc/config.h b/torch_xla/csrc/config.h new file mode 100644 index 000000000000..369544cdd8dc --- /dev/null +++ b/torch_xla/csrc/config.h @@ -0,0 +1,9 @@ +// config.h +#ifndef XLA_TORCH_XLA_CSRC_CONFIG_H_ +#define XLA_TORCH_XLA_CSRC_CONFIG_H_ + +#include + +C10_DECLARE_int(torch_xla_graph_execution_check_level); + +#endif // XLA_TORCH_XLA_CSRC_CONFIG_H_ diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 04dcbf526ed0..417be762a45b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -37,6 +37,7 @@ #include "torch_xla/csrc/aten_autograd_ops.h" #include "torch_xla/csrc/aten_fallback.h" #include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/config.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/dl_convertor.h" #include "torch_xla/csrc/dtype.h" @@ -2575,6 +2576,13 @@ void InitXlaModuleBindings(py::module m) { }); m.def("_get_xla_enable_device_data_cache", []() { return FLAGS_torch_lazy_enable_device_data_cache; }); + m.def("_set_torch_xla_graph_execution_check_level", + [](int torch_xla_graph_execution_check_level) { + FLAGS_torch_xla_graph_execution_check_level = + torch_xla_graph_execution_check_level; + }); + m.def("_get_torch_xla_graph_execution_check_level", + []() { return FLAGS_torch_xla_graph_execution_check_level; }); m.def("_set_use_eager_mode", [](bool use_eager_mode) { XLAGraphExecutor::Get()->SetUseEagerMode(use_eager_mode); }); diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 4d5d7935b0bf..57cefc269024 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -10,6 +11,7 @@ #include #include #include +#include #include #include @@ -28,6 +30,7 @@ #include "absl/strings/str_join.h" #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/config.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir_dump_util.h" @@ -478,6 +481,44 @@ std::vector XLAGraphExecutor::GetTensors( std::vector* tensors) { TF_VLOG(4) << "Trying to get the value of " << tensors->size() << " tensor(s)"; + if (FLAGS_torch_xla_graph_execution_check_level > 0 and + FLAGS_torch_xla_graph_execution_check_level <= 2) { + // Add Python stack trace information + auto frames = torch::lazy::GetPythonFrames(); + if (!frames.empty()) { + TF_LOG(WARNING) << "Python call stack:"; + for (auto& location : frames) { + TF_LOG(WARNING) << " " << location.function << " (" << location.file + << ":" << location.line << ")"; + } + } else { + TF_LOG(WARNING) << "Python frame is empty ..."; + } + + // print tensor shape and type + TF_VLOG(4) << "Printing tensor shape and type for investigation ..."; + + std::stringstream ss; + ss << "[" << __FILE__ << ":" << __LINE__ << "] Tensors:\n"; + + for (size_t i = 0; i < tensors->size(); ++i) { + const auto& tensor = (*tensors)[i]; + ss << " [" << i << "]: shape=" << tensor->shape() + << ", dtype=" << tensor->dtype() << "\n"; + } + + TF_LOG(WARNING) << ss.str(); + + if (FLAGS_torch_xla_graph_execution_check_level == 1) { + TF_LOG(WARNING) + << "Trying to get the value of tensor(s): Use the tensor value " + "during tracing may lead to unexpected behavior."; + } else { + XLA_ERROR() << "Trying to get the value of tensor(s): Use the tensor " + "value during tracing may lead to unexpected behavior."; + } + } + SyncTensorsConfig config; config.force_ltc_data = false; auto async = SyncTensorsGraphInternal(tensors, {}, config);