@@ -63,7 +63,7 @@ template <typename ValueType>
6363Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
6464 mpi::communicator comm, dim<2 > global_size,
6565 dim<2 > local_size, size_type stride)
66- : EnableLinOp <Vector>{exec, global_size},
66+ : matrix::EnableMultiVector <Vector>{exec, global_size},
6767 DistributedBase{comm},
6868 local_{exec, local_size, stride}
6969{
@@ -74,7 +74,7 @@ template <typename ValueType>
7474Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
7575 mpi::communicator comm, dim<2 > global_size,
7676 std::unique_ptr<local_vector_type> local_vector)
77- : EnableLinOp <Vector>{exec, global_size},
77+ : matrix::EnableMultiVector <Vector>{exec, global_size},
7878 DistributedBase{comm},
7979 local_{exec}
8080{
@@ -86,7 +86,9 @@ template <typename ValueType>
8686Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
8787 mpi::communicator comm,
8888 std::unique_ptr<local_vector_type> local_vector)
89- : EnableLinOp<Vector>{exec, {}}, DistributedBase{comm}, local_{exec}
89+ : matrix::EnableMultiVector<Vector>{exec, {}},
90+ DistributedBase{comm},
91+ local_{exec}
9092{
9193 this ->set_size (compute_global_size (exec, comm, local_vector->get_size ()));
9294 local_vector->move_to (&local_);
@@ -158,15 +160,225 @@ std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_const(
158160}
159161
160162
163+ template <typename ValueType>
164+ std::unique_ptr<typename Vector<ValueType>::absolute_type>
165+ Vector<ValueType>::compute_absolute_impl() const
166+ {
167+ return compute_absolute ();
168+ }
169+
170+ template <typename ValueType>
171+ void Vector<ValueType>::compute_absolute_inplace_impl()
172+ {
173+ compute_absolute_inplace ();
174+ }
175+
176+ template <typename ValueType>
177+ std::unique_ptr<typename Vector<ValueType>::complex_type>
178+ Vector<ValueType>::make_complex_impl() const
179+ {
180+ return make_complex ();
181+ }
182+
183+ template <typename ValueType>
184+ std::unique_ptr<typename Vector<ValueType>::real_type>
185+ Vector<ValueType>::get_real_impl() const
186+ {
187+ return get_real ();
188+ }
189+
190+ template <typename ValueType>
191+ std::unique_ptr<typename Vector<ValueType>::real_type>
192+ Vector<ValueType>::get_imag_impl() const
193+ {
194+ return get_imag ();
195+ }
196+
197+ template <typename ValueType>
198+ void Vector<ValueType>::fill_impl(matrix::any_value_t value)
199+ {
200+ std::visit (
201+ [this ](auto value) {
202+ using SndValueType = std::decay_t <decltype (value)>;
203+ if constexpr (!is_complex<ValueType>() &&
204+ is_complex<SndValueType>()) {
205+ GKO_INVALID_STATE (
206+ " Trying to fill a real vector with a complex value." );
207+ } else {
208+ fill (static_cast <ValueType>(value));
209+ }
210+ },
211+ value);
212+ }
213+
214+ template <typename ValueType>
215+ void Vector<ValueType>::scale_impl(matrix::any_const_dense_t alpha)
216+ {
217+ std::visit ([this ](auto alpha) { scale (alpha); }, alpha);
218+ }
219+
220+ template <typename ValueType>
221+ void Vector<ValueType>::inv_scale_impl(matrix::any_const_dense_t alpha)
222+ {
223+ std::visit ([this ](auto alpha) { inv_scale (alpha); }, alpha);
224+ }
225+
226+ template <typename ValueType>
227+ std::unique_ptr<const typename Vector<ValueType>::real_type>
228+ Vector<ValueType>::create_real_view_impl() const
229+ {
230+ return create_real_view ();
231+ }
232+
233+ template <typename ValueType>
234+ std::unique_ptr<typename Vector<ValueType>::real_type>
235+ Vector<ValueType>::create_real_view_impl()
236+ {
237+ return create_real_view ();
238+ }
239+
240+ template <typename ValueType>
241+ std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_subview_impl(
242+ matrix::local_span rows, matrix::local_span columns)
243+ {
244+ auto exec = this ->get_executor ();
245+ auto comm = this ->get_communicator ();
246+ auto global_rows = this ->get_size ()[0 ];
247+ auto global_cols = this ->get_size ()[1 ];
248+ comm.all_reduce (exec, &global_rows, 1 , MPI_SUM);
249+ comm.all_reduce (exec, &global_cols, 1 , MPI_SUM);
250+ return create_subview_impl (rows, columns, global_rows, global_cols);
251+ }
252+
253+
254+ template <typename ValueType>
255+ std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_subview_impl(
256+ matrix::local_span rows, matrix::local_span columns) const
257+ {
258+ auto exec = this ->get_executor ();
259+ auto comm = this ->get_communicator ();
260+ auto global_rows = this ->get_size ()[0 ];
261+ auto global_cols = this ->get_size ()[1 ];
262+ comm.all_reduce (exec, &global_rows, 1 , MPI_SUM);
263+ comm.all_reduce (exec, &global_cols, 1 , MPI_SUM);
264+ return create_subview_impl (rows, columns, global_rows, global_cols);
265+ }
266+
267+
268+ template <typename ValueType>
269+ std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_subview_impl(
270+ matrix::local_span rows, matrix::local_span columns, size_type global_rows,
271+ size_type globals_cols) const
272+ {
273+ // @todo: use const-cast here until dense also has const create_submatrix
274+ return create (
275+ this ->get_executor (), this ->get_communicator (),
276+ dim<2 >{global_rows, globals_cols},
277+ const_cast <local_vector_type&>(local_).create_submatrix (rows, columns));
278+ }
279+
280+
281+ template <typename ValueType>
282+ std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_subview_impl(
283+ matrix::local_span rows, matrix::local_span columns, size_type global_rows,
284+ size_type globals_cols)
285+ {
286+ return create (this ->get_executor (), this ->get_communicator (),
287+ dim<2 >{global_rows, globals_cols},
288+ local_.create_submatrix (rows, columns));
289+ }
290+
291+ template <typename ValueType>
292+ void Vector<ValueType>::make_complex_impl(complex_type* result) const
293+ {
294+ make_complex (result);
295+ }
296+
297+ template <typename ValueType>
298+ void Vector<ValueType>::get_real_impl(real_type* result) const
299+ {
300+ get_real (result);
301+ }
302+
303+ template <typename ValueType>
304+ void Vector<ValueType>::get_imag_impl(real_type* result) const
305+ {
306+ get_imag (result);
307+ }
308+
309+ template <typename ValueType>
310+ void Vector<ValueType>::add_scaled_impl(matrix::any_const_dense_t alpha,
311+ const Vector* b)
312+ {
313+ std::visit ([this , b](auto alpha) { add_scaled (alpha, b); }, alpha);
314+ }
315+
316+ template <typename ValueType>
317+ void Vector<ValueType>::sub_scaled_impl(matrix::any_const_dense_t alpha,
318+ const Vector* b)
319+ {
320+ std::visit ([this , b](auto alpha) { sub_scaled (alpha, b); }, alpha);
321+ }
322+
323+ template <typename ValueType>
324+ void Vector<ValueType>::compute_dot_impl(const Vector* b, Vector* result) const
325+ {
326+ compute_dot (b, result);
327+ }
328+
329+ template <typename ValueType>
330+ void Vector<ValueType>::compute_dot_impl(const Vector* b, Vector* result,
331+ array<char >& tmp) const
332+ {
333+ compute_dot (b, result, tmp);
334+ }
335+
336+ template <typename ValueType>
337+ void Vector<ValueType>::compute_conj_dot_impl(const Vector* b,
338+ Vector* result) const
339+ {
340+ compute_conj_dot (b, result);
341+ }
342+
343+ template <typename ValueType>
344+ void Vector<ValueType>::compute_conj_dot_impl(const Vector* b, Vector* result,
345+ array<char >& tmp) const
346+ {
347+ compute_conj_dot (b, result, tmp);
348+ }
349+
350+ template <typename ValueType>
351+ void Vector<ValueType>::compute_norm2_impl(absolute_type* result) const
352+ {
353+ compute_norm2 (result);
354+ }
355+
356+ template <typename ValueType>
357+ void Vector<ValueType>::compute_norm2_impl(absolute_type* result,
358+ array<char >& tmp) const
359+ {
360+ compute_norm2 (result, tmp);
361+ }
362+
363+ template <typename ValueType>
364+ void Vector<ValueType>::compute_norm1_impl(absolute_type* result) const
365+ {
366+ compute_norm1 (result);
367+ }
368+
369+ template <typename ValueType>
370+ void Vector<ValueType>::compute_norm1_impl(absolute_type* result,
371+ array<char >& tmp) const
372+ {
373+ compute_norm1 (result, tmp);
374+ }
375+
376+
161377template <typename ValueType>
162378std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_config_of(
163379 ptr_param<const Vector> other)
164380{
165- // De-referencing `other` before calling the functions (instead of
166- // using operator `->`) is currently required to be compatible with
167- // CUDA 10.1.
168- // Otherwise, it results in a compile error.
169- return (*other).create_with_same_config ();
381+ return other->create_with_same_config_impl ();
170382}
171383
172384
@@ -750,8 +962,8 @@ Vector<ValueType>::create_real_view()
750962
751963
752964template <typename ValueType>
753- std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_same_config()
754- const
965+ std::unique_ptr<Vector<ValueType>>
966+ Vector<ValueType>::create_with_same_config_impl() const
755967{
756968 return Vector::create (
757969 this ->get_executor (), this ->get_communicator (), this ->get_size (),
0 commit comments