Skip to content

Commit b85cda2

Browse files
Keyword spotting network and test-benches created
1 parent fd38187 commit b85cda2

13 files changed

+418
-267
lines changed

Diff for: c_reference/src/dscnn.c

+52-27
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,28 @@
99
int DSCNN_LR(float* output_signal, float* input_signal, unsigned in_T, unsigned in_channels, float* mean, float* var,
1010
unsigned affine, float* gamma, float* beta, unsigned in_place, unsigned cnn_hidden, int cnn_padding, unsigned cnn_kernel_size,
1111
const void* cnn_params, int cnn_activations){
12-
unsigned out_T;
13-
// BatchNorm
14-
float* norm_out = (float*)malloc(in_T*in_channels*sizeof(float));
15-
BatchNorm1d(norm_out, input_signal, in_T, in_channels,
16-
mean, var, affine, gamma, beta, in_place, 0.00001);
17-
18-
// CNN
19-
out_T = in_T - cnn_kernel_size + 2*cnn_padding + 1;
20-
Conv1D_LR(output_signal, out_T, cnn_hidden, norm_out,
21-
in_T, in_channels, cnn_padding, cnn_kernel_size,
22-
cnn_params, cnn_activations);
23-
free(norm_out);
24-
12+
13+
unsigned out_T = in_T - cnn_kernel_size + 2*cnn_padding + 1;
14+
if(in_place){
15+
// BatchNorm
16+
BatchNorm1d(0, input_signal, in_T, in_channels,
17+
mean, var, affine, gamma, beta, in_place, 0.00001);
18+
// CNN
19+
Conv1D_LR(output_signal, out_T, cnn_hidden, input_signal,
20+
in_T, in_channels, cnn_padding, cnn_kernel_size,
21+
cnn_params, cnn_activations);
22+
}
23+
else{
24+
// BatchNorm
25+
float* norm_out = (float*)malloc(in_T*in_channels*sizeof(float));
26+
BatchNorm1d(norm_out, input_signal, in_T, in_channels,
27+
mean, var, affine, gamma, beta, in_place, 0.00001);
28+
// CNN
29+
Conv1D_LR(output_signal, out_T, cnn_hidden, norm_out,
30+
in_T, in_channels, cnn_padding, cnn_kernel_size,
31+
cnn_params, cnn_activations);
32+
free(norm_out);
33+
}
2534
return 0;
2635
}
2736

@@ -36,33 +45,49 @@ int DSCNN_LR_Point_Depth(float* output_signal, float* input_signal, unsigned in_
3645
float* act_out= (float*)malloc(in_T * (in_channels>>1) * sizeof(float));
3746
TanhGate(act_out, input_signal, in_T, in_channels);
3847

39-
// Norm
4048
in_channels >>= 1;
41-
BatchNorm1d(0, act_out, in_T, in_channels,
42-
mean, var, affine, gamma, beta, in_place, 0.00001);
43-
44-
// Depth CNN
45-
out_T = in_T - depth_cnn_kernel_size + 2*depth_cnn_padding + 1;
46-
float* depth_out = (float*)malloc(out_T * depth_cnn_hidden * sizeof(float));
47-
Conv1D_Depth(depth_out, out_T, act_out,
48-
in_T, in_channels, depth_cnn_padding, depth_cnn_kernel_size,
49-
depth_cnn_params, depth_cnn_activations);
50-
free(act_out);
49+
float* depth_out;
50+
if(in_place){
51+
// Norm
52+
BatchNorm1d(0, act_out, in_T, in_channels,
53+
mean, var, affine, gamma, beta, in_place, 0.00001);
54+
// Depth CNN
55+
out_T = in_T - depth_cnn_kernel_size + 2*depth_cnn_padding + 1;
56+
depth_out = (float*)malloc(out_T * depth_cnn_hidden * sizeof(float));
57+
Conv1D_Depth(depth_out, out_T, act_out,
58+
in_T, in_channels, depth_cnn_padding, depth_cnn_kernel_size,
59+
depth_cnn_params, depth_cnn_activations);
60+
free(act_out);
61+
}
62+
else{
63+
// Norm
64+
float* norm_out = (float*)malloc(in_T * in_channels * sizeof(float));
65+
BatchNorm1d(norm_out, act_out, in_T, in_channels,
66+
mean, var, affine, gamma, beta, in_place, 0.00001);
67+
free(act_out);
68+
// Depth CNN
69+
out_T = in_T - depth_cnn_kernel_size + 2*depth_cnn_padding + 1;
70+
depth_out = (float*)malloc(out_T * depth_cnn_hidden * sizeof(float));
71+
Conv1D_Depth(depth_out, out_T, norm_out,
72+
in_T, in_channels, depth_cnn_padding, depth_cnn_kernel_size,
73+
depth_cnn_params, depth_cnn_activations);
74+
free(norm_out);
75+
}
5176

5277
// Point CNN
5378
in_T = out_T;
5479
out_T = in_T - point_cnn_kernel_size + 2*point_cnn_padding + 1;
5580
float* point_out = (float*)malloc(out_T * point_cnn_hidden * sizeof(float));
5681
Conv1D_LR(point_out, out_T, point_cnn_hidden, depth_out,
57-
in_T, depth_cnn_hidden, point_cnn_padding, point_cnn_kernel_size,
58-
point_cnn_params, point_cnn_activations);
82+
in_T, depth_cnn_hidden, point_cnn_padding, point_cnn_kernel_size,
83+
point_cnn_params, point_cnn_activations);
5984
free(depth_out);
6085

