Skip to content

Commit 58508ec

Browse files
committed
aarch64: Remove computation cache from conv, matmul, inner product
conv, matmul and inner product have been converted to being stateless in ACL, thus there is no longer a need for the computation cache. This has been replaced with a placeholder cache as the cache is still expected to exist. The placeholder cache can be removed once deconv is also converted to stateless. No change in performance.
1 parent 41b6162 commit 58508ec

File tree

4 files changed

+86
-115
lines changed

4 files changed

+86
-115
lines changed

include/ideep/lru_cache.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,20 @@ class computation_cache {
212212
return t_store_;
213213
}
214214
};
215+
216+
struct placeholder_lru {
217+
void clear() {};
218+
};
219+
220+
// Placeholder cache to be used where computation cache
221+
// isn't required but expected. Missing methods should
222+
// be added as required
223+
class placeholder_computation_cache {
224+
public:
225+
static inline placeholder_lru t_store() {
226+
return placeholder_lru();
227+
}
228+
};
215229
} // namespace utils
216230
} // namespace ideep
217231
#endif

include/ideep/operators/conv.hpp

Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ struct conv_deconv_utils {
294294
struct convolution_forward
295295
: public dnnl::convolution_forward,
296296
#ifdef __aarch64__
297-
utils::computation_cache<std::pair<dnnl::convolution_forward::primitive_desc, dnnl::convolution_forward> > {
297+
utils::placeholder_computation_cache {
298298
#else
299299
utils::computation_cache<dnnl::convolution_forward::primitive_desc> {
300300
#endif
@@ -1417,54 +1417,37 @@ struct convolution_forward
14171417
dst_desc_query = dst_desc.to_format(memory_format);
14181418
}
14191419

1420-
auto key = utils::create_key(
1421-
aprop_kind,
1422-
aalgorithm,
1423-
src_desc_query,
1424-
weights_desc_query,
1425-
bias_desc_query,
1426-
dst_desc_query,
1427-
with_bias,
1428-
strides,
1429-
dilates,
1430-
padding_l,
1431-
padding_r,
1432-
attr,
1433-
omp_get_max_threads());
1434-
14351420
dnnl::convolution_forward::primitive_desc pd;
1436-
return fetch_or_create(key, [&]() {
1437-
if (with_bias) {
1438-
pd = primitive_desc(
1439-
aengine,
1440-
aprop_kind,
1441-
aalgorithm,
1442-
src_desc_query,
1443-
weights_desc_query,
1444-
bias_desc_query,
1445-
dst_desc_query,
1446-
strides,
1447-
dilates,
1448-
padding_l,
1449-
padding_r,
1450-
attr
1451-
);
1452-
} else {
1421+
if (with_bias) {
14531422
pd = primitive_desc(
1454-
aengine,
1455-
aprop_kind,
1456-
aalgorithm,
1457-
src_desc_query,
1458-
weights_desc_query,
1459-
dst_desc_query,
1460-
strides,
1461-
dilates,
1462-
padding_l,
1463-
padding_r,
1464-
attr);
1465-
}
1466-
return std::make_pair(pd, super(pd));
1467-
});
1423+
aengine,
1424+
aprop_kind,
1425+
aalgorithm,
1426+
src_desc_query,
1427+
weights_desc_query,
1428+
bias_desc_query,
1429+
dst_desc_query,
1430+
strides,
1431+
dilates,
1432+
padding_l,
1433+
padding_r,
1434+
attr
1435+
);
1436+
} else {
1437+
pd = primitive_desc(
1438+
aengine,
1439+
aprop_kind,
1440+
aalgorithm,
1441+
src_desc_query,
1442+
weights_desc_query,
1443+
dst_desc_query,
1444+
strides,
1445+
dilates,
1446+
padding_l,
1447+
padding_r,
1448+
attr);
1449+
}
1450+
return {pd, super(pd)};
14681451
}
14691452
#else
14701453
template <bool with_bias>

include/ideep/operators/inner_product.hpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct inner_product_forward_params {
3535
struct inner_product_forward
3636
: public dnnl::inner_product_forward,
3737
#ifdef __aarch64__
38-
utils::computation_cache<std::pair<dnnl::inner_product_forward::primitive_desc, dnnl::inner_product_forward>> {
38+
utils::placeholder_computation_cache {
3939
#else
4040
utils::computation_cache<dnnl::inner_product_forward::primitive_desc> {
4141
#endif
@@ -254,27 +254,15 @@ struct inner_product_forward
254254
const attr_t& attr = attr_t(),
255255
const prop_kind aprop_kind = prop_kind::forward,
256256
const engine& aengine = engine::cpu_engine()) {
257-
auto key = utils::create_key(
258-
aprop_kind,
259-
src_desc,
260-
weights_desc,
261-
bias_desc,
262-
dst_desc,
263-
attr,
264-
with_bias,
265-
omp_get_max_threads());
266-
267-
return fetch_or_create(key, [&]() {
268-
dnnl::inner_product_forward::primitive_desc pd;
269-
if (with_bias) {
270-
pd = primitive_desc(
271-
aengine, aprop_kind, src_desc, weights_desc, bias_desc, dst_desc, attr);
272-
} else {
273-
pd = primitive_desc(
274-
aengine, aprop_kind, src_desc, weights_desc, dst_desc, attr);
275-
}
276-
return std::make_pair(pd, super(pd));
277-
});
257+
dnnl::inner_product_forward::primitive_desc pd;
258+
if (with_bias) {
259+
pd = primitive_desc(
260+
aengine, aprop_kind, src_desc, weights_desc, bias_desc, dst_desc, attr);
261+
} else {
262+
pd = primitive_desc(
263+
aengine, aprop_kind, src_desc, weights_desc, dst_desc, attr);
264+
}
265+
return std::make_pair(pd, super(pd));
278266
}
279267
#else
280268
static primitive_desc get_primitive_desc(

include/ideep/operators/matmul.hpp

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct matmul_forward_params {
4949

5050
struct matmul_forward : public dnnl::matmul,
5151
#ifdef __aarch64__
52-
utils::computation_cache<std::pair<dnnl::matmul::primitive_desc, dnnl::matmul> > {
52+
utils::placeholder_computation_cache {
5353
#else
5454
utils::computation_cache<dnnl::matmul::primitive_desc> {
5555
#endif
@@ -887,6 +887,16 @@ struct matmul_forward : public dnnl::matmul,
887887
dst_desc = dst.get_desc().to_type(dst_data_type);
888888
}
889889

890+
#ifdef __aarch64__
891+
if (with_bias) {
892+
param.pd = primitive_desc(
893+
aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr);
894+
} else {
895+
param.pd = primitive_desc(
896+
aengine, src_desc, weights_desc, dst_desc, op_attr);
897+
}
898+
param.primitive = std::move(super(param.pd));
899+
#else
890900
auto key = utils::create_key(
891901
src_desc,
892902
weights_desc,
@@ -895,21 +905,6 @@ struct matmul_forward : public dnnl::matmul,
895905
op_attr,
896906
with_bias,
897907
omp_get_max_threads());
898-
899-
#ifdef __aarch64__
900-
auto pd_pair = fetch_or_create(key, [&]() {
901-
if (with_bias) {
902-
param.pd = primitive_desc(
903-
aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr);
904-
} else {
905-
param.pd = primitive_desc(
906-
aengine, src_desc, weights_desc, dst_desc, op_attr);
907-
}
908-
return std::make_pair(param.pd, super(param.pd));
909-
});
910-
param.pd = std::move(pd_pair.first);
911-
param.primitive = std::move(pd_pair.second);
912-
#else
913908
param.pd = fetch_or_create(key, [&]() {
914909
if (with_bias) {
915910
return primitive_desc(
@@ -1070,6 +1065,16 @@ struct matmul_forward : public dnnl::matmul,
10701065
if (!dst.is_empty()) {
10711066
dst_desc = dst.get_desc().to_type(dst_data_type);
10721067
}
1068+
#ifdef __aarch64__
1069+
if (with_bias) {
1070+
param.pd = primitive_desc(
1071+
aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr);
1072+
} else {
1073+
param.pd = primitive_desc(
1074+
aengine, src_desc, weights_desc, dst_desc, op_attr);
1075+
}
1076+
param.primitive = std::move(super(param.pd));
1077+
#else
10731078
auto key = utils::create_key(
10741079
src_desc,
10751080
weights_desc,
@@ -1078,20 +1083,6 @@ struct matmul_forward : public dnnl::matmul,
10781083
op_attr,
10791084
with_bias,
10801085
omp_get_max_threads());
1081-
#ifdef __aarch64__
1082-
auto pd_pair = fetch_or_create(key, [&]() {
1083-
if (with_bias) {
1084-
param.pd = primitive_desc(
1085-
aengine, src_desc, weights_desc, bias_desc, dst_desc, op_attr);
1086-
} else {
1087-
param.pd = primitive_desc(
1088-
aengine, src_desc, weights_desc, dst_desc, op_attr);
1089-
}
1090-
return std::make_pair(param.pd, super(param.pd));
1091-
});
1092-
param.pd = std::move(pd_pair.first);
1093-
param.primitive = std::move(pd_pair.second);
1094-
#else
10951086
param.pd = fetch_or_create(key, [&]() {
10961087
if (with_bias) {
10971088
return primitive_desc(
@@ -1225,6 +1216,17 @@ struct matmul_forward : public dnnl::matmul,
12251216
bias_desc = {bias.get_dims(), data_type::f32, bia_tag};
12261217
}
12271218

1219+
// Create pd and primitive
1220+
#ifdef __aarch64__
1221+
if (with_bias) {
1222+
param.pd = primitive_desc(
1223+
aengine, src_desc, weights.get_desc(), bias_desc, dst_desc, op_attr);
1224+
} else {
1225+
param.pd = primitive_desc(
1226+
aengine, src_desc, weights.get_desc(), dst_desc, op_attr);
1227+
}
1228+
param.primitive = std::move(super(param.pd));
1229+
#else
12281230
auto key = utils::create_key(
12291231
src_desc,
12301232
weights.get_desc(),
@@ -1233,22 +1235,6 @@ struct matmul_forward : public dnnl::matmul,
12331235
op_attr,
12341236
with_bias,
12351237
omp_get_max_threads());
1236-
1237-
// Create pd and primitive
1238-
#ifdef __aarch64__
1239-
auto pd_pair = fetch_or_create(key, [&]() {
1240-
if (with_bias) {
1241-
param.pd = primitive_desc(
1242-
aengine, src_desc, weights.get_desc(), bias_desc, dst_desc, op_attr);
1243-
} else {
1244-
param.pd = primitive_desc(
1245-
aengine, src_desc, weights.get_desc(), dst_desc, op_attr);
1246-
}
1247-
return std::make_pair(param.pd, super(param.pd));
1248-
});
1249-
param.pd = std::move(pd_pair.first);
1250-
param.primitive = std::move(pd_pair.second);
1251-
#else
12521238
param.pd = fetch_or_create(key, [&]() {
12531239
if (with_bias) {
12541240
return primitive_desc(

0 commit comments

Comments
 (0)