diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 371f2d83084f..283112f578f4 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -82,6 +82,18 @@ if [[ "$BAZEL_VERB" == "coverage" ]]; then EXTRA_FLAGS="$EXTRA_FLAGS --remote_download_outputs=all" # for lcov symlink fi +# Forward PJRT_DEVICE and CPU_NUM_DEVICES to bazel test environment. +# Set sensible defaults when not provided so tests run reproducibly. +: "${PJRT_DEVICE:=CPU}" +: "${CPU_NUM_DEVICES:=2}" +export PJRT_DEVICE CPU_NUM_DEVICES +if [[ -n "${PJRT_DEVICE}" ]]; then + EXTRA_FLAGS="$EXTRA_FLAGS --test_env=PJRT_DEVICE=${PJRT_DEVICE}" +fi +if [[ -n "${CPU_NUM_DEVICES}" ]]; then + EXTRA_FLAGS="$EXTRA_FLAGS --test_env=CPU_NUM_DEVICES=${CPU_NUM_DEVICES}" +fi + test_names=("all") if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then test_names=("test_aten_xla_tensor_1" diff --git a/test/cpp/test_xla_generator.cpp b/test/cpp/test_xla_generator.cpp index d45991f72d39..ebf90e93b457 100644 --- a/test/cpp/test_xla_generator.cpp +++ b/test/cpp/test_xla_generator.cpp @@ -1,6 +1,9 @@ +#include #include #include +#include + #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/xla_generator.h" @@ -102,5 +105,122 @@ TEST_F(XLAGeneratorTest, Clone) { ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed()); } +TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) { + // Test getting default generator for device 0 + auto result = at::detail::GetDefaultXLAGenerator(0); + ASSERT_TRUE(result.ok()) << "Failed to get default generator: " + << result.status(); + + const at::Generator& default_gen = result.value(); + ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen.device().index(), 0); + + // Test getting default generator with -1 (should default to device 0) + auto result_default = at::detail::GetDefaultXLAGenerator(-1); + ASSERT_TRUE(result_default.ok()) + << "Failed to get default generator with -1: " << result_default.status(); + + const at::Generator& default_gen_neg1 = result_default.value(); + ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen_neg1.device().index(), 0); + ASSERT_EQ(default_gen, default_gen_neg1); + + // Test that subsequent calls return the same generator instance + auto result2 = at::detail::GetDefaultXLAGenerator(0); + ASSERT_TRUE(result2.ok()); + const at::Generator& default_gen2 = result2.value(); + ASSERT_EQ(default_gen, default_gen2); + + // Test getting non-defuault device generator + auto result_device1 = at::detail::GetDefaultXLAGenerator(1); + ASSERT_TRUE(result_device1.ok()) + << "Failed to get default generator for device 1: " + << result_device1.status(); + + const at::Generator& default_gen_device1 = result_device1.value(); + ASSERT_EQ(default_gen_device1.device().type(), at::DeviceType::XLA); + ASSERT_EQ(default_gen_device1.device().index(), 1); + ASSERT_NE(default_gen_device1, default_gen); +} + +TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) { + // Test with invalid device indices + auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2); + ASSERT_FALSE(result_neg2.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + ASSERT_THAT(result_neg2.status().message(), + testing::HasSubstr("Invalid XLA device index")); + + // Test with very large device index (assuming there aren't 1000 XLA devices) + auto result_large = at::detail::GetDefaultXLAGenerator(100); + ASSERT_FALSE(result_large.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); + ASSERT_THAT(result_large.status().message(), + testing::HasSubstr("Invalid XLA device index")); +} + +TEST_F(XLAGeneratorTest, CreateXLAGenerator) { + // Test creating generator for device 1 + auto result = at::detail::CreateXLAGenerator(1); + ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status(); + + at::Generator created_gen = result.value(); + ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA); + ASSERT_EQ(created_gen.device().index(), 1); + + // Test that the generator is initialized with default seed + ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val); + + // Test creating generator with -1 (should use current device) + auto result_current = at::detail::CreateXLAGenerator(-1); + ASSERT_TRUE(result_current.ok()) + << "Failed to create generator with -1: " << result_current.status(); + + at::Generator created_gen_neg1 = result_current.value(); + ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA); + // Device index should be >= 0 (actual device depends on current XLA device) + ASSERT_GE(created_gen_neg1.device().index(), 0); +} + +TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) { + // Test that each call creates a new generator instance + auto result1 = at::detail::CreateXLAGenerator(0); + auto result2 = at::detail::CreateXLAGenerator(0); + + ASSERT_TRUE(result1.ok()); + ASSERT_TRUE(result2.ok()); + + at::Generator gen1 = result1.value(); + at::Generator gen2 = result2.value(); + + // Should be different instances (compare generators, not their stack + // addresses) + ASSERT_NE(gen1, gen2); + + // But should have same device and initial seed + ASSERT_EQ(gen1.device(), gen2.device()); + ASSERT_EQ(gen1.current_seed(), gen2.current_seed()); + + // Modifying one should not affect the other + gen1.set_current_seed(12345); + ASSERT_NE(gen1.current_seed(), gen2.current_seed()); +} + +TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) { + // Test with invalid device indices + auto result_neg2 = at::detail::CreateXLAGenerator(-2); + ASSERT_FALSE(result_neg2.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status())); + ASSERT_THAT(result_neg2.status().message(), + testing::HasSubstr("Invalid XLA device index")); + + // Test with very large device index (assuming there aren't 100 XLA devices) + auto result_large = at::detail::CreateXLAGenerator(100); + ASSERT_FALSE(result_large.ok()); + ASSERT_TRUE(absl::IsInvalidArgument(result_large.status())); + ASSERT_THAT(result_large.status().message(), + testing::HasSubstr("Invalid XLA device index")); +} + } // namespace cpp_test -} // namespace torch_xla \ No newline at end of file +} // namespace torch_xla diff --git a/torch_xla/csrc/xla_generator.cpp b/torch_xla/csrc/xla_generator.cpp index 5d0a7c15866b..e86f7de3448d 100644 --- a/torch_xla/csrc/xla_generator.cpp +++ b/torch_xla/csrc/xla_generator.cpp @@ -5,10 +5,115 @@ #include #include #include +#include #include +#include #include #include +#include +#include + +#include "absl/status/status.h" +#include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/runtime.h" +#include "torch_xla/csrc/status.h" + +namespace at { + +namespace detail { + +namespace { + +// Total number of XLA devices in the system. +static int64_t num_xla_devices; + +// Ensures default_gens_xla is initialized once. +static std::deque xla_gens_init_flag; + +// Default, global XLA generators, one per XLA device. +static std::vector default_gens_xla; + +/* + * Populates the global variables related to XLA generators + * Warning: this function must only be called once! + */ +static absl::Status InitXLAGenVector() { + static const absl::Status* init_status = new absl::Status([]() { + XLA_ASSIGN_OR_RETURN(auto c_client, + torch_xla::runtime::GetComputationClient()); + num_xla_devices = static_cast(c_client->GetNumDevices()); + xla_gens_init_flag.resize(num_xla_devices); + default_gens_xla.resize(num_xla_devices); + return absl::OkStatus(); + }()); + return *init_status; +} + +// Validates and normalizes an XLA device index. +// If requested_index == -1, the current device index is used. +// Returns InvalidArgument if the resolved index is out of range. +static absl::StatusOr NormalizeXLADeviceIndex( + c10::DeviceIndex requested_index) { + c10::DeviceIndex idx = requested_index; + if (idx == -1) { + idx = torch_xla::bridge::GetCurrentAtenDevice().index(); + } + if (idx < 0 || idx >= num_xla_devices) { + return absl::InvalidArgumentError( + "Invalid device index for XLA generator. Provided index: " + + std::to_string(idx)); + } + return idx; +} + +} // anonymous namespace + +/** + * PyTorch maintains a collection of default generators that get + * initialized once. The purpose of these default generators is to + * maintain a global running state of the pseudo random number generation, + * when a user does not explicitly mention any generator. + * GetDefaultXLAGenerator gets the default generator for a particular + * XLA device. + */ +absl::StatusOr GetDefaultXLAGenerator( + c10::DeviceIndex device_index) { + XLA_RETURN_IF_ERROR(InitXLAGenVector(), + "Failed to initialize XLA generators"); + // Normalize and validate the target device index; default to current device + // when unspecified + XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx, + NormalizeXLADeviceIndex(device_index), + "Invalid XLA device index"); + c10::call_once(xla_gens_init_flag[idx], [&] { + default_gens_xla[idx] = at::make_generator(idx); + default_gens_xla[idx].seed(); + }); + return default_gens_xla[idx]; +} + +/** + * Utility to create a XLAGeneratorImpl. Returns a shared_ptr + */ +absl::StatusOr CreateXLAGenerator( + c10::DeviceIndex device_index) { + XLA_RETURN_IF_ERROR(InitXLAGenVector(), + "Failed to initialize XLA generators"); + // Normalize and validate the target device index; default to current device + // when unspecified + XLA_ASSIGN_OR_RETURN(c10::DeviceIndex idx, + NormalizeXLADeviceIndex(device_index), + "Invalid XLA device index"); + auto gen = at::make_generator(idx); + auto xla_gen = at::check_generator(gen); + xla_gen->set_current_seed(c10::default_rng_seed_val); + return gen; +} + +} // namespace detail +} // namespace at namespace at { diff --git a/torch_xla/csrc/xla_generator.h b/torch_xla/csrc/xla_generator.h index 330d32861200..8001737e795c 100644 --- a/torch_xla/csrc/xla_generator.h +++ b/torch_xla/csrc/xla_generator.h @@ -2,10 +2,17 @@ #include #include +#include +#include +#include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" + namespace at { // Holds the actual state variables for the XLA generator. @@ -53,4 +60,13 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl { c10::intrusive_ptr state_; }; -} // namespace at \ No newline at end of file +namespace detail { + +absl::StatusOr GetDefaultXLAGenerator( + c10::DeviceIndex device_index = -1); +absl::StatusOr CreateXLAGenerator( + c10::DeviceIndex device_index = -1); + +} // namespace detail + +} // namespace at