Skip to content

Commit 3737124

Browse files
committed
derive vector from multivector
1 parent ba29c1c commit 3737124

File tree

2 files changed

+305
-26
lines changed

2 files changed

+305
-26
lines changed

core/distributed/vector.cpp

Lines changed: 222 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ template <typename ValueType>
6363
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
6464
mpi::communicator comm, dim<2> global_size,
6565
dim<2> local_size, size_type stride)
66-
: EnableLinOp<Vector>{exec, global_size},
66+
: matrix::EnableMultiVector<Vector>{exec, global_size},
6767
DistributedBase{comm},
6868
local_{exec, local_size, stride}
6969
{
@@ -74,7 +74,7 @@ template <typename ValueType>
7474
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
7575
mpi::communicator comm, dim<2> global_size,
7676
std::unique_ptr<local_vector_type> local_vector)
77-
: EnableLinOp<Vector>{exec, global_size},
77+
: matrix::EnableMultiVector<Vector>{exec, global_size},
7878
DistributedBase{comm},
7979
local_{exec}
8080
{
@@ -86,7 +86,9 @@ template <typename ValueType>
8686
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
8787
mpi::communicator comm,
8888
std::unique_ptr<local_vector_type> local_vector)
89-
: EnableLinOp<Vector>{exec, {}}, DistributedBase{comm}, local_{exec}
89+
: matrix::EnableMultiVector<Vector>{exec, {}},
90+
DistributedBase{comm},
91+
local_{exec}
9092
{
9193
this->set_size(compute_global_size(exec, comm, local_vector->get_size()));
9294
local_vector->move_to(&local_);
@@ -158,15 +160,225 @@ std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_const(
158160
}
159161

160162

163+
template <typename ValueType>
164+
std::unique_ptr<typename Vector<ValueType>::absolute_type>
165+
Vector<ValueType>::compute_absolute_impl() const
166+
{
167+
return compute_absolute();
168+
}
169+
170+
template <typename ValueType>
171+
void Vector<ValueType>::compute_absolute_inplace_impl()
172+
{
173+
compute_absolute_inplace();
174+
}
175+
176+
template <typename ValueType>
177+
std::unique_ptr<typename Vector<ValueType>::complex_type>
178+
Vector<ValueType>::make_complex_impl() const
179+
{
180+
return make_complex();
181+
}
182+
183+
template <typename ValueType>
184+
std::unique_ptr<typename Vector<ValueType>::real_type>
185+
Vector<ValueType>::get_real_impl() const
186+
{
187+
return get_real();
188+
}
189+
190+
template <typename ValueType>
191+
std::unique_ptr<typename Vector<ValueType>::real_type>
192+
Vector<ValueType>::get_imag_impl() const
193+
{
194+
return get_imag();
195+
}
196+
197+
template <typename ValueType>
198+
void Vector<ValueType>::fill_impl(matrix::any_value_t value)
199+
{
200+
std::visit(
201+
[this](auto value) {
202+
using SndValueType = std::decay_t<decltype(value)>;
203+
if constexpr (!is_complex<ValueType>() &&
204+
is_complex<SndValueType>()) {
205+
GKO_INVALID_STATE(
206+
"Trying to fill a real vector with a complex value.");
207+
} else {
208+
fill(static_cast<ValueType>(value));
209+
}
210+
},
211+
value);
212+
}
213+
214+
template <typename ValueType>
215+
void Vector<ValueType>::scale_impl(matrix::any_const_dense_t alpha)
216+
{
217+
std::visit([this](auto alpha) { scale(alpha); }, alpha);
218+
}
219+
220+
template <typename ValueType>
221+
void Vector<ValueType>::inv_scale_impl(matrix::any_const_dense_t alpha)
222+
{
223+
std::visit([this](auto alpha) { inv_scale(alpha); }, alpha);
224+
}
225+
226+
template <typename ValueType>
227+
std::unique_ptr<const typename Vector<ValueType>::real_type>
228+
Vector<ValueType>::create_real_view_impl() const
229+
{
230+
return create_real_view();
231+
}
232+
233+
template <typename ValueType>
234+
std::unique_ptr<typename Vector<ValueType>::real_type>
235+
Vector<ValueType>::create_real_view_impl()
236+
{
237+
return create_real_view();
238+
}
239+
240+
template <typename ValueType>
241+
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_subview_impl(
242+
matrix::local_span rows, matrix::local_span columns)
243+
{
244+
auto exec = this->get_executor();
245+
auto comm = this->get_communicator();
246+
auto global_rows = this->get_size()[0];
247+
auto global_cols = this->get_size()[1];
248+
comm.all_reduce(exec, &global_rows, 1, MPI_SUM);
249+
comm.all_reduce(exec, &global_cols, 1, MPI_SUM);
250+
return create_subview_impl(rows, columns, global_rows, global_cols);
251+
}
252+
253+
254+
template <typename ValueType>
255+
std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_subview_impl(
256+
matrix::local_span rows, matrix::local_span columns) const
257+
{
258+
auto exec = this->get_executor();
259+
auto comm = this->get_communicator();
260+
auto global_rows = this->get_size()[0];
261+
auto global_cols = this->get_size()[1];
262+
comm.all_reduce(exec, &global_rows, 1, MPI_SUM);
263+
comm.all_reduce(exec, &global_cols, 1, MPI_SUM);
264+
return create_subview_impl(rows, columns, global_rows, global_cols);
265+
}
266+
267+
268+
template <typename ValueType>
269+
std::unique_ptr<const Vector<ValueType>> Vector<ValueType>::create_subview_impl(
270+
matrix::local_span rows, matrix::local_span columns, size_type global_rows,
271+
size_type globals_cols) const
272+
{
273+
// @todo: use const-cast here until dense also has const create_submatrix
274+
return create(
275+
this->get_executor(), this->get_communicator(),
276+
dim<2>{global_rows, globals_cols},
277+
const_cast<local_vector_type&>(local_).create_submatrix(rows, columns));
278+
}
279+
280+
281+
template <typename ValueType>
282+
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_subview_impl(
283+
matrix::local_span rows, matrix::local_span columns, size_type global_rows,
284+
size_type globals_cols)
285+
{
286+
return create(this->get_executor(), this->get_communicator(),
287+
dim<2>{global_rows, globals_cols},
288+
local_.create_submatrix(rows, columns));
289+
}
290+
291+
template <typename ValueType>
292+
void Vector<ValueType>::make_complex_impl(complex_type* result) const
293+
{
294+
make_complex(result);
295+
}
296+
297+
template <typename ValueType>
298+
void Vector<ValueType>::get_real_impl(real_type* result) const
299+
{
300+
get_real(result);
301+
}
302+
303+
template <typename ValueType>
304+
void Vector<ValueType>::get_imag_impl(real_type* result) const
305+
{
306+
get_imag(result);
307+
}
308+
309+
template <typename ValueType>
310+
void Vector<ValueType>::add_scaled_impl(matrix::any_const_dense_t alpha,
311+
const Vector* b)
312+
{
313+
std::visit([this, b](auto alpha) { add_scaled(alpha, b); }, alpha);
314+
}
315+
316+
template <typename ValueType>
317+
void Vector<ValueType>::sub_scaled_impl(matrix::any_const_dense_t alpha,
318+
const Vector* b)
319+
{
320+
std::visit([this, b](auto alpha) { sub_scaled(alpha, b); }, alpha);
321+
}
322+
323+
template <typename ValueType>
324+
void Vector<ValueType>::compute_dot_impl(const Vector* b, Vector* result) const
325+
{
326+
compute_dot(b, result);
327+
}
328+
329+
template <typename ValueType>
330+
void Vector<ValueType>::compute_dot_impl(const Vector* b, Vector* result,
331+
array<char>& tmp) const
332+
{
333+
compute_dot(b, result, tmp);
334+
}
335+
336+
template <typename ValueType>
337+
void Vector<ValueType>::compute_conj_dot_impl(const Vector* b,
338+
Vector* result) const
339+
{
340+
compute_conj_dot(b, result);
341+
}
342+
343+
template <typename ValueType>
344+
void Vector<ValueType>::compute_conj_dot_impl(const Vector* b, Vector* result,
345+
array<char>& tmp) const
346+
{
347+
compute_conj_dot(b, result, tmp);
348+
}
349+
350+
template <typename ValueType>
351+
void Vector<ValueType>::compute_norm2_impl(absolute_type* result) const
352+
{
353+
compute_norm2(result);
354+
}
355+
356+
template <typename ValueType>
357+
void Vector<ValueType>::compute_norm2_impl(absolute_type* result,
358+
array<char>& tmp) const
359+
{
360+
compute_norm2(result, tmp);
361+
}
362+
363+
template <typename ValueType>
364+
void Vector<ValueType>::compute_norm1_impl(absolute_type* result) const
365+
{
366+
compute_norm1(result);
367+
}
368+
369+
template <typename ValueType>
370+
void Vector<ValueType>::compute_norm1_impl(absolute_type* result,
371+
array<char>& tmp) const
372+
{
373+
compute_norm1(result, tmp);
374+
}
375+
376+
161377
template <typename ValueType>
162378
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_config_of(
163379
ptr_param<const Vector> other)
164380
{
165-
// De-referencing `other` before calling the functions (instead of
166-
// using operator `->`) is currently required to be compatible with
167-
// CUDA 10.1.
168-
// Otherwise, it results in a compile error.
169-
return (*other).create_with_same_config();
381+
return other->create_with_same_config_impl();
170382
}
171383

172384

@@ -750,8 +962,8 @@ Vector<ValueType>::create_real_view()
750962

751963

752964
template <typename ValueType>
753-
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_same_config()
754-
const
965+
std::unique_ptr<Vector<ValueType>>
966+
Vector<ValueType>::create_with_same_config_impl() const
755967
{
756968
return Vector::create(
757969
this->get_executor(), this->get_communicator(), this->get_size(),

0 commit comments

Comments
 (0)