Skip to content

Commit

Permalink
#16364: proper address calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuang-tt committed Feb 6, 2025
1 parent b093594 commit e93160a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 23 deletions.
31 changes: 26 additions & 5 deletions tt_metal/api/tt-metalium/command_queue_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ enum class CommandQueueDeviceAddrType : uint8_t {
COMPLETION_Q1_LAST_EVENT = 5,
DISPATCH_S_SYNC_SEM = 6,
DISPATCH_MESSAGE = 7,
UNRESERVED = 8
KERNEL_DEBUG_STATUS = 8,
UNRESERVED = 9,
};

enum class CommandQueueHostAddrType : uint8_t {
ISSUE_Q_RD = 0,
ISSUE_Q_WR = 1,
COMPLETION_Q_WR = 2,
COMPLETION_Q_RD = 3,
UNRESERVED = 4
UNRESERVED = 4,
};

//
Expand Down Expand Up @@ -82,13 +83,22 @@ class DispatchMemMap {
return instance;
}

uint32_t prefetch_buffer_base() const { return prefetch_buffer_base_; }

uint32_t prefetch_q_entries() const { return settings.prefetch_q_entries_; }

uint32_t prefetch_q_size() const { return settings.prefetch_q_size_; }

uint32_t max_prefetch_command_size() const { return settings.prefetch_max_cmd_size_; }

uint32_t cmddat_q_base() const { return cmddat_q_base_; }
template <bool is_prefetch_d>
uint32_t cmddat_q_base() const {
if constexpr (is_prefetch_d) {
return cmddat_q_base_d_variant_;
} else {
return cmddat_q_base_;
}
}

uint32_t cmddat_q_size() const { return settings.prefetch_cmddat_q_size_; }

