|
9 | 9 | #include "core/distributed/vector_kernels.hpp"
|
10 | 10 | #include "core/matrix/dense_kernels.hpp"
|
11 | 11 | #include "core/mpi/mpi_op.hpp"
|
12 |
| - |
| 12 | +#include "ginkgo/core/base/temporary_conversion.hpp" |
13 | 13 |
|
14 | 14 | namespace gko {
|
15 | 15 | namespace experimental {
|
@@ -417,6 +417,63 @@ auto Vector<ValueType>::create_local_view_impl(
|
417 | 417 | type);
|
418 | 418 | }
|
419 | 419 |
|
| 420 | +template <typename ValueType> |
| 421 | +auto Vector<ValueType>::temporary_precision_impl( |
| 422 | + syn::variant_from_tuple<matrix::supported_value_types> type) |
| 423 | + -> std::unique_ptr<matrix::MultiVector, |
| 424 | + std::function<void(matrix::MultiVector*)>> |
| 425 | +{ |
| 426 | + if (std::holds_alternative<ValueType>(type)) { |
| 427 | + return {this, null_deleter<matrix::MultiVector>{}}; |
| 428 | + } |
| 429 | + return std::visit( |
| 430 | + [this](auto type) |
| 431 | + -> std::unique_ptr<matrix::MultiVector, |
| 432 | + std::function<void(matrix::MultiVector*)>> { |
| 433 | + using SndValueType = std::decay_t<decltype(type)>; |
| 434 | + if constexpr (is_complex<ValueType>() == |
| 435 | + is_complex<SndValueType>()) { |
| 436 | + auto result = Vector<SndValueType>::create( |
| 437 | + this->get_executor(), this->get_communicator()); |
| 438 | + this->convert_to(result.get()); |
| 439 | + return {result.release(), |
| 440 | + gko::detail::dynamic_convert_back_deleter< |
| 441 | + Vector<ValueType>, matrix::MultiVector, |
| 442 | + Vector<SndValueType>>{this}}; |
| 443 | + } else { |
| 444 | + // @todo: handle real <--> complex conversion |
| 445 | + GKO_INVALID_STATE("Unsupported value type"); |
| 446 | + } |
| 447 | + }, |
| 448 | + type); |
| 449 | +} |
| 450 | + |
| 451 | +template <typename ValueType> |
| 452 | +auto Vector<ValueType>::temporary_precision_impl( |
| 453 | + syn::variant_from_tuple<matrix::supported_value_types> type) const |
| 454 | + -> std::unique_ptr<const matrix::MultiVector> |
| 455 | +{ |
| 456 | + if (std::holds_alternative<ValueType>(type)) { |
| 457 | + return Vector::create_const(this->get_executor(), |
| 458 | + this->get_communicator(), this->get_size(), |
| 459 | + make_const_dense_view(&local_)); |
| 460 | + } |
| 461 | + return std::visit( |
| 462 | + [this](auto type) -> std::unique_ptr<const matrix::MultiVector> { |
| 463 | + using SndValueType = std::decay_t<decltype(type)>; |
| 464 | + if constexpr (is_complex<ValueType>() == |
| 465 | + is_complex<SndValueType>()) { |
| 466 | + auto result = Vector<SndValueType>::create( |
| 467 | + this->get_executor(), this->get_communicator()); |
| 468 | + this->convert_to(result.get()); |
| 469 | + return result; |
| 470 | + } else { |
| 471 | + // @todo: handle real <--> complex conversion |
| 472 | + GKO_INVALID_STATE("Unsupported value type"); |
| 473 | + } |
| 474 | + }, |
| 475 | + type); |
| 476 | +} |
420 | 477 |
|
421 | 478 | template <typename ValueType>
|
422 | 479 | auto Vector<ValueType>::get_stride_impl() const -> size_type
|
|
0 commit comments