Skip to content

Commit 50dc1e1

Browse files
dplassgitcopybara-github
authored andcommitted
[Proc-scoped channels] Update get_conversion_records to use all invocations of a proc instead of the proc functions themselves to determine if a proc should be added to the conversion records list.
Fix type info to properly record invocations. PiperOrigin-RevId: 820224956
1 parent 533a0e9 commit 50dc1e1

13 files changed

+380
-110
lines changed

xls/dslx/ir_convert/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ cc_library(
507507
"//xls/dslx/type_system:parametric_env",
508508
"//xls/dslx/type_system:type_info",
509509
"//xls/public:status_macros",
510+
"@com_google_absl//absl/container:flat_hash_set",
510511
"@com_google_absl//absl/log",
511512
"@com_google_absl//absl/status",
512513
"@com_google_absl//absl/status:statusor",

xls/dslx/ir_convert/conversion_record.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,11 @@ std::string ConversionRecord::ToString() const {
101101
config = config_record_->ToString();
102102
}
103103
return absl::StrFormat(
104-
"ConversionRecord{m=%s, f=%s, top=%s, pid=%s, parametric_env=%s, "
105-
"type_info=%p, config=%s}",
106-
module_->name(), f_->identifier(), is_top_ ? "true" : "false", proc_id,
107-
parametric_env_.ToString(), type_info_, config);
104+
"ConversionRecord{m=%s, invocation=%p, f=%s, top=%s, pid=%s, "
105+
"parametric_env=%s, type_info=%p, config=%s}",
106+
module_->name(), invocation_, f_->identifier(),
107+
is_top_ ? "true" : "false", proc_id, parametric_env_.ToString(),
108+
type_info_, config);
108109
}
109110

110111
} // namespace xls::dslx

xls/dslx/ir_convert/get_conversion_records.cc

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <variant>
2121
#include <vector>
2222

23+
#include "absl/container/flat_hash_set.h"
2324
#include "absl/log/log.h"
2425
#include "absl/status/status.h"
2526
#include "absl/status/statusor.h"
@@ -70,12 +71,6 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
7071
} else {
7172
VLOG(5) << "Processing fn " << f->ToString();
7273
}
73-
// TODO: davidplass - change this to gather invocations from *spawns*
74-
// instead of *functions*. This will allow function_converter to emit
75-
// multiple IR procs with the same parametrics but different config
76-
// parameters. Then, HandleFunction would only have to handle procs
77-
// that are not spawned explicitly, like test or top procs.
78-
7974
// Note, it's possible there is no config invocation if it's a
8075
// top proc or some other reason.
8176
std::unique_ptr<ConversionRecord> config_record;
@@ -124,16 +119,13 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
124119
if (f->proc().has_value()) {
125120
proc_id = proc_id_factory_.CreateProcId(
126121
/*parent=*/std::nullopt, f->proc().value(),
127-
// TODO: davidplass - For parametric procs we have to decide if this
128-
// is a new instance if it has been called with the same parametrics
129-
// before. Otherwise it needs a new procid.
130122
/*count_as_new_instance=*/false);
131123
}
132124
std::vector<InvocationCalleeData> calls =
133125
type_info_->GetUniqueInvocationCalleeData(f);
134126
if (f->IsParametric() && calls.empty()) {
135127
VLOG(5) << "No calls to parametric proc " << f->name_def()->ToString();
136-
return absl::OkStatus();
128+
return DefaultHandler(f);
137129
}
138130
for (auto& callee_data : calls) {
139131
XLS_ASSIGN_OR_RETURN(
@@ -184,7 +176,35 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
184176
}
185177

186178
absl::Status HandleProc(const Proc* p) override {
187-
return AddFunction(&p->next());
179+
const Function* next_fn = &p->next();
180+
181+
if (p->IsParametric()) {
182+
std::optional<ProcId> proc_id = proc_id_factory_.CreateProcId(
183+
/*parent=*/std::nullopt, const_cast<Proc*>(p),
184+
/*count_as_new_instance=*/false);
185+
186+
std::vector<InvocationCalleeData> next_calls =
187+
type_info_->GetUniqueInvocationCalleeData(next_fn);
188+
for (auto& callee_data : next_calls) {
189+
XLS_ASSIGN_OR_RETURN(
190+
ConversionRecord cr,
191+
InvocationToConversionRecord(
192+
next_fn, callee_data.invocation, callee_data.derived_type_info,
193+
callee_data.callee_bindings, callee_data.caller_bindings,
194+
// Since this proc is being spawned, it's certainly not top.
195+
/* is_top= */ false, proc_id));
196+
records_.push_back(std::move(cr));
197+
}
198+
}
199+
if (top_ == next_fn || !p->IsParametric()) {
200+
// "top" procs won't have spawns referencing them so they won't
201+
// otherwise be added to the list, so we have to manually do it here.
202+
203+
// Similarly, if a proc is not parametric, while it might not have any
204+
// spawns, we still want to convert it.
205+
return AddFunction(next_fn);
206+
}
207+
return absl::OkStatus();
188208
}
189209

190210
absl::Status HandleTestProc(const TestProc* tp) override {
@@ -215,6 +235,20 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
215235

216236
} // namespace
217237

238+
// Filters duplicate conversion records from the given vector and returns a new
239+
// vector without duplicates.
240+
std::vector<ConversionRecord> RemoveFunctionDuplicates(
241+
std::vector<ConversionRecord>& ready) {
242+
absl::flat_hash_set<std::pair<Function*, ParametricEnv>> records;
243+
std::vector<ConversionRecord> result;
244+
for (auto& record : ready) {
245+
if (records.emplace(record.f(), record.parametric_env()).second) {
246+
result.push_back(std::move(record));
247+
}
248+
}
249+
return result;
250+
}
251+
218252
absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecords(
219253
Module* module, TypeInfo* type_info, bool include_tests) {
220254
ProcIdFactory proc_id_factory;
@@ -224,7 +258,8 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecords(
224258
proc_id_factory, /*top=*/nullptr);
225259
XLS_RETURN_IF_ERROR(module->Accept(&visitor));
226260

227-
return visitor.records();
261+
std::vector<ConversionRecord> records = visitor.records();
262+
return RemoveFunctionDuplicates(records);
228263
}
229264

230265
absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
@@ -238,7 +273,9 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
238273
ConversionRecordVisitor visitor(m, type_info, /*include_tests=*/true,
239274
proc_id_factory, f);
240275
XLS_RETURN_IF_ERROR(m->Accept(&visitor));
241-
return visitor.records();
276+
277+
std::vector<ConversionRecord> records = visitor.records();
278+
return RemoveFunctionDuplicates(records);
242279
}
243280