6186
// Pool
6287
in_T = out_T;
6388
out_T = in_T - pool_kernel_size + 2*pool_padding + 1;
6489
AvgPool1D(output_signal, out_T, point_out, in_T, point_cnn_hidden,
65-
pool_padding, pool_kernel_size, pool_activation);
90+
pool_padding, pool_kernel_size, pool_activation);
6691
free(point_out);
6792

6893
return 0;

Diff for: c_reference/tests/Makefile

+3-5
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@ MODEL_DIR=../models
88
SRC_DIR=../src
99
IFLAGS = -I $(INCLUDE_DIR) -I $(MODEL_DIR)
1010

11-
all: test_rnn test_postcnn test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast
11+
all: test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_keyword_spotting
1212

1313
KWS_DIR=kws
14-
test_postcnn: $(KWS_DIR)/test_postcnn.c $(SRC_DIR)/conv_utils.o $(SRC_DIR)/utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o
15-
$(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm
16-
test_rnn: $(KWS_DIR)/test_rnn.c $(SRC_DIR)/fastgrnn.o $(SRC_DIR)/utils.o
14+
test_keyword_spotting: $(KWS_DIR)/test_keyword_spotting.c $(SRC_DIR)/conv_utils.o $(SRC_DIR)/utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o $(SRC_DIR)/fastgrnn.o
1715
$(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm
1816

1917
DSCNN_DIR=dscnn
@@ -63,7 +61,7 @@ test_quantized_face_detection_fast: $(FACE_DETECTION_DIR)/test_quantized_face_de
6361
.PHONY: clean cleanest
6462

6563
clean:
66-
rm -f *.o *.gch test_rnn test_postcnn test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast
64+
rm -f *.o *.gch test_avg_pool test_conv1d test_conv1d_depth test_conv1d_lr test_conv1d_lr_depth test_dscnn_lr test_dscnn_lr_depth_point test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_keyword_spotting
6765

6866
cleanest: clean
6967
rm *~

Diff for: c_reference/tests/kws/keyword_spotting_io_1.h

+14
Large diffs are not rendered by default.

Diff for: c_reference/tests/kws/keyword_spotting_io_2.h

+14
Large diffs are not rendered by default.

Diff for: c_reference/tests/kws/keyword_spotting_io_3.h

+14
Large diffs are not rendered by default.

Diff for: c_reference/tests/kws/postcnn_io.h

-12
This file was deleted.

Diff for: c_reference/tests/kws/postcnn_params.h

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT license.
33

4-
#define I_T 101
5-
#define O_T 97
6-
#define I_F 400
7-
#define O_F 41
8-
#define DEPTH_FILT 5
9-
#define POOL_FILT 2
10-
#define POINT_FILT 1
4+
#define POST_CNN_I_F 400
5+
#define POST_CNN_O_F 41
6+
#define POST_CNN_DEPTH_FILT 5
7+
#define POST_CNN_DEPTH_PAD 2
8+
#define POST_CNN_DEPTH_ACT 0
9+
#define POST_CNN_POOL 2
10+
#define POST_CNN_POOL_PAD 0
11+
#define POST_CNN_POOL_ACT 0
12+
#define POST_CNN_POINT_FILT 1
13+
#define POST_CNN_POINT_PAD 0
14+
#define POST_CNN_POINT_ACT 0
1115
#define LOW_RANK 50
1216

1317
static float CNN2_BNORM_MEAN[200] = {0.015971357, 6.501892e-05, 0.038588, 0.097353816, 0.015376935, 0.015602097, 0.036061652, 0.03898482, 0.039924037, 0.006442192, 0.03593987, 0.04222, 0.07892465, 0.040109057, 0.04670995, 0.034607653, 0.08764101, 0.0654985, 0.05729792, 0.016125392, 0.04164587, 0.067503765, -0.028087385, -0.015071113, -0.059104428, 0.016970858, 0.066376075, 0.08096571, 0.07588531, -0.052501034, 0.059453476, 0.051389035, 0.068797044, 0.07448322, -0.039696753, 0.035946004, -0.0014653725, 0.088247605, 0.041575737, 0.10277397, 0.0526724, -0.036430024, 0.045154247, 0.025572386, -0.03798233, -0.03787984, 0.038794607, 0.069571935, 0.04857156, -0.0071943738, -0.032641422, 0.09272896, 0.051213413, 0.027016614, 0.016180614, 0.06295424, 0.037129566, 0.032467406, -0.025120424, 0.030453552, 0.06195704, -0.018342853, 0.01071218, 0.08289012, -0.0020247104, -0.025325423, 0.09067498, 0.08771331, -0.00087628997, 0.058319252, -0.008090766, 0.017746149, 0.018517531, 0.086093254, -0.027057227, -0.016622052, 0.046957616, 0.09430719, 0.034838118, -0.012873425, 0.037685357, 0.06542477, 0.019381318, 0.056773286, 0.013210703, 0.061603934, 0.058845002, 0.038058348, 0.059714716, 0.08331196, 0.07923841, 0.062044818, 0.0026301693, 0.06795837, -0.0030768516, 0.06081296, 0.007542504, 0.023446025, 0.029961232, 0.036956295, 0.038154498, 0.0475312, -0.002375686, 0.06024979, 0.008005025, 0.077196926, 0.0789222, 0.08032208, -0.017859492, 0.09035916, 0.0018038531, 0.09352108, 0.01635931, 0.019340234, 0.08410177, 0.039363027, 0.072246365, 0.014267202, 0.09342878, 0.035540942, -0.024538206, -0.029167939, 0.022879561, 0.038776945, -0.008962599, 0.028289083, 0.07435437, 0.033813186, 0.03850727, 0.029347735, -0.0625054, -0.040992733, 0.07381808, 0.066816054, 0.0124026975, 0.071189724, 0.030857489, 0.0691046, 0.0816209, 0.055097148, 0.0014903225, 0.03434464, -0.0059871743, 0.056997567, 0.03338638, 0.006343891, -0.012877238, -0.06268586, -0.021703077, 0.04223155, -0.0071792454, 0.042931616, 0.07306048, 0.069892146, 0.04687559, 0.070037335, 0.04052685, 0.014894321, 0.0052992613, 0.095167615, -0.013672871, 0.041835114, 0.029129915, 0.044560347, 0.050107688, 0.053555574, 0.072749875, 0.078805596, 0.010685162, 0.0071339705, 0.064971276, 0.00079534744, -0.017808538, 0.03917341, 0.08676169, 0.07253845, -0.05193372, -0.019195765, 0.0065121413, 0.03828087, -0.0030095147, 0.04439866, 0.077847764, 0.033887736, -0.0059973625, 0.031676162, -0.06647171, 0.044209484, 0.053383026, -0.009077284, -0.046332154, -0.0028125723, 0.09074347, 0.08587855, -0.0076064896, 0.0626982, -0.021182764, 0.064857244, 0.02505869, -0.017837577};

Diff for: c_reference/tests/kws/precnn_params.h

+23
Large diffs are not rendered by default.

Diff for: c_reference/tests/kws/rnn_io.h

-12
This file was deleted.

0 commit comments

Comments
 (0)