Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion test/cpp/test_xla_generator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <gtest/gtest.h>
#include <torch/torch.h>

#include <cstdlib>

#include "test/cpp/torch_xla_test.h"
#include "torch_xla/csrc/xla_generator.h"

Expand All @@ -18,6 +20,20 @@ class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest {
at::Generator gen_;
};

// Ensure PJRT is configured to a CPU backend for tests that touch the PJRT
// runtime.
static void EnsurePjrtCpuBackend() {
const char* pjrt = std::getenv("PJRT_DEVICE");
if (pjrt == nullptr || pjrt[0] == '\0') {
// Use CPU backend with a single device by default.
setenv("PJRT_DEVICE", "CPU", 1);
}
const char* cpu_devices = std::getenv("CPU_NUM_DEVICES");
if (cpu_devices == nullptr || cpu_devices[0] == '\0') {
setenv("CPU_NUM_DEVICES", "1", 0);
}
}

TEST_F(XLAGeneratorTest, Constructor) {
// Check that the generator was created for the correct device
ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA);
Expand Down Expand Up @@ -102,5 +118,106 @@ TEST_F(XLAGeneratorTest, Clone) {
ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed());
}

TEST_F(XLAGeneratorTest, GetDefaultXLAGenerator) {
EnsurePjrtCpuBackend();
// Test getting default generator for device 0
auto result = at::detail::GetDefaultXLAGenerator(0);
ASSERT_TRUE(result.ok()) << "Failed to get default generator: "
<< result.status();

const at::Generator& default_gen = result.value();
ASSERT_EQ(default_gen.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen.device().index(), 0);

// Test getting default generator with -1 (should default to device 0)
auto result_default = at::detail::GetDefaultXLAGenerator(-1);
ASSERT_TRUE(result_default.ok())
<< "Failed to get default generator with -1: " << result_default.status();

const at::Generator& default_gen_neg1 = result_default.value();
ASSERT_EQ(default_gen_neg1.device().type(), at::DeviceType::XLA);
ASSERT_EQ(default_gen_neg1.device().index(), 0);

// Test that subsequent calls return the same generator instance
auto result2 = at::detail::GetDefaultXLAGenerator(0);
ASSERT_TRUE(result2.ok());
const at::Generator& default_gen2 = result2.value();
ASSERT_EQ(std::addressof(default_gen), std::addressof(default_gen2));
}

TEST_F(XLAGeneratorTest, GetDefaultXLAGeneratorInvalidDevice) {
EnsurePjrtCpuBackend();
// Test with invalid device indices
auto result_neg2 = at::detail::GetDefaultXLAGenerator(-2);
ASSERT_FALSE(result_neg2.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));

// Test with very large device index (assuming there aren't 1000 XLA devices)
auto result_large = at::detail::GetDefaultXLAGenerator(1000);
ASSERT_FALSE(result_large.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
}

TEST_F(XLAGeneratorTest, CreateXLAGenerator) {
EnsurePjrtCpuBackend();
// Test creating generator for device 0
auto result = at::detail::CreateXLAGenerator(0);
ASSERT_TRUE(result.ok()) << "Failed to create generator: " << result.status();

at::Generator created_gen = result.value();
ASSERT_EQ(created_gen.device().type(), at::DeviceType::XLA);
ASSERT_EQ(created_gen.device().index(), 0);

// Test that the generator is initialized with default seed
ASSERT_EQ(created_gen.current_seed(), c10::default_rng_seed_val);

// Test creating generator with -1 (should use current device)
auto result_default = at::detail::CreateXLAGenerator(-1);
ASSERT_TRUE(result_default.ok())
<< "Failed to create generator with -1: " << result_default.status();

at::Generator created_gen_neg1 = result_default.value();
ASSERT_EQ(created_gen_neg1.device().type(), at::DeviceType::XLA);
// Device index should be >= 0 (actual device depends on current XLA device)
ASSERT_GE(created_gen_neg1.device().index(), 0);
}

TEST_F(XLAGeneratorTest, CreateXLAGeneratorUniqueness) {
EnsurePjrtCpuBackend();
// Test that each call creates a new generator instance
auto result1 = at::detail::CreateXLAGenerator(0);
auto result2 = at::detail::CreateXLAGenerator(0);

ASSERT_TRUE(result1.ok());
ASSERT_TRUE(result2.ok());

at::Generator gen1 = result1.value();
at::Generator gen2 = result2.value();

// Should be different instances
ASSERT_NE(std::addressof(gen1), std::addressof(gen2));

// But should have same device and initial seed
ASSERT_EQ(gen1.device(), gen2.device());
ASSERT_EQ(gen1.current_seed(), gen2.current_seed());

// Modifying one should not affect the other
gen1.set_current_seed(12345);
ASSERT_NE(gen1.current_seed(), gen2.current_seed());
}