Expand Down Expand Up @@ -174,6 +184,9 @@ class DispatchMemMap {
device_cq_addr_sizes_[dev_addr_idx] = settings.dispatch_s_sync_sem_;
} else if (dev_addr_type == CommandQueueDeviceAddrType::DISPATCH_MESSAGE) {
device_cq_addr_sizes_[dev_addr_idx] = settings.dispatch_message_;
} else if (dev_addr_type == CommandQueueDeviceAddrType::KERNEL_DEBUG_STATUS) {
// May be 0
device_cq_addr_sizes_[dev_addr_idx] = settings.kernel_debug_status_enable_;
} else {
device_cq_addr_sizes_[dev_addr_idx] = settings.other_ptrs_size;
}
Expand All @@ -193,8 +206,14 @@ class DispatchMemMap {
uint32_t prefetch_dispatch_unreserved_base =
device_cq_addrs_[tt::utils::underlying_type<CommandQueueDeviceAddrType>(
CommandQueueDeviceAddrType::UNRESERVED)];
cmddat_q_base_ = prefetch_dispatch_unreserved_base + round_size(settings.prefetch_q_size_, pcie_alignment);
scratch_db_base_ = cmddat_q_base_ + round_size(settings.prefetch_cmddat_q_size_, pcie_alignment);

// Prefetcher: FetchQ | Cmddat | Scratch
// Dispatcher: Dispatch Buffer
prefetch_buffer_base_ = prefetch_dispatch_unreserved_base; // Already aligned from above
cmddat_q_base_d_variant_ = align(prefetch_buffer_base_ + settings.prefetch_d_buffer_size_, pcie_alignment);
cmddat_q_base_ = align(prefetch_buffer_base_ + settings.prefetch_q_size_, pcie_alignment);
scratch_db_base_ = align(cmddat_q_base_ + settings.prefetch_cmddat_q_size_, pcie_alignment);

dispatch_buffer_base_ = align(prefetch_dispatch_unreserved_base, 1 << DispatchSettings::DISPATCH_BUFFER_LOG_PAGE_SIZE);
dispatch_buffer_block_size_pages_ = settings.dispatch_pages_ / DispatchSettings::DISPATCH_BUFFER_SIZE_BLOCKS;
const uint32_t dispatch_cb_end = dispatch_buffer_base_ + settings.dispatch_size_;
Expand Down Expand Up @@ -223,7 +242,9 @@ class DispatchMemMap {
return {l1_base, l1_size};
}

uint32_t prefetch_buffer_base_;
uint32_t cmddat_q_base_;
uint32_t cmddat_q_base_d_variant_;
uint32_t scratch_db_base_;
uint32_t dispatch_buffer_base_;

Expand Down
8 changes: 8 additions & 0 deletions tt_metal/api/tt-metalium/dispatch_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class DispatchSettings {
uint32_t tunneling_buffer_size_;
uint32_t tunneling_buffer_pages_; // tunneling_buffer_size_ / PREFETCH_D_BUFFER_LOG_PAGE_SIZE

uint32_t kernel_debug_status_enable_{0}; // How much space reserved for debug status. 0 means disabled.

CoreType core_type_; // Which core this settings is for

bool operator==(const DispatchSettings& other) const {
Expand Down Expand Up @@ -278,6 +280,12 @@ class DispatchSettings {
return *this;
}

// Trivial setter for enabling dispatch kernel debug status
DispatchSettings& kernel_debug_status_enable(uint32_t size) {
this->kernel_debug_status_enable_ = size;
return *this;
}

// Returns a list of errors
std::vector<std::string> get_errors() const;

Expand Down
12 changes: 7 additions & 5 deletions tt_metal/impl/dispatch/kernel_config/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ void DispatchKernel::GenerateDependentConfigs() {
auto prefetch_kernel = dynamic_cast<PrefetchKernel*>(upstream_kernels_[0]);
TT_ASSERT(prefetch_kernel);
dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore();
dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id;
dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id;
dependent_config_.upstream_dispatch_cb_sem_id =
prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id.value();
dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id.value();

if (prefetch_kernel->GetStaticConfig().is_h_variant.value() &&
prefetch_kernel->GetStaticConfig().is_d_variant.value()) {
Expand Down Expand Up @@ -228,7 +229,7 @@ void DispatchKernel::GenerateDependentConfigs() {
dependent_config_.prefetch_h_noc_xy = tt::tt_metal::hal.noc_xy_encoding(
prefetch_h_kernel->GetVirtualCore().x, prefetch_h_kernel->GetVirtualCore().y);
dependent_config_.prefetch_h_local_downstream_sem_addr =
prefetch_h_kernel->GetStaticConfig().my_downstream_cb_sem_id;
prefetch_h_kernel->GetStaticConfig().my_downstream_cb_sem_id.value();
dependent_config_.downstream_cb_base = 0; // Unused
dependent_config_.downstream_cb_size = 0; // Unused
dependent_config_.downstream_cb_sem_id = 0; // Unused
Expand All @@ -238,8 +239,9 @@ void DispatchKernel::GenerateDependentConfigs() {
auto prefetch_kernel = dynamic_cast<PrefetchKernel*>(upstream_kernels_[0]);
TT_ASSERT(prefetch_kernel);
dependent_config_.upstream_logical_core = prefetch_kernel->GetLogicalCore();
dependent_config_.upstream_dispatch_cb_sem_id = prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id;
dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id;
dependent_config_.upstream_dispatch_cb_sem_id =
prefetch_kernel->GetStaticConfig().my_downstream_cb_sem_id.value();
dependent_config_.upstream_sync_sem = prefetch_kernel->GetStaticConfig().downstream_sync_sem_id.value();

if (prefetch_kernel->GetStaticConfig().is_h_variant.value() &&
prefetch_kernel->GetStaticConfig().is_d_variant.value()) {
Expand Down
28 changes: 15 additions & 13 deletions tt_metal/impl/dispatch/kernel_config/prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <host_api.hpp>
#include <tt_metal.hpp>
#include "tt_align.hpp"
#include "tt_metal/impl/dispatch/dispatch_query_manager.hpp"

#include <tt-metalium/command_queue_interface.hpp>
Expand All @@ -27,23 +28,21 @@ void PrefetchKernel::GenerateStaticConfigs() {
uint32_t issue_queue_start_addr = command_queue_start_addr + cq_start;
uint32_t issue_queue_size = device_->sysmem_manager().get_issue_queue_size(cq_id_);

dependent_config_.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base();
static_config_.downstream_cb_log_page_size = DispatchSettings::DISPATCH_BUFFER_LOG_PAGE_SIZE;
static_config_.downstream_cb_pages = my_dispatch_constants.dispatch_buffer_pages();
static_config_.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore(
*program_, logical_core_, my_dispatch_constants.dispatch_buffer_pages(), GetCoreType());

static_config_.pcie_base = issue_queue_start_addr;
static_config_.pcie_size = issue_queue_size;
static_config_.prefetch_q_base =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::UNRESERVED);
static_config_.prefetch_q_base = my_dispatch_constants.prefetch_buffer_base();
static_config_.prefetch_q_size = my_dispatch_constants.prefetch_q_size();
static_config_.prefetch_q_rd_ptr_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_RD);
static_config_.prefetch_q_pcie_rd_ptr_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD);

static_config_.cmddat_q_base = my_dispatch_constants.cmddat_q_base();
static_config_.cmddat_q_base = my_dispatch_constants.cmddat_q_base<false>();
static_config_.cmddat_q_size = my_dispatch_constants.cmddat_q_size();

static_config_.scratch_db_base = my_dispatch_constants.scratch_db_base();
Expand Down Expand Up @@ -94,15 +93,14 @@ void PrefetchKernel::GenerateStaticConfigs() {

static_config_.pcie_base = issue_queue_start_addr;
static_config_.pcie_size = issue_queue_size;
static_config_.prefetch_q_base =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::UNRESERVED);
static_config_.prefetch_q_base = my_dispatch_constants.prefetch_buffer_base();
static_config_.prefetch_q_size = my_dispatch_constants.prefetch_q_size();
static_config_.prefetch_q_rd_ptr_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_RD);
static_config_.prefetch_q_pcie_rd_ptr_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD);

static_config_.cmddat_q_base = my_dispatch_constants.cmddat_q_base();
static_config_.cmddat_q_base = my_dispatch_constants.cmddat_q_base<false>();
static_config_.cmddat_q_size = my_dispatch_constants.cmddat_q_size();

static_config_.scratch_db_base = my_dispatch_constants.scratch_db_base();
Expand All @@ -123,7 +121,6 @@ void PrefetchKernel::GenerateStaticConfigs() {
static_config_.dispatch_s_buffer_size = 0;
static_config_.dispatch_s_cb_log_page_size = 0;
} else if (static_config_.is_d_variant.value()) {
dependent_config_.downstream_cb_base = my_dispatch_constants.dispatch_buffer_base();
static_config_.downstream_cb_log_page_size = DispatchSettings::PREFETCH_D_BUFFER_LOG_PAGE_SIZE;
static_config_.downstream_cb_pages = my_dispatch_constants.dispatch_buffer_pages();
static_config_.my_downstream_cb_sem_id = tt::tt_metal::CreateSemaphore(
Expand All @@ -138,13 +135,14 @@ void PrefetchKernel::GenerateStaticConfigs() {
static_config_.prefetch_q_pcie_rd_ptr_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD);

// prefetch_q is not used in d variant so buffer can start at the first region which is prefetch_buffer_base()
static_config_.cmddat_q_base = my_dispatch_constants.dispatch_buffer_base();
static_config_.cmddat_q_size = my_dispatch_constants.prefetch_d_buffer_size();

uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST);
static_config_.scratch_db_base = (my_dispatch_constants.dispatch_buffer_base() +
my_dispatch_constants.prefetch_d_buffer_size() + pcie_alignment - 1) &
(~(pcie_alignment - 1));
// scratch_db_base() is based on cmddat_q_base(). calculate manually instead.
static_config_.scratch_db_base =
tt::align(static_config_.cmddat_q_base.value() + static_config_.cmddat_q_size.value(), pcie_alignment);
static_config_.scratch_db_size = my_dispatch_constants.scratch_db_size();
static_config_.downstream_sync_sem_id =
tt::tt_metal::CreateSemaphore(*program_, logical_core_, 0, GetCoreType());
Expand Down Expand Up @@ -202,7 +200,9 @@ void PrefetchKernel::GenerateDependentConfigs() {
found_dispatch = true;

dependent_config_.downstream_logical_core = dispatch_kernel->GetLogicalCore();
dependent_config_.downstream_cb_sem_id = dispatch_kernel->GetStaticConfig().my_dispatch_cb_sem_id;
dependent_config_.downstream_cb_sem_id =
dispatch_kernel->GetStaticConfig().my_dispatch_cb_sem_id.value();
dependent_config_.downstream_cb_base = dispatch_kernel->GetStaticConfig().dispatch_cb_base.value();
} else if (auto dispatch_s_kernel = dynamic_cast<DispatchSKernel*>(k)) {
TT_ASSERT(!found_dispatch_s, "PREFETCH kernel has multiple downstream DISPATCH kernels.");
found_dispatch_s = true;
Expand Down Expand Up @@ -265,7 +265,9 @@ void PrefetchKernel::GenerateDependentConfigs() {
found_dispatch = true;

dependent_config_.downstream_logical_core = dispatch_kernel->GetLogicalCore();
dependent_config_.downstream_cb_sem_id = dispatch_kernel->GetStaticConfig().my_dispatch_cb_sem_id;
dependent_config_.downstream_cb_sem_id =
dispatch_kernel->GetStaticConfig().my_dispatch_cb_sem_id.value();
dependent_config_.downstream_cb_base = dispatch_kernel->GetStaticConfig().dispatch_cb_base.value();
} else if (auto dispatch_s_kernel = dynamic_cast<DispatchSKernel*>(k)) {
TT_ASSERT(!found_dispatch_s, "PREFETCH kernel has multiple downstream DISPATCH kernels.");
found_dispatch_s = true;
Expand Down

0 comments on commit e93160a

Please sign in to comment.