This repository was archived by the owner on Jul 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathblas_routine.cpp
More file actions
102 lines (86 loc) · 2.39 KB
/
blas_routine.cpp
File metadata and controls
102 lines (86 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/**
* MATH3512 Matrix Computations, The Australian National University
* Supervisor : Professor Linda Stals
* Student : u6633756 Junming Zhao
* blas_routine.cpp - implementation of BLAS3 operation (a) and (d)
**/
#include "matrix.cpp"
#include <stdexcept>
/** BLAS level3 function a): C = alpha*A*B + beta*C
---- parameters ----
[in] alpha : scalar
[in] beta : scalar
[in] A : m x n matrix
[in] B : n x p matrix
[in,out] C : m x p matrix
**/
/**
* @brief BLAS level3 function a): C = alpha*A*B + beta*C
*
* @param alpha
* @param beta
* @param A mxn matrix
* @param B nxp matrix
* @param C mxp matrix, overwritten
*/
void BLAS_3A(double alpha, double beta,
Matrix A, Matrix B, Matrix &C){
C *= beta;
if (A.get_rows() < B.get_cols()){
// A = alpha * A
C += Matrix::strassen(alpha*A,B);
}
else{
// B = alpha * B
C += Matrix::strassen(A, alpha*B);
}
}
/**
* @brief BLAS level3 function d): C = alpha*(T^-1)*B
*
* @param alpha
* @param T mxm upper triangular matrix
* @param B mxp matrix
* @param s whether to use Strassen's algorithm
* @return Matrix mxp matrix
*/
Matrix BLAS_3D(double alpha, Matrix T, Matrix B, bool s){
int m = T.get_rows();
int p = B.get_cols();
if (m != T.get_cols()){
throw std::invalid_argument("T is not square matrix");
}
if (m != B.get_rows()){
throw std::invalid_argument("dimensions not matching for BLAS3-d");
}
// base case
if (m == 1){
return (1/(T(0,0)))*B;
}
Matrix C(m,p);
C.assign_zeros();
int m2 = m/2;
int p2 = p/2;
Matrix T11 = T.slice(m2,0,m2,0);
Matrix T12 = T.slice(m2,0,m2,m2);
Matrix T22 = T.slice(m2,m2,m2,m2);
Matrix B11 = B.slice(m2,0,p2,0);
Matrix B12 = B.slice(m2,0,p2,p2);
Matrix B21 = B.slice(m2,m2,p2,0);
Matrix B22 = B.slice(m2,m2,p2,p2);
Matrix C11 = C.slice(m2,0,p2,0);
Matrix C12 = C.slice(m2,0,p2,p2);
Matrix C21 = C.slice(m2,m2,p2,0);
Matrix C22 = C.slice(m2,m2,p2,p2);
C21 += BLAS_3D(1, T22, B21, s);
C22 += BLAS_3D(1, T22, B22, s);
if (s) {
C11 += BLAS_3D(1, T11, (B11 - Matrix::strassen(T12, C21)), s);
C12 += BLAS_3D(1, T11, (B12 - Matrix::strassen(T12, C22)), s);
}
else {
C11 += BLAS_3D(1, T11, (B11 - T12*C21), s);
C12 += BLAS_3D(1, T11, (B12 - T12*C22), s);
}
return C;
}