Skip to content

Commit 1853043

Browse files
committed
fixup! derive vector from multivector
1 parent b531f1a commit 1853043

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

core/distributed/vector.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "core/distributed/vector_kernels.hpp"
1010
#include "core/matrix/dense_kernels.hpp"
1111
#include "core/mpi/mpi_op.hpp"
12-
12+
#include "ginkgo/core/base/temporary_conversion.hpp"
1313

1414
namespace gko {
1515
namespace experimental {
@@ -417,6 +417,63 @@ auto Vector<ValueType>::create_local_view_impl(
417417
type);
418418
}
419419

420+
template <typename ValueType>
421+
auto Vector<ValueType>::temporary_precision_impl(
422+
syn::variant_from_tuple<matrix::supported_value_types> type)
423+
-> std::unique_ptr<matrix::MultiVector,
424+
std::function<void(matrix::MultiVector*)>>
425+
{
426+
if (std::holds_alternative<ValueType>(type)) {
427+
return {this, null_deleter<matrix::MultiVector>{}};
428+
}
429+
return std::visit(
430+
[this](auto type)
431+
-> std::unique_ptr<matrix::MultiVector,
432+
std::function<void(matrix::MultiVector*)>> {
433+
using SndValueType = std::decay_t<decltype(type)>;
434+
if constexpr (is_complex<ValueType>() ==
435+
is_complex<SndValueType>()) {
436+
auto result = Vector<SndValueType>::create(
437+
this->get_executor(), this->get_communicator());
438+
this->convert_to(result.get());
439+
return {result.release(),
440+
gko::detail::dynamic_convert_back_deleter<
441+
Vector<ValueType>, matrix::MultiVector,
442+
Vector<SndValueType>>{this}};
443+
} else {
444+
// @todo: handle real <--> complex conversion
445+
GKO_INVALID_STATE("Unsupported value type");
446+
}
447+
},
448+
type);
449+
}
450+
451+
template <typename ValueType>
452+
auto Vector<ValueType>::temporary_precision_impl(
453+
syn::variant_from_tuple<matrix::supported_value_types> type) const
454+
-> std::unique_ptr<const matrix::MultiVector>
455+
{
456+
if (std::holds_alternative<ValueType>(type)) {
457+
return Vector::create_const(this->get_executor(),
458+
this->get_communicator(), this->get_size(),
459+
make_const_dense_view(&local_));
460+
}
461+
return std::visit(
462+
[this](auto type) -> std::unique_ptr<const matrix::MultiVector> {
463+
using SndValueType = std::decay_t<decltype(type)>;
464+
if constexpr (is_complex<ValueType>() ==
465+
is_complex<SndValueType>()) {
466+
auto result = Vector<SndValueType>::create(
467+
this->get_executor(), this->get_communicator());
468+
this->convert_to(result.get());
469+
return result;
470+
} else {
471+
// @todo: handle real <--> complex conversion
472+
GKO_INVALID_STATE("Unsupported value type");
473+
}
474+
},
475+
type);
476+
}
420477

421478
template <typename ValueType>
422479
auto Vector<ValueType>::get_stride_impl() const -> size_type

include/ginkgo/core/distributed/vector.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,14 @@ class Vector : public matrix::EnableMultiVector<Vector<ValueType>>,
746746
-> syn::variant_from_tuple<syn::apply_to_list<
747747
std::unique_ptr, syn::apply_to_list<std::add_const_t,
748748
matrix::dense_types>>> override;
749+
[[nodiscard]] auto temporary_precision_impl(
750+
syn::variant_from_tuple<matrix::supported_value_types> type)
751+
-> std::unique_ptr<matrix::MultiVector,
752+
std::function<void(matrix::MultiVector*)>> override;
753+
754+
[[nodiscard]] auto temporary_precision_impl(
755+
syn::variant_from_tuple<matrix::supported_value_types> type) const
756+
-> std::unique_ptr<const matrix::MultiVector> override;
749757

750758
[[nodiscard]] auto get_stride_impl() const -> size_type override;
751759

0 commit comments

Comments
 (0)