Skip to content

Commit b531f1a

Browse files
committed
fixup! derive vector from multivector
1 parent 3737124 commit b531f1a

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

core/distributed/vector.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,57 @@ void Vector<ValueType>::compute_norm1_impl(absolute_type* result,
374374
}
375375

376376

377+
template <typename ValueType>
378+
syn::variant_from_tuple<
379+
syn::apply_to_list<std::unique_ptr, matrix::dense_types>>
380+
Vector<ValueType>::create_local_view_impl(
381+
syn::variant_from_tuple<matrix::supported_value_types> type)
382+
{
383+
return std::visit(
384+
[this](auto type)
385+
-> syn::variant_from_tuple<
386+
syn::apply_to_list<std::unique_ptr, matrix::dense_types>> {
387+
using SndValueType = std::decay_t<decltype(type)>;
388+
if constexpr (std::is_same_v<ValueType, SndValueType>) {
389+
return make_dense_view(&local_);
390+
} else {
391+
GKO_INVALID_STATE("Unsupported value type");
392+
}
393+
},
394+
type);
395+
}
396+
397+
398+
template <typename ValueType>
399+
auto Vector<ValueType>::create_local_view_impl(
400+
syn::variant_from_tuple<matrix::supported_value_types> type) const
401+
-> syn::variant_from_tuple<syn::apply_to_list<
402+
std::unique_ptr,
403+
syn::apply_to_list<std::add_const_t, matrix::dense_types>>>
404+
{
405+
return std::visit(
406+
[this](auto type)
407+
-> syn::variant_from_tuple<syn::apply_to_list<
408+
std::unique_ptr,
409+
syn::apply_to_list<std::add_const_t, matrix::dense_types>>> {
410+
using SndValueType = std::decay_t<decltype(type)>;
411+
if constexpr (std::is_same_v<ValueType, SndValueType>) {
412+
return make_const_dense_view(&local_);
413+
} else {
414+
GKO_INVALID_STATE("Unsupported value type");
415+
}
416+
},
417+
type);
418+
}
419+
420+
421+
template <typename ValueType>
422+
auto Vector<ValueType>::get_stride_impl() const -> size_type
423+
{
424+
return local_.get_stride();
425+
}
426+
427+
377428
template <typename ValueType>
378429
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_config_of(
379430
ptr_param<const Vector> other)

include/ginkgo/core/distributed/vector.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,19 @@ class Vector : public matrix::EnableMultiVector<Vector<ValueType>>,
736736
void compute_norm1_impl(absolute_type* result,
737737
array<char>& tmp) const override;
738738

739+
[[nodiscard]] syn::variant_from_tuple<
740+
syn::apply_to_list<std::unique_ptr, matrix::dense_types>>
741+
create_local_view_impl(
742+
syn::variant_from_tuple<matrix::supported_value_types> type) override;
743+
744+
[[nodiscard]] auto create_local_view_impl(
745+
syn::variant_from_tuple<matrix::supported_value_types> type) const
746+
-> syn::variant_from_tuple<syn::apply_to_list<
747+
std::unique_ptr, syn::apply_to_list<std::add_const_t,
748+
matrix::dense_types>>> override;
749+
750+
[[nodiscard]] auto get_stride_impl() const -> size_type override;
751+
739752
private:
740753
local_vector_type local_;
741754
::gko::detail::DenseCache<ValueType> host_reduction_buffer_;

0 commit comments

Comments
 (0)