Skip to content

Commit

Permalink
Unit test regression fix and dead code removal
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNijjar committed Feb 9, 2025
1 parent 8f653f5 commit 78f8965
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,18 @@ void kernel_main() {

// bit of a hack to extract X/Y
const auto dest_noc_address = get_noc_addr(p, dest_addr_gen, 0, NORMALIZED_NOC_INDEX);
const size_t packet_size = page_size;
const size_t packet_size = page_size + sizeof(tt::fabric::PacketHeader);
auto packet_addr = get_read_ptr(cb_id_in0);
auto& packet_header = *reinterpret_cast<tt::fabric::PacketHeader*>(packet_addr);
auto* packet_header = reinterpret_cast<volatile tt::fabric::PacketHeader*>(packet_addr);
if constexpr (mcast_mode) {
packet_header
.to_chip_multicast(tt::fabric::MulticastRoutingCommandHeader{config.mcast.distance, config.mcast.range})
.to_noc_unicast_write(
->to_chip_multicast(
tt::fabric::MulticastRoutingCommandHeader{config.mcast.distance, config.mcast.range})
->to_noc_unicast_write(
tt::fabric::NocUnicastCommandHeader{dest_noc_address}, (pages_to_send * page_size));
} else {
packet_header.to_chip_unicast(config.unicast.distance)
.to_noc_unicast_write(
packet_header->to_chip_unicast(config.unicast.distance)
->to_noc_unicast_write(
tt::fabric::NocUnicastCommandHeader{dest_noc_address}, (pages_to_send * page_size));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3266,7 +3266,6 @@ TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWra
RunWriteThroughputStabilityTestWithPersistentFabric(
num_mcasts, num_unicasts, num_links, num_op_invocations, params);
}
// hangs with DPRINT
TEST(EdmFabric, BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap_2Device) {
const size_t num_mcasts = 9;
const size_t num_unicasts = 0;
Expand Down Expand Up @@ -3294,7 +3293,6 @@ TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWra
RunWriteThroughputStabilityTestWithPersistentFabric(
num_mcasts, num_unicasts, num_links, num_op_invocations, params);
}
// First to hang - maybe somethign to do with merging traffic
TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap_TwoWorkers_4Device) {
const size_t num_mcasts = 9;
const size_t num_unicasts = 0;
Expand Down Expand Up @@ -3603,6 +3601,18 @@ TEST(EdmFabric, BasicMcastThroughputTest_3) {
RunWriteThroughputStabilityTestWithPersistentFabric(
num_mcasts, num_unicasts, num_links, num_op_invocations, params);
}
TEST(EdmFabric, BasicMcastThroughputTest_3_onehop) {
const size_t num_mcasts = 200000;
const size_t num_unicasts = 2;
const size_t num_links = 1;
const size_t num_op_invocations = 1;
const bool line_sync = true;
WriteThroughputStabilityTestWithPersistentFabricParams params;
params.line_sync = line_sync;
params.line_size = 2;
RunWriteThroughputStabilityTestWithPersistentFabric(
num_mcasts, num_unicasts, num_links, num_op_invocations, params);
}
TEST(EdmFabric, BasicMcastThroughputTest_4) {
const size_t num_mcasts = 800000;
const size_t num_unicasts = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

using ttnn::ccl::WorkerXY;

static constexpr bool enable_first_level_ack = true;
static constexpr bool fuse_receiver_flush_and_completion_ptr = true;

/*
The fabric Erisc Data Mover (EDM) is a component that can be used to build *very* simple linear topology fabrics.
Expand Down Expand Up @@ -247,11 +250,11 @@ constexpr uint8_t NUM_TRANSACTION_IDS = 4;

template <uint8_t MAX_TRANSACTION_IDS>
struct TransactionIdCounter {
void increment() {
FORCE_INLINE void increment() {
this->next_trid = tt::fabric::wrap_increment<MAX_TRANSACTION_IDS>(this->next_trid);
}

uint8_t get() const {
FORCE_INLINE uint8_t get() const {
return this->next_trid;
}

Expand Down Expand Up @@ -314,41 +317,37 @@ constexpr uint32_t to_sender_1_pkts_completed_id = 4;

// This will be an atomic register read to the register
template <uint32_t stream_id>
int32_t get_ptr_val() {
FORCE_INLINE int32_t get_ptr_val() {
return NOC_STREAM_READ_REG(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX);
constexpr uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX);
return *reinterpret_cast<volatile uint32_t*>(addr);
}
int32_t get_ptr_val(uint8_t stream_id) {
FORCE_INLINE int32_t get_ptr_val(uint8_t stream_id) {
return NOC_STREAM_READ_REG(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX);
const uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX);
return *reinterpret_cast<volatile uint32_t*>(addr);
}

// Writing to this register will leverage the built-in stream hardware which will automatically perform an atomic increment
// on the register. This can save precious erisc cycles by offloading a lot of pointer manipulation.
// Additionally, these registers are accessible via eth_reg_write calls which can be used to write a value,
// inline the eth command (without requiring source L1)
template <uint32_t stream_id>
void increment_local_update_ptr_val(int32_t val) {
FORCE_INLINE void increment_local_update_ptr_val(int32_t val) {
NOC_STREAM_WRITE_REG_FIELD(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX, REMOTE_DEST_BUF_WORDS_FREE_INC, val);
}
void increment_local_update_ptr_val(uint8_t stream_id, int32_t val) {
FORCE_INLINE void increment_local_update_ptr_val(uint8_t stream_id, int32_t val) {
NOC_STREAM_WRITE_REG_FIELD(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX, REMOTE_DEST_BUF_WORDS_FREE_INC, val);
}

template <uint32_t stream_id>
void remote_update_ptr_val(int32_t val) {
FORCE_INLINE void remote_update_ptr_val(int32_t val) {
constexpr uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX);
internal_::eth_write_remote_reg_no_txq_check(DEFAULT_ETH_TXQ, addr, val << REMOTE_DEST_BUF_WORDS_FREE_INC);
}
void remote_update_ptr_val(uint32_t stream_id, int32_t val) {
FORCE_INLINE void remote_update_ptr_val(uint32_t stream_id, int32_t val) {
const uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX);
internal_::eth_write_remote_reg_no_txq_check(DEFAULT_ETH_TXQ, addr, val << REMOTE_DEST_BUF_WORDS_FREE_INC);
}

template <uint32_t stream_id>
void init_ptr_val(int32_t val) {
FORCE_INLINE void init_ptr_val(int32_t val) {
NOC_STREAM_WRITE_REG(stream_id, STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, val);
}

Expand All @@ -371,19 +370,19 @@ struct OutboundReceiverChannelPointers {
tt::fabric::ChannelBufferPointer<RECEIVER_NUM_BUFFERS> ack_ptr;
tt::fabric::ChannelBufferPointer<RECEIVER_NUM_BUFFERS> completion_ptr;

bool has_space_for_packet() const {
FORCE_INLINE bool has_space_for_packet() const {
return completion_ptr.distance_behind(wrptr) < RECEIVER_NUM_BUFFERS;
}

bool has_unacknowledged_eth_packets() const {
FORCE_INLINE bool has_unacknowledged_eth_packets() const {
return ack_ptr.get_ptr() != wrptr.get_ptr();
}

bool has_incomplete_eth_packets() const {
FORCE_INLINE bool has_incomplete_eth_packets() const {
return completion_ptr.get_ptr() != wrptr.get_ptr();
}

bool has_unacknowledged_or_incomplete_eth_packets() const {
FORCE_INLINE bool has_unacknowledged_or_incomplete_eth_packets() const {
return has_incomplete_eth_packets() || has_unacknowledged_eth_packets();
}
};
Expand Down Expand Up @@ -486,20 +485,9 @@ static constexpr size_t worker_info_offset_past_connection_semaphore = 32;
// SENDER SIDE HELPERS
/////////////////////////////////////////////

template <uint8_t SENDER_NUM_BUFFERS, uint8_t RECEIVER_NUM_BUFFERS>
void send_channel_sync(
tt::fabric::EthChannelBuffer<SENDER_NUM_BUFFERS> &sender_buffer_channel,
tt::fabric::ChannelBufferPointer<SENDER_NUM_BUFFERS> &sender_wrptr,
tt::fabric::EthChannelBuffer<RECEIVER_NUM_BUFFERS> &receiver_buffer_channel,
tt::fabric::ChannelBufferPointer<RECEIVER_NUM_BUFFERS> &remote_receiver_wrptr
) {
auto src_addr = sender_buffer_channel.get_bytes_sent_address(sender_wrptr.get_buffer_index());
auto dest_addr = receiver_buffer_channel.get_bytes_sent_address(remote_receiver_wrptr.get_buffer_index());
internal_::eth_send_packet_bytes_unsafe(DEFAULT_ETH_TXQ, src_addr, dest_addr, sizeof(eth_channel_sync_t));
}

template <uint8_t SENDER_NUM_BUFFERS, uint8_t RECEIVER_NUM_BUFFERS>
void send_next_data(
FORCE_INLINE void send_next_data(
tt::fabric::EthChannelBuffer<SENDER_NUM_BUFFERS> &sender_buffer_channel,
tt::fabric::EdmChannelWorkerInterface<SENDER_NUM_BUFFERS> &sender_worker_interface,
OutboundReceiverChannelPointers<RECEIVER_NUM_BUFFERS> &outbound_to_receiver_channel_pointers,
Expand Down Expand Up @@ -550,7 +538,7 @@ void send_next_data(
* MUST CHECK !is_eth_txq_busy() before calling
*/
template <size_t NUM_SENDER_CHANNELS, uint8_t SENDER_NUM_BUFFERS, uint8_t RECEIVER_NUM_BUFFERS>
void receiver_send_received_ack(
FORCE_INLINE void receiver_send_received_ack(
std::array<tt::fabric::ChannelBufferPointer<SENDER_NUM_BUFFERS>, NUM_SENDER_CHANNELS> &remote_eth_sender_ackptrs,
std::array<tt::fabric::EthChannelBuffer<SENDER_NUM_BUFFERS>, NUM_SENDER_CHANNELS> &remote_sender_channels,
// currently the pointer is working multiple jobs (ack, completion, read) because we haven't implemented the
Expand Down Expand Up @@ -595,7 +583,7 @@ FORCE_INLINE bool can_forward_packet_completely(
}

// !!!WARNING!!! - MAKE SURE CONSUMER HAS SPACE BEFORE CALLING
void receiver_forward_packet(
FORCE_INLINE void receiver_forward_packet(
// TODO: have a separate cached copy of the packet header to save some additional L1 loads
volatile tt::fabric::PacketHeader *packet_start,
tt::fabric::RoutingFields cached_routing_fields,
Expand Down Expand Up @@ -663,22 +651,30 @@ FORCE_INLINE bool run_sender_channel_step(
outbound_to_receiver_channel_pointers.completion_ptr.increment_n(completions_since_last_check);
sender_rdptr.increment_n(completions_since_last_check);
increment_local_update_ptr_val(to_sender_packets_completed_streams[sender_channel_index], -completions_since_last_check);
if constexpr (!enable_first_level_ack) {
if (channel_connection_established) {
local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(sender_rdptr.get_ptr());
}
}
}

// Process ACKs from receiver
// ACKs are processed second to avoid any sort of races. If we process acks second,
// we are guaranteed to see equal to or greater the number of acks than completions
auto acks_since_last_check = get_ptr_val(to_sender_packets_acked_streams[sender_channel_index]);

auto& sender_ackptr = local_sender_channel_worker_interface.local_ackptr;
if (acks_since_last_check > 0) {
sender_ackptr.increment_n(acks_since_last_check);
if (channel_connection_established) {
local_sender_channel_worker_interface.update_worker_copy_of_read_ptr();
if constexpr (enable_first_level_ack) {
auto acks_since_last_check = get_ptr_val(to_sender_packets_acked_streams[sender_channel_index]);
auto& sender_ackptr = local_sender_channel_worker_interface.local_ackptr;
if (acks_since_last_check > 0) {
sender_ackptr.increment_n(acks_since_last_check);
if (channel_connection_established) {
local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(sender_ackptr.get_ptr());
}
increment_local_update_ptr_val(to_sender_packets_acked_streams[sender_channel_index], -acks_since_last_check);
}
increment_local_update_ptr_val(to_sender_packets_acked_streams[sender_channel_index], -acks_since_last_check);
did_something = did_something || (completions_since_last_check + acks_since_last_check) > 0;
} else {
did_something = did_something || (completions_since_last_check > 0);
}
did_something = did_something || (completions_since_last_check + acks_since_last_check) > 0;


if (!channel_connection_established) {
Expand All @@ -698,7 +694,11 @@ FORCE_INLINE bool run_sender_channel_step(
}
did_something = true;
channel_connection_established = true;
local_sender_channel_worker_interface.update_worker_copy_of_read_ptr();
if constexpr (enable_first_level_ack) {
local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(local_sender_channel_worker_interface.local_ackptr.get_ptr());
} else {
local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(local_sender_channel_worker_interface.local_rdptr.get_ptr());
}
}
} else if (local_sender_channel_worker_interface.has_worker_teardown_request()) {
did_something = true;
Expand All @@ -725,23 +725,27 @@ FORCE_INLINE void run_receiver_channel_step(
auto &ack_ptr = receiver_channel_pointers.ack_ptr;
auto pkts_received_since_last_check = get_ptr_val<to_receiver_pkts_sent_id>();
bool pkts_received = pkts_received_since_last_check > 0;
bool can_send_over_eth = !internal_::eth_txq_is_busy(DEFAULT_ETH_TXQ);
ASSERT(receiver_channel_pointers.completion_ptr.distance_behind(ack_ptr) < RECEIVER_NUM_BUFFERS);
if (pkts_received && can_send_over_eth) {
// currently only support processing one packet at a time, so we only decrement by 1
increment_local_update_ptr_val<to_receiver_pkts_sent_id>(-1);
receiver_send_received_ack(
remote_eth_sender_wrptrs,
remote_sender_channnels,
ack_ptr,
local_receiver_channel);
ack_ptr.increment();
if constexpr (enable_first_level_ack) {
bool can_send_over_eth = !internal_::eth_txq_is_busy(DEFAULT_ETH_TXQ);
ASSERT(receiver_channel_pointers.completion_ptr.distance_behind(ack_ptr) < RECEIVER_NUM_BUFFERS);
if (pkts_received && can_send_over_eth) {
// currently only support processing one packet at a time, so we only decrement by 1
increment_local_update_ptr_val<to_receiver_pkts_sent_id>(-1);
receiver_send_received_ack(
remote_eth_sender_wrptrs,
remote_sender_channnels,
ack_ptr,
local_receiver_channel);
ack_ptr.increment();
}
} else {
increment_local_update_ptr_val<to_receiver_pkts_sent_id>(-pkts_received_since_last_check);
ack_ptr.increment_n(pkts_received_since_last_check);
}

auto &wr_sent_ptr = receiver_channel_pointers.wr_sent_ptr;
bool unwritten_packets = !wr_sent_ptr.is_caught_up_to(ack_ptr);
if (unwritten_packets) {
DeviceZoneScopedN("EDMR-Send-Chk");
auto receiver_buffer_index = wr_sent_ptr.get_buffer_index();
volatile auto packet_header = local_receiver_channel.get_packet_header(receiver_buffer_index);

Expand All @@ -751,37 +755,57 @@ FORCE_INLINE void run_receiver_channel_step(
can_forward_packet_completely(packet_header, cached_routing_fields, downstream_edm_interface);
bool trid_flushed = receiver_channel_trid_tracker.transaction_flushed(receiver_buffer_index);
if (can_send_to_all_local_chip_receivers && trid_flushed) {
DeviceZoneScopedN("EDMR-Send-Impl");
// DeviceZoneScopedN("EDMR-Send-Impl");
uint8_t trid = receiver_channel_trid_tracker.update_buffer_slot_to_next_trid_and_advance_trid_counter(receiver_buffer_index);
receiver_forward_packet(packet_header, cached_routing_fields, downstream_edm_interface, trid);
wr_sent_ptr.increment();
}
}

auto &wr_flush_ptr = receiver_channel_pointers.wr_flush_ptr;
bool unflushed_writes = !wr_flush_ptr.is_caught_up_to(wr_sent_ptr);
if (unflushed_writes) {
auto receiver_buffer_index = wr_flush_ptr.get_buffer_index();
bool next_trid_flushed = receiver_channel_trid_tracker.transaction_flushed(receiver_buffer_index);
if (next_trid_flushed) {
local_receiver_channel.eth_clear_sender_channel_ack(receiver_buffer_index);
wr_flush_ptr.increment();
receiver_channel_trid_tracker.clear_trid_at_buffer_slot(receiver_buffer_index);
if constexpr (!fuse_receiver_flush_and_completion_ptr) {
auto &wr_flush_ptr = receiver_channel_pointers.wr_flush_ptr;
bool unflushed_writes = !wr_flush_ptr.is_caught_up_to(wr_sent_ptr);
if (unflushed_writes) {
auto receiver_buffer_index = wr_flush_ptr.get_buffer_index();
bool next_trid_flushed = receiver_channel_trid_tracker.transaction_flushed(receiver_buffer_index);
if (next_trid_flushed) {
wr_flush_ptr.increment();
receiver_channel_trid_tracker.clear_trid_at_buffer_slot(receiver_buffer_index);
}
}
}

auto &completion_ptr = receiver_channel_pointers.completion_ptr;
bool unsent_completions = !completion_ptr.is_caught_up_to(wr_flush_ptr);
if (unsent_completions) {
bool can_send_without_blocking = !internal_::eth_txq_is_busy(DEFAULT_ETH_TXQ);
if (can_send_without_blocking) {
// completion ptr incremented in callee
receiver_send_completion_ack(
remote_eth_sender_wrptrs,
remote_sender_channnels,
completion_ptr,
local_receiver_channel);
auto &completion_ptr = receiver_channel_pointers.completion_ptr;
bool unsent_completions = !completion_ptr.is_caught_up_to(wr_flush_ptr);
if (unsent_completions) {
bool can_send_without_blocking = !internal_::eth_txq_is_busy(DEFAULT_ETH_TXQ);
if (can_send_without_blocking) {
// completion ptr incremented in callee
receiver_send_completion_ack(
remote_eth_sender_wrptrs,
remote_sender_channnels,
completion_ptr,
local_receiver_channel);
}
}
} else {
auto &wr_flush_ptr = receiver_channel_pointers.wr_flush_ptr;
// Currently unclear if it's better to loop here or not... Also unclear if merging these
// two pointers is better or not... Seems to be maybe 5-10% better merged but need more data
if (!wr_flush_ptr.is_caught_up_to(wr_sent_ptr) && !internal_::eth_txq_is_busy(DEFAULT_ETH_TXQ)) {
auto receiver_buffer_index = wr_flush_ptr.get_buffer_index();
bool next_trid_flushed = receiver_channel_trid_tracker.transaction_flushed(receiver_buffer_index);
if (next_trid_flushed) {
auto &completion_ptr = receiver_channel_pointers.completion_ptr;
wr_flush_ptr.increment();
receiver_channel_trid_tracker.clear_trid_at_buffer_slot(receiver_buffer_index);
receiver_send_completion_ack(
remote_eth_sender_wrptrs,
remote_sender_channnels,
completion_ptr,
local_receiver_channel);
}
}

}
};

Expand Down Expand Up @@ -976,7 +1000,7 @@ void kernel_main() {
static constexpr size_t sender_channel_0_counters_address = get_compile_time_arg_val(18);
static constexpr size_t sender_channel_1_counters_address = get_compile_time_arg_val(19);

static constexpr bool enable_packet_header_recording = get_compile_time_arg_val(20) != 0;
static constexpr bool enable_packet_header_recording = false; //get_compile_time_arg_val(20) != 0;
static constexpr size_t receiver_completed_packet_header_cb_address = get_compile_time_arg_val(21);
static constexpr size_t receiver_completed_packet_header_cb_size_headers = get_compile_time_arg_val(22);
static constexpr size_t sender_0_completed_packet_header_cb_address = get_compile_time_arg_val(23);
Expand Down
Loading

0 comments on commit 78f8965

Please sign in to comment.