TEST_F(XLAGeneratorTest, CreateXLAGeneratorInvalidDevice) {
EnsurePjrtCpuBackend();
// Test with invalid device indices
auto result_neg2 = at::detail::CreateXLAGenerator(-2);
ASSERT_FALSE(result_neg2.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_neg2.status()));

// Test with very large device index (assuming there aren't 1000 XLA devices)
auto result_large = at::detail::CreateXLAGenerator(1000);
ASSERT_FALSE(result_large.ok());
ASSERT_TRUE(absl::IsInvalidArgument(result_large.status()));
}

} // namespace cpp_test
} // namespace torch_xla
} // namespace torch_xla
94 changes: 94 additions & 0 deletions torch_xla/csrc/xla_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,104 @@
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/CallOnce.h>
#include <c10/util/intrusive_ptr.h>

#include <cstring>
#include <deque>
#include <vector>

#include "absl/status/status.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/runtime.h"
#include "torch_xla/csrc/status.h"

namespace at {

namespace detail {

namespace {

// Total number of XLA devices in the system.
static int64_t num_xla_devices;

// Ensures default_gens_xla is initialized once.
static std::deque<c10::once_flag> xla_gens_init_flag;

// Default, global XLA generators, one per XLA device.
static std::vector<at::Generator> default_gens_xla;

/*
* Populates the global variables related to XLA generators
* Warning: this function must only be called once!
*/
static void InitXLAGenVector() {
// Ensures we only call deviceCount only once.
static bool num_xla_device_init_flag [[maybe_unused]] = []() {
// Get local num of XLA devices
XLA_ASSIGN_OR_THROW(auto c_client,
torch_xla::runtime::GetComputationClient());
num_xla_devices = static_cast<int64_t>(c_client->GetNumDevices());
xla_gens_init_flag.resize(num_xla_devices);
default_gens_xla.resize(num_xla_devices);
return true;
}();
}

} // anonymous namespace

/**
* PyTorch maintains a collection of default generators that get
* initialized once. The purpose of these default generators is to
* maintain a global running state of the pseudo random number generation,
* when a user does not explicitly mention any generator.
* GetDefaultXLAGenerator gets the default generator for a particular
* XLA device.
*/
absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
c10::DeviceIndex device_index) {
InitXLAGenVector();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = 0; // Default to device 0 for XLA
} else if (idx < -1 || idx >= num_xla_devices) {
return absl::InvalidArgumentError(
"Invalid device index for XLA generator. Provided index: " +
std::to_string(idx));
}
c10::call_once(xla_gens_init_flag[idx], [&] {
default_gens_xla[idx] = at::make_generator<XLAGeneratorImpl>(idx);
default_gens_xla[idx].seed();
});
return default_gens_xla[idx];
}

/**
* Utility to create a XLAGeneratorImpl. Returns a shared_ptr
*/
absl::StatusOr<at::Generator> CreateXLAGenerator(
c10::DeviceIndex device_index) {
InitXLAGenVector();
c10::DeviceIndex idx = device_index;
if (idx == -1) {
idx = torch_xla::bridge::GetCurrentAtenDevice()
.index(); // Use current XLA device
} else if (idx < -1 || idx >= num_xla_devices) {
return absl::InvalidArgumentError(
"Invalid device index for XLA generator. Provided index: " +
std::to_string(idx));
}
auto gen = at::make_generator<XLAGeneratorImpl>(idx);
auto xla_gen = at::check_generator<XLAGeneratorImpl>(gen);
xla_gen->set_current_seed(c10::default_rng_seed_val);
return gen;
}

} // namespace detail
} // namespace at

namespace at {

Expand Down
18 changes: 17 additions & 1 deletion torch_xla/csrc/xla_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@

#include <ATen/core/Generator.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/GeneratorImpl.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/intrusive_ptr.h>

#include <cstdint>

#include "absl/status/status.h"
#include "absl/status/statusor.h"

namespace at {

// Holds the actual state variables for the XLA generator.
Expand Down Expand Up @@ -53,4 +60,13 @@ struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
c10::intrusive_ptr<XLAGeneratorState> state_;
};

} // namespace at
namespace detail {

absl::StatusOr<const at::Generator&> GetDefaultXLAGenerator(
c10::DeviceIndex device_index = -1);
absl::StatusOr<at::Generator> CreateXLAGenerator(
c10::DeviceIndex device_index = -1);

} // namespace detail

} // namespace at