Skip to content

Commit

Permalink
Try to remove push work from ttnn
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Feb 19, 2025
1 parent c17e35a commit cdb9eef
Show file tree
Hide file tree
Showing 12 changed files with 1 addition and 851 deletions.
45 changes: 0 additions & 45 deletions tests/ttnn/unit_tests/gtests/tensor/common_tensor_test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,49 +10,4 @@

namespace test_utils {

void test_tensor_on_device(const ttnn::Shape& input_shape, const TensorLayout& layout, tt::tt_metal::IDevice* device) {
using namespace tt::tt_metal;

const ttnn::QueueId io_cq = ttnn::DefaultQueueId;

const auto input_buf_size_bytes = layout.compute_packed_buffer_size_bytes(input_shape);
const auto host_buffer_datum_size_bytes = sizeof(uint32_t);
const auto input_buf_size = input_buf_size_bytes / host_buffer_datum_size_bytes;

auto host_data = std::make_shared<uint32_t[]>(input_buf_size);
auto readback_data = std::make_shared<uint32_t[]>(input_buf_size);

const auto random_prime_number = 4051;
for (int i = 0; i < input_buf_size; i++) {
host_data[i] = i % random_prime_number;
}

auto tensor = tt::tt_metal::create_device_tensor(TensorSpec(input_shape, layout), device);
ttnn::queue_synchronize(device->command_queue(*io_cq));

ttnn::write_buffer(io_cq, tensor, {host_data});
ttnn::queue_synchronize(device->command_queue(*io_cq));

ttnn::read_buffer(io_cq, tensor, {readback_data});
ttnn::queue_synchronize(device->command_queue(*io_cq));

for (int i = 0; i < input_buf_size; i++) {
EXPECT_EQ(host_data[i], readback_data[i]);
if (host_data[i] != readback_data[i]) {
break;
}
}

EXPECT_EQ(tensor.get_padded_shape(), layout.compute_padded_shape(input_shape));
tensor.deallocate();
}

void test_tensor_on_device(const ttnn::Shape& input_shape, const tt::tt_metal::TensorLayout& layout) {
tt::tt_metal::IDevice* device = tt::tt_metal::CreateDevice(0);

test_tensor_on_device(input_shape, layout, device);

tt::tt_metal::CloseDevice(device);
}

} // namespace test_utils
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,3 @@

#include "ttnn/tensor/layout/tensor_layout.hpp"
#include "ttnn/tensor/tensor.hpp"

namespace test_utils {
void test_tensor_on_device(
const ttnn::Shape& input_shape, const tt::tt_metal::TensorLayout& layout, tt::tt_metal::IDevice* device);
void test_tensor_on_device(const ttnn::Shape& input_shape, const tt::tt_metal::TensorLayout& layout);
} // namespace test_utils
72 changes: 0 additions & 72 deletions tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,6 @@

namespace {

void run_create_tensor_test(tt::tt_metal::IDevice* device, const ttnn::Shape& input_shape) {
MemoryConfig mem_cfg = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED,
.buffer_type = BufferType::DRAM,
.shard_spec = std::nullopt};

const ttnn::QueueId io_cq = ttnn::DefaultQueueId;
constexpr DataType dtype = DataType::BFLOAT16;
constexpr uint32_t datum_size_bytes = 2;

auto input_buf_size_datums = input_shape.volume();

auto host_data = std::shared_ptr<uint16_t[]>(new uint16_t[input_buf_size_datums]);
auto readback_data = std::shared_ptr<uint16_t[]>(new uint16_t[input_buf_size_datums]);

for (int i = 0; i < input_buf_size_datums; i++) {
host_data[i] = 1;
}

TensorSpec tensor_spec(input_shape, TensorLayout(dtype, PageConfig(Layout::TILE), mem_cfg));
ASSERT_EQ(input_buf_size_datums * datum_size_bytes, tensor_spec.compute_packed_buffer_size_bytes());
auto input_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(device, tensor_spec);

auto input_storage = tt::tt_metal::DeviceStorage{input_buffer};

Tensor input_tensor = Tensor(input_storage, input_shape, dtype, Layout::TILE);

ttnn::write_buffer(io_cq, input_tensor, {host_data});

ttnn::read_buffer(io_cq, input_tensor, {readback_data});

for (int i = 0; i < input_buf_size_datums; i++) {
EXPECT_EQ(host_data[i], readback_data[i]);
}

input_tensor.deallocate();
}