244281
Proc* p = std::get<Proc*>(entry);
@@ -250,6 +287,8 @@ absl::StatusOr<std::vector<ConversionRecord>> GetConversionRecordsForEntry(
250287
ConversionRecordVisitor visitor(m, new_ti, /*include_tests=*/true,
251288
proc_id_factory, &p->next());
252289
XLS_RETURN_IF_ERROR(m->Accept(&visitor));
253-
return visitor.records();
290+
291+
std::vector<ConversionRecord> records = visitor.records();
292+
return RemoveFunctionDuplicates(records);
254293
}
255294
} // namespace xls::dslx

xls/dslx/ir_convert/get_conversion_records_test.cc

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,13 @@ proc top {
164164
GetConversionRecords(tm.module, tm.type_info, false));
165165
ASSERT_EQ(3, order.size());
166166
EXPECT_EQ(order[0].f()->identifier(), "P.next");
167-
EXPECT_EQ(order[0].invocation()->ToString(),
168-
"P.next<u32:2>(P.init<u32:2>())");
169167
const ConversionRecord* config_record = order[0].config_record();
170168
EXPECT_NE(config_record, nullptr);
171169
EXPECT_EQ(config_record->invocation()->ToString(), "P.config<u32:2>(u2:1)");
172170
EXPECT_EQ(order[0].parametric_env(),
173171
ParametricEnv(absl::flat_hash_map<std::string, InterpValue>{
174172
{"N", InterpValue::MakeUBits(/*bit_count=*/32, /*value=*/2)}}));
175173
EXPECT_EQ(order[1].f()->identifier(), "P.next");
176-
EXPECT_EQ(order[1].invocation()->ToString(),
177-
"P.next<u32:4>(P.init<u32:4>())");
178174
config_record = order[1].config_record();
179175
EXPECT_NE(config_record, nullptr);
180176
EXPECT_EQ(config_record->invocation()->ToString(), "P.config<u32:4>(u4:2)");
@@ -388,8 +384,9 @@ fn my_test() -> bool { f<u32:8>(u8:1) == u32:8 }
388384
XLS_ASSERT_OK_AND_ASSIGN(
389385
TypecheckedModule tm,
390386
ParseAndTypecheck(kProgram, "test.x", "test", &import_data));
391-
XLS_ASSERT_OK_AND_ASSIGN(std::vector<ConversionRecord> order,
392-
GetConversionRecords(tm.module, tm.type_info, true));
387+
XLS_ASSERT_OK_AND_ASSIGN(
388+
std::vector<ConversionRecord> order,
389+
GetConversionRecords(tm.module, tm.type_info, /*include_tests=*/true));
393390
ASSERT_EQ(2, order.size());
394391
EXPECT_EQ(order[0].f()->identifier(), "f");
395392
EXPECT_EQ(order[0].parametric_env(),
@@ -440,13 +437,15 @@ proc test {
440437
next(x: ()) { () }
441438
}
442439
)";
440+
443441
TEST(GetConversionRecordsTest, TestProc) {
444442
auto import_data = CreateImportDataForTest();
445443
XLS_ASSERT_OK_AND_ASSIGN(
446444
TypecheckedModule tm,
447445
ParseAndTypecheck(kTestProc, "test.x", "test", &import_data));
448-
XLS_ASSERT_OK_AND_ASSIGN(std::vector<ConversionRecord> order,
449-
GetConversionRecords(tm.module, tm.type_info, true));
446+
XLS_ASSERT_OK_AND_ASSIGN(
447+
std::vector<ConversionRecord> order,
448+
GetConversionRecords(tm.module, tm.type_info, /*include_tests=*/true));
450449
ASSERT_EQ(2, order.size());
451450
EXPECT_EQ(order[0].f()->identifier(), "P.next");
452451
EXPECT_EQ(order[0].parametric_env(),
@@ -462,7 +461,7 @@ TEST(GetConversionRecordsTest, TestProcSkipped) {
462461
ParseAndTypecheck(kTestProc, "test.x", "test", &import_data));
463462
XLS_ASSERT_OK_AND_ASSIGN(
464463
std::vector<ConversionRecord> order,
465-
GetConversionRecords(tm.module, tm.type_info, false));
464+
GetConversionRecords(tm.module, tm.type_info, /*include_tests=*/false));
466465
// It still converts the parametric proc because there is still a spawn,
467466
// in the test proc.
468467
ASSERT_EQ(1, order.size());
@@ -511,6 +510,61 @@ impl S {
511510
EXPECT_EQ(order[0].parametric_env(), ParametricEnv());
512511
}
513512

513+
TEST(GetConversionRecordsTest, SpawnTree) {
514+
constexpr std::string_view kProgram = R"(
515+
proc spawnee2<N:u32, M:u16> {
516+
init { }
517+
config() { () }
518+
next(state: ()) { () }
519+
}
520+
521+
proc spawnee1<N:u32> {
522+
init { }
523+
config() {
524+
spawn spawnee2<N, u16:1>();
525+
spawn spawnee2<N, u16:2>();
526+
}
527+
next(state: ()) { () }
528+
}
529+
530+
pub proc main {
531+
init { }
532+
config() {
533+
spawn spawnee1<u32:3>();
534+
spawn spawnee1<u32:4>();
535+
}
536+
next(state: ()) { () }
537+
}
538+
)";
539+
540+
auto import_data = CreateImportDataForTest();
541+
XLS_ASSERT_OK_AND_ASSIGN(
542+
TypecheckedModule tm,
543+
ParseAndTypecheck(kProgram, "test.x", "main", &import_data, {},
544+
TypeInferenceVersion::kVersion2));
545+
XLS_ASSERT_OK_AND_ASSIGN(
546+
std::vector<ConversionRecord> order,
547+
GetConversionRecords(tm.module, tm.type_info, false));
548+
// 4 of spawnee2, 2 of spawnee1 and 1 of main.
549+
ASSERT_EQ(7, order.size());
550+
EXPECT_EQ(order[0].parametric_env(),
551+
ParametricEnv(absl::flat_hash_map<std::string, InterpValue>{
552+
{"N", InterpValue::MakeUBits(/*bit_count=*/32, /*value=*/3)},
553+
{"M", InterpValue::MakeUBits(/*bit_count=*/16, /*value=*/1)}}));
554+
EXPECT_EQ(order[1].parametric_env(),
555+
ParametricEnv(absl::flat_hash_map<std::string, InterpValue>{
556+
{"N", InterpValue::MakeUBits(/*bit_count=*/32, /*value=*/3)},
557+
{"M", InterpValue::MakeUBits(/*bit_count=*/16, /*value=*/2)}}));
558+
EXPECT_EQ(order[2].parametric_env(),
559+
ParametricEnv(absl::flat_hash_map<std::string, InterpValue>{
560+
{"N", InterpValue::MakeUBits(/*bit_count=*/32, /*value=*/4)},
561+
{"M", InterpValue::MakeUBits(/*bit_count=*/16, /*value=*/1)}}));
562+
EXPECT_EQ(order[3].parametric_env(),
563+
ParametricEnv(absl::flat_hash_map<std::string, InterpValue>{
564+
{"N", InterpValue::MakeUBits(/*bit_count=*/32, /*value=*/4)},
565+
{"M", InterpValue::MakeUBits(/*bit_count=*/16, /*value=*/2)}}));
566+
}
567+
514568
} // namespace
515569

516570
} // namespace xls::dslx

0 commit comments

Comments
 (0)