Skip to content

Commit ba29c1c

Browse files
committed
fixup! derive dense from multivector
1 parent 57536f6 commit ba29c1c

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

core/matrix/dense.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,58 @@ auto Dense<ValueType>::create_local_view_impl(
15781578
type);
15791579
}
15801580

1581+
template <typename ValueType>
1582+
auto Dense<ValueType>::temporary_precision_impl(
1583+
syn::variant_from_tuple<supported_value_types> type)
1584+
-> std::unique_ptr<MultiVector, std::function<void(MultiVector*)>>
1585+
{
1586+
if (std::holds_alternative<ValueType>(type)) {
1587+
return {this, null_deleter<MultiVector>{}};
1588+
}
1589+
return std::visit(
1590+
[this](auto type)
1591+
-> std::unique_ptr<MultiVector, std::function<void(MultiVector*)>> {
1592+
using SndValueType = std::decay_t<decltype(type)>;
1593+
if constexpr (is_complex<ValueType>() ==
1594+
is_complex<SndValueType>()) {
1595+
auto result = Dense<SndValueType>::create(this->get_executor());
1596+
this->convert_to(result.get());
1597+
return {result.release(),
1598+
gko::detail::dynamic_convert_back_deleter<
1599+
Dense<ValueType>, MultiVector, Dense<SndValueType>>{
1600+
this}};
1601+
} else {
1602+
// @todo: handle real <--> complex conversion
1603+
GKO_INVALID_STATE("Unsupported value type");
1604+
}
1605+
},
1606+
type);
1607+
}
1608+
1609+
template <typename ValueType>
1610+
auto Dense<ValueType>::temporary_precision_impl(
1611+
syn::variant_from_tuple<supported_value_types> type) const
1612+
-> std::unique_ptr<const MultiVector>
1613+
{
1614+
if (std::holds_alternative<ValueType>(type)) {
1615+
return make_const_dense_view(this);
1616+
}
1617+
return std::visit(
1618+
[this](auto type) -> std::unique_ptr<const MultiVector> {
1619+
using SndValueType = std::decay_t<decltype(type)>;
1620+
if constexpr (is_complex<ValueType>() ==
1621+
is_complex<SndValueType>()) {
1622+
auto result = Dense<SndValueType>::create(this->get_executor());
1623+
this->convert_to(result.get());
1624+
return result;
1625+
} else {
1626+
// @todo: handle real <--> complex conversion
1627+
GKO_INVALID_STATE("Unsupported value type");
1628+
}
1629+
},
1630+
type);
1631+
}
1632+
15811633
template <typename ValueType>
15821634
auto Dense<ValueType>::get_stride_impl() const -> size_type
15831635
{

include/ginkgo/core/matrix/dense.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1534,6 +1534,15 @@ class Dense : public EnableMultiVector<Dense<ValueType>>,
15341534
std::unique_ptr,
15351535
syn::apply_to_list<std::add_const_t, dense_types>>> override;
15361536

1537+
[[nodiscard]] auto temporary_precision_impl(
1538+
syn::variant_from_tuple<supported_value_types> type)
1539+
-> std::unique_ptr<MultiVector,
1540+
std::function<void(MultiVector*)>> override;
1541+
1542+
[[nodiscard]] auto temporary_precision_impl(
1543+
syn::variant_from_tuple<supported_value_types> type) const
1544+
-> std::unique_ptr<const MultiVector> override;
1545+
15371546
[[nodiscard]] auto get_stride_impl() const -> size_type override;
15381547

15391548
private:

0 commit comments

Comments
 (0)