struct CreateTensorParams {
ttnn::Shape shape;
};
Expand All @@ -65,11 +27,6 @@ struct CreateTensorParams {
class CreateTensorTest : public ttnn::TTNNFixtureWithDevice,
public ::testing::WithParamInterface<CreateTensorParams> {};

TEST_P(CreateTensorTest, Tile) {
CreateTensorParams params = GetParam();
run_create_tensor_test(device_, params.shape);
}

INSTANTIATE_TEST_SUITE_P(
CreateTensorTestWithShape,
CreateTensorTest,
Expand All @@ -90,35 +47,6 @@ using CombinationInputParams =
class EmptyTensorTest : public ttnn::TTNNFixtureWithDevice,
public ::testing::WithParamInterface<CombinationInputParams> {};

TEST_P(EmptyTensorTest, Combinations) {
auto params = GetParam();
auto shape = std::get<0>(params);
auto dtype = std::get<1>(params);
auto layout = std::get<2>(params);
auto memory_config = std::get<3>(params);
tt::log_info(
"Running test with shape={}, dtype={}, layout={}, memory_config={}", shape, dtype, layout, memory_config);

if (layout == tt::tt_metal::Layout::ROW_MAJOR && dtype == tt::tt_metal::DataType::BFLOAT8_B) {
GTEST_SKIP() << "Skipping test with ROW_MAJOR layout and BFLOAT8_B dtype!";
}

auto tensor_layout = tt::tt_metal::TensorLayout::fromPaddedShape(
dtype, PageConfig(layout), memory_config, /* logical */ shape, /* padded */ shape);

// Ignoring too large single bank allocations
if (memory_config.memory_layout == TensorMemoryLayout::SINGLE_BANK) {
if (tensor_layout.compute_page_size_bytes(shape) >= 500 * 1024) {
GTEST_SKIP() << "Skipping test with page size exceeding single bank size of 500 kB!";
}
}

auto tensor = tt::tt_metal::create_device_tensor(shape, dtype, layout, device_, memory_config);
EXPECT_EQ(tensor.get_logical_shape(), shape);

test_utils::test_tensor_on_device(shape, tensor_layout, device_);
}

INSTANTIATE_TEST_SUITE_P(
EmptyTensorTestWithShape,
EmptyTensorTest,
Expand Down
22 changes: 0 additions & 22 deletions tests/ttnn/unit_tests/gtests/tensor/test_tensor_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,6 @@ struct TensorLayoutTestParams {

class TensorLayoutComputeTests : public ::testing::TestWithParam<TensorLayoutTestParams> {};

TEST_P(TensorLayoutComputeTests, TensorLayout_Generic) {
const auto& params = GetParam();
TensorLayout layout(params.inputs.data_type, PageConfig(params.inputs.layout), DefaultMemoryConfig);

EXPECT_EQ(layout.get_alignment(), params.expected.alignment);
EXPECT_EQ(layout.compute_physical_shape(params.inputs.shape), params.expected.physical_size);
EXPECT_EQ(layout.compute_strides(params.inputs.shape), params.expected.strides);

if (params.expected.tensor_creation_works) {
test_utils::test_tensor_on_device(params.inputs.shape, layout);
}
}

INSTANTIATE_TEST_SUITE_P(
TensorLayoutTests,
TensorLayoutComputeTests,
Expand Down Expand Up @@ -162,15 +149,6 @@ struct LegacyPaddingRoundtripTestParams {

class TensorLayoutLegacyPaddingRoundtipTests : public ::testing::TestWithParam<LegacyPaddingRoundtripTestParams> {};

TEST_P(TensorLayoutLegacyPaddingRoundtipTests, Tensor_LagacyPaddingRoundtrip) {
const auto& params = GetParam();
TensorLayout layout = TensorLayout::fromPaddedShape(
DataType::BFLOAT16, Layout::ROW_MAJOR, DefaultMemoryConfig, params.shape, params.padded_shape);
EXPECT_EQ(layout.compute_padded_shape(params.shape), params.padded_shape);

test_utils::test_tensor_on_device(params.shape, layout);
}

INSTANTIATE_TEST_SUITE_P(
TensorLayoutTests,
TensorLayoutLegacyPaddingRoundtipTests,
Expand Down
11 changes: 0 additions & 11 deletions tests/ttnn/unit_tests/gtests/tensor/test_tensor_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,17 +756,6 @@ class CreateShardedTensorWithAlignmentTests
: public ttnn::TTNNFixtureWithDevice,
public ::testing::WithParamInterface<CreateShardedTensorWithAlignmentParams> {};

TEST_P(CreateShardedTensorWithAlignmentTests, AllocateTensor) {
const auto& params = GetParam();
const auto& input_shape = params.inputs.shape;

TensorLayout layout(params.inputs.data_type, params.inputs.page_config, params.inputs.memory_config);

test_utils::test_tensor_on_device(input_shape, layout, device_);

EXPECT_EQ(layout.compute_physical_shape(input_shape), params.expected.physical_shape);
}

INSTANTIATE_TEST_SUITE_P(
TensorShardingTests,
CreateShardedTensorWithAlignmentTests,
Expand Down
142 changes: 0 additions & 142 deletions tests/ttnn/unit_tests/gtests/test_async_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,148 +19,6 @@ namespace {

using MultiCommandQueueSingleDeviceFixture = ::ttnn::MultiCommandQueueSingleDeviceFixture;

TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) {
IDevice* device = this->device_;
MemoryConfig mem_cfg = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED,
.buffer_type = BufferType::DRAM,
.shard_spec = std::nullopt};

uint32_t input_buf_size_datums = 1024 * 1024;
uint32_t output_buf_size_datums = 1024 * 32;
uint32_t datum_size_bytes = 2;
ttnn::QueueId io_cq = ttnn::QueueId(1); // Data reads and writes done through CQ0
ttnn::QueueId workload_dispatch_cq = ttnn::QueueId(0); // Workload dispatched through CQ1

ttnn::Shape input_shape({1, 1, 1024, 1024});
auto host_data = std::shared_ptr<bfloat16[]>(new bfloat16[input_buf_size_datums]);
auto readback_data = std::shared_ptr<bfloat16[]>(new bfloat16[output_buf_size_datums]);

for (int i = 0; i < input_buf_size_datums; i++) {
host_data[i] = bfloat16(static_cast<float>(1));
}
// Create golden data using tt_eager APIs
Tensor np_tensor = ttnn::full(input_shape, static_cast<float>(1), DataType::BFLOAT16, Layout::TILE, *device_);
ttnn::SmallVector<int64_t> reduce_dims = {3};
Tensor np_out = ttnn::moreh_sum(np_tensor, reduce_dims, false, std::nullopt, std::nullopt, std::nullopt);
Tensor np_out_host = np_out.cpu();
const bfloat16* golden_output =
std::get<owned_buffer::Buffer<bfloat16>>(std::get<OwnedStorage>(np_out_host.get_storage()).buffer).begin();
// Enable Asynchronous Execution and test ttnn runtime APIs
device_->enable_async(true);
// Events for host - device synchronization
auto write_event = std::make_shared<Event>();
auto workload_event = std::make_shared<Event>();
// Running sum-reduce with preallocated output
// Preallocate Input and Output Tensors on Device
tt_metal::TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg);
ASSERT_EQ(input_buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(input_shape));
ASSERT_EQ(
output_buf_size_datums * datum_size_bytes,
tensor_layout.compute_packed_buffer_size_bytes(np_out.get_padded_shape()));
auto input_buffer =
tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, TensorSpec(input_shape, tensor_layout));
auto output_buffer = tt::tt_metal::tensor_impl::allocate_buffer_on_device(
device_, TensorSpec(np_out.get_padded_shape(), tensor_layout));
auto input_storage = tt::tt_metal::DeviceStorage{input_buffer};
auto output_storage = tt::tt_metal::DeviceStorage{output_buffer};
Tensor input_tensor = Tensor(
input_storage,
TensorSpec(input_shape, TensorLayout(DataType::BFLOAT16, PageConfig(Layout::TILE), MemoryConfig{})));
Tensor output_tensor = Tensor(output_storage, np_out.get_logical_shape(), DataType::BFLOAT16, Layout::TILE);
// Populate input_tensor with data
ttnn::write_buffer(io_cq, input_tensor, {host_data});
// Record the completion of the write event
ttnn::record_event(device_->command_queue(*io_cq), write_event);
// Host stalls until write is completed, before sending workload
ttnn::event_synchronize(write_event);
EXPECT_EQ(ttnn::event_query(write_event), true);
// Dispatch workload. Preallocated output_tensor is populated by op/
ttnn::moreh_sum(input_tensor, /*dim*/ 3, false, output_tensor, std::nullopt, std::nullopt);
// Record completion of workload
ttnn::record_event(device_->command_queue(*workload_dispatch_cq), workload_event);
ttnn::event_synchronize(workload_event);
EXPECT_EQ(ttnn::event_query(workload_event), true);
// Read output back, once workload is complete
ttnn::read_buffer(io_cq, output_tensor, {readback_data});
// Ensure that reference count book keeping is done correctly
// Tensors only have one reference in the main thread. Ensure this is true.
EXPECT_EQ(input_tensor.tensor_attributes->main_thread_ref_count, 1);
EXPECT_EQ(output_tensor.tensor_attributes->main_thread_ref_count, 1);
// Buffers are currently jointly owned by the original buffer object, the storage object and the tensor (3).
EXPECT_EQ(input_buffer.use_count(), 3);
EXPECT_EQ(output_buffer.use_count(), 3);
// Deallocate tensors (tensor gives up buffer). Done asynchronously, so sync on queue after.
input_tensor.deallocate();
output_tensor.deallocate();
ttnn::queue_synchronize(device_->command_queue(*io_cq));
// Buffer only has 2 owners in main thread.
EXPECT_EQ(input_buffer.use_count(), 2);
EXPECT_EQ(output_buffer.use_count(), 2);
for (int i = 0; i < output_buf_size_datums; i++) {
EXPECT_EQ(readback_data[i], golden_output[i]);
}
}

TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeAllocatedBuffers) {
device_->enable_async(true);
MemoryConfig mem_cfg = MemoryConfig{
.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED,
.buffer_type = BufferType::DRAM,
.shard_spec = std::nullopt};

uint32_t buf_size_datums = 1024 * 1024;
uint32_t datum_size_bytes = 2;
std::vector<uint32_t> inputs = {4, 9, 16, 25, 36, 64};
ttnn::QueueId io_cq = ttnn::QueueId(1);
ttnn::QueueId workload_dispatch_cq = ttnn::QueueId(0);
ttnn::Shape shape{1, 1, 1024, 1024};

auto host_data = std::shared_ptr<bfloat16[]>(new bfloat16[buf_size_datums]);
auto readback_data = std::shared_ptr<bfloat16[]>(new bfloat16[buf_size_datums]);
for (int loop = 0; loop < 10; loop++) {
log_info(LogTest, "Running outer loop {}", loop);
for (auto input_val : inputs) {
for (int i = 0; i < buf_size_datums; i++) {
host_data[i] = bfloat16(static_cast<float>(input_val));
}

auto write_event = std::make_shared<Event>();
auto workload_event = std::make_shared<Event>();
TensorLayout tensor_layout(DataType::BFLOAT16, PageConfig(Layout::TILE), mem_cfg);
ASSERT_EQ(buf_size_datums * datum_size_bytes, tensor_layout.compute_packed_buffer_size_bytes(shape));
auto input_buffer =
tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, TensorSpec(shape, tensor_layout));
auto input_storage = tt::tt_metal::DeviceStorage{input_buffer};
Tensor input_tensor = Tensor(input_storage, shape, DataType::BFLOAT16, Layout::TILE);
ttnn::write_buffer(io_cq, input_tensor, {host_data}); // Write using cq 1
ttnn::record_event(device_->command_queue(*io_cq), write_event); // Record write on cq 1
// Wait until cq 1 write is complete
ttnn::wait_for_event(device_->command_queue(*workload_dispatch_cq), write_event);

// Run operation on cq 0
Tensor output_tensor = ttnn::sqrt(workload_dispatch_cq, input_tensor);
auto dummy_buffer_0 =
tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, TensorSpec(shape, tensor_layout));
output_tensor = ttnn::neg(workload_dispatch_cq, output_tensor);
// Allocate this buffer to stress test async allocation across op execution and explicit allocation
auto dummy_buffer_1 =
tt::tt_metal::tensor_impl::allocate_buffer_on_device(device_, TensorSpec(shape, tensor_layout));
// Record cq 0 prog execution
ttnn::record_event(device_->command_queue(*workload_dispatch_cq), workload_event);
// Wait until cq 0 prog execution is done
ttnn::wait_for_event(device_->command_queue(*io_cq), workload_event);
// Read using cq 1
ttnn::read_buffer(io_cq, output_tensor, {readback_data});
for (int i = 0; i < buf_size_datums; i++) {
EXPECT_EQ(
static_cast<int>(std::floor(bfloat16(readback_data[i]).to_float())),
static_cast<int>(-1 * sqrt(input_val)));
}
}
}
}

TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeBufferDestructor) {
// Test functionality for the buffer destructor, which will call deallocate asynchronously
// We must ensure that the deallocate step, which can run after the buffer has been destroyed
Expand Down
Loading

0 comments on commit cdb9eef

Please sign in to comment.