Skip to content

Commit ca04f80

Browse files
andralexcursoragent
andcommitted
Make equality and ordering consistent
Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 53f84c1 commit ca04f80

File tree

4 files changed

+16
-17
lines changed

4 files changed

+16
-17
lines changed

cudax/include/cuda/experimental/__stf/places/data_place_extension.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class exec_place;
5858
* int get_device_ordinal() const override { return my_device_id; }
5959
* ::std::string to_string() const override { return "my_custom_place"; }
6060
* size_t hash() const override { return std::hash<int>{}(my_device_id); }
61-
* bool equals(const data_place_extension& other) const override { ... }
61+
* bool operator==(const data_place_extension& other) const override { ... }
6262
* };
6363
* @endcode
6464
*/
@@ -104,15 +104,15 @@ public:
104104
* @param other The other extension to compare with
105105
* @return true if the extensions represent the same place
106106
*/
107-
virtual bool equals(const data_place_extension& other) const = 0;
107+
virtual bool operator==(const data_place_extension& other) const = 0;
108108

109109
/**
110110
* @brief Compare ordering with another extension
111111
*
112112
* @param other The other extension to compare with
113113
* @return true if this extension is less than the other
114114
*/
115-
virtual bool less_than(const data_place_extension& other) const = 0;
115+
virtual bool operator<(const data_place_extension& other) const = 0;
116116

117117
/**
118118
* @brief Create a physical memory allocation for this place (VMM API)

cudax/include/cuda/experimental/__stf/places/exec/cuda_stream.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ public:
7171

7272
bool operator==(const exec_place::impl& rhs) const override
7373
{
74-
auto other = dynamic_cast<const impl*>(&rhs);
75-
if (!other)
74+
if (typeid(*this) != typeid(rhs))
7675
{
7776
return false;
7877
}
78+
const auto& other = static_cast<const impl&>(rhs);
7979
// Compare by stream handle
80-
return dstream.stream == other->dstream.stream;
80+
return dstream.stream == other.dstream.stream;
8181
}
8282

8383
size_t hash() const override

cudax/include/cuda/experimental/__stf/places/exec/green_context.cuh

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,17 @@ public:
9191
return hash_all(view_.g_ctx, view_.pool, view_.devid);
9292
}
9393

94-
bool equals(const data_place_extension& other) const override
94+
bool operator==(const data_place_extension& other) const override
9595
{
96-
const auto* other_gc = dynamic_cast<const extension*>(&other);
97-
if (!other_gc)
96+
if (typeid(*this) != typeid(other))
9897
{
9998
return false;
10099
}
101-
return view_ == other_gc->view_;
100+
const auto& other_gc = static_cast<const extension&>(other);
101+
return view_ == other_gc.view_;
102102
}
103103

104-
bool less_than(const data_place_extension& other) const override
104+
bool operator<(const data_place_extension& other) const override
105105
{
106106
if (typeid(*this) != typeid(other))
107107
{
@@ -424,14 +424,13 @@ public:
424424

425425
bool operator==(const exec_place::impl& rhs) const override
426426
{
427-
// First, check if rhs is also a green context impl
428-
auto other = dynamic_cast<const impl*>(&rhs);
429-
if (!other)
427+
if (typeid(*this) != typeid(rhs))
430428
{
431429
return false;
432430
}
431+
const auto& other = static_cast<const impl&>(rhs);
433432
// Compare green context handles
434-
return g_ctx == other->g_ctx;
433+
return g_ctx == other.g_ctx;
435434
}
436435

437436
size_t hash() const override

cudax/include/cuda/experimental/__stf/places/places.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ public:
174174
// If both are extensions, delegate to the extension
175175
if (is_extension() && rhs.is_extension())
176176
{
177-
return extension->less_than(*rhs.extension);
177+
return *extension < *rhs.extension;
178178
}
179179

180180
// Extensions sort after non-extensions
@@ -1881,7 +1881,7 @@ inline bool data_place::operator==(const data_place& rhs) const
18811881
if (is_extension())
18821882
{
18831883
_CCCL_ASSERT(devid == extension_devid, "");
1884-
return (rhs.devid == extension_devid && extension->equals(*rhs.extension));
1884+
return (rhs.devid == extension_devid && *extension == *rhs.extension);
18851885
}
18861886

18871887
return (get_grid() == rhs.get_grid() && (get_partitioner() == rhs.get_partitioner()));

0 commit comments

Comments
 (0)