Skip to content

Commit c9034b9

Browse files
committed
Fixed Wrapper behaviour => now correct output
1 parent 20223cd commit c9034b9

5 files changed

Lines changed: 98 additions & 40 deletions

File tree

Readme.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Repo Content:
1+
# Repo Content: (Work in progress)
22

33
This repository is a porting of Ecosin Caffe Branch to Pytorch.\
44
If intrested there is also a Tensorflow Porting of the same code: TODO
@@ -10,11 +10,15 @@ It aim at accelerating on GPU convoluntion in context where kernels has hight sp
1010

1111
It Working principle are based on CSR kernel compression.
1212

13-
The paper that describe the SparseConvolution implementation contained in Ecosin is available at: https://arxiv.org/pdf/1802.10280.pdf\
13+
The paper that describe the SparseConvolution implementation contained in Ecosin is available at: https://arxiv.org/pdf/1802.10280.pdf
14+
1415
The original C++ coda is available in the following repository:\
1516
https://github.com/chenxuhao/caffe-escoin\
17+
1618
More specifically in this file: https://github.com/chenxuhao/caffe-escoin/blob/master/src/caffe/util/math_functions.cu => function caffe_gpu_sconv(...)
19+
1720
# How To use:
21+
TODO => make better instructions\
1822
To use our custom pytorch layer simply compile it with the Makefile then:\
1923
```
2024
import sparse_conv as sp
@@ -34,9 +38,15 @@ make all
3438
```
3539
- Execute the example script:
3640
```
37-
python main.py
41+
python test_behaviour.py
42+
```
43+
- If you see as output the following all works fine.
44+
```
45+
Vanilla vs SparseConv:
46+
SUCCESS => Same Outputs
47+
IN -shape: torch.Size([1, 1, 32, 32])
48+
OUT-shape: torch.Size([1, 6, 28, 28])
3849
```
39-
4050
# How It works:
4151

4252
We have simply written a CUDA => python wrapper using the ctype package of python.\

lib/sparse_conv.so

0 Bytes
Binary file not shown.

sparse_conv.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from sparse_conv_wrapper import *
1212

1313
#--------------------------------------------------------
14-
#-----------TEST THE FUNCTIONS---------------------------
14+
#-----------Sparse Conv Custom Layer---------------------
1515
#--------------------------------------------------------
1616

17-
17+
#THIS LAYER MAKE SENSE ONLY ON CUDA, WE HAVE NOT DONE THE PORTING OF THE C++ VERSION, IF NO CUDA AVAILABLE IT WILL USE THE CLASSIC nn.conv2d
1818

1919
class SparseConv2D(torch.nn.Conv2d):
2020
def __init__(self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: _size_2_t = 0, dilation: _size_2_t = 1, bias: bool = None):
@@ -35,7 +35,7 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: _size_2_t,
3535
def make_kernel_sparse(self,in_height,in_width):
3636
#Copy kernel
3737
k = copy.deepcopy(self.weight.detach())
38-
print(k)
38+
#print(k)
3939
print(f"Kernel Shape:{k.shape}")
4040

4141
#Reshape kernel from (CH , W , H) => (CH , W * H)
@@ -47,21 +47,21 @@ def make_kernel_sparse(self,in_height,in_width):
4747
self.sparse_kernel = x.to_sparse_csr()
4848

4949
#Define the CSR format in tenrors to CUDA
50-
self.rowptr = self.sparse_kernel.crow_indices().cuda()
51-
self.colidx = self.sparse_kernel.col_indices().cuda()
50+
self.rowptr = self.sparse_kernel.crow_indices().to(torch.int).cuda()
51+
self.colidx = self.sparse_kernel.col_indices().to(torch.int).cuda()
5252
self.values = self.sparse_kernel.values().cuda()
5353

54-
print(f"colidx: {self.colidx}")
5554
#Stretch the Kernel to input size: (SEE PAPER OF ESCOTING)
5655
kernel_h = self.weight.shape[2]
5756
kernel_w = self.weight.shape[3]
5857
gpu_kernel_stretch(self.rowptr,self.colidx,self.out_channels,in_height,in_width,self.padding,self.padding,kernel_h,kernel_w)
5958

60-
print(f"rowptr: {self.rowptr}")
61-
print(f"colidx: {self.colidx}")
62-
print(f"values: {self.values}")
59+
#print(f"rowptr: {self.rowptr} => {self.rowptr.type()}")
60+
#print(f"colidx: {self.colidx} => {self.colidx.type()}")
61+
#print(f"values: {self.values} => {self.values.type()}")
6362

6463
return
64+
#End Deprecated => Sequential Code
6565
for out_channel in range(self.out_channels):
6666
print(f"ROW [{out_channel}]")
6767
for j in range(self.rowptr[out_channel] , self.rowptr[out_channel+1]):
@@ -72,10 +72,15 @@ def make_kernel_sparse(self,in_height,in_width):
7272
self.colidx[j] = math.floor((in_channel*(in_height + self.padding) + kernel_row)*(in_width + self.padding) + kernel_col)
7373
print(f"Changing colidx[{j}] from {col} => {self.colidx[j]}")
7474

75-
#End
75+
7676

7777

7878
def forward(self, input: Tensor) -> Tensor: # input: HWCN
79+
#TODO CHECK if CUDA is available and in case not use nn.conv2D forward
80+
81+
#TODO CHECK SPARSITY
82+
83+
#TODO ADD "Group > 1" compatibility
7984

8085
#Training mode
8186
if self.training:
@@ -107,12 +112,21 @@ def forward(self, input: Tensor) -> Tensor: # input: HWCN
107112
self.make_kernel_sparse(in_height,in_width)
108113

109114
#Allocate outputs
110-
output = torch.zeros(batch_size, self.out_channels,output_h, output_w).cuda()
111-
115+
output = torch.zeros(batch_size, self.out_channels,output_h, output_w).cuda()
116+
input = input.cuda()
112117
#Calculate sparse conv
113118
sparse_conv(input,self.in_channels,1,in_height,in_width,self.padding,self.padding,self.stride,self.stride,self.dilation,self.dilation,self.rowptr,self.colidx,self.values,kernel_h,kernel_w,self.bias,output,self.out_channels,self.groups)
114119

115120
#Return output
116121
return output
117122

118123

124+
125+
#-------------------------------------------------------------------------------------------------------
126+
#-----------Helper Model Module with some custom method to initialize sparseConv layers-----------------
127+
#-------------------------------------------------------------------------------------------------------
128+
''''
129+
class SparseModel(nn.Module):
130+
def __init__():
131+
'''
132+

src/sparse_conv.cu

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ void caffe_gpu_stretch(const int *rowptr, int *colidx, int M,
7979
void gpu_kernel_stretch(const void *rowptr, void *colidx, int M,
8080
int height, int width, int pad_h, int pad_w, int kernel_h, int kernel_w){
8181

82-
printf("Stretch kernel\n");
82+
//printf("Stretch kernel\n");
8383
caffe_gpu_stretch((int*)rowptr,(int*)colidx,M,height,width,pad_h,pad_w,kernel_h,kernel_w);
8484
}
8585

@@ -208,7 +208,25 @@ __global__ void sconv_shm(const int * rowptr, const int * colidx, const Dtype *
208208
const int output_col = blockIdx.x * blockDim.x + threadIdx.x;
209209
const int oc = blockIdx.z * blockDim.z + threadIdx.z;
210210

211+
/*
212+
if(output_row==0 && output_col==0 && oc == 0){
213+
for(int o=0;o<1;o++){
214+
for(int i=0;i<10;i++){
215+
printf("\nRow[%d]\n",i);
216+
for(int j=0;j<10;j++){
217+
printf("%f,",input[i*10 + j]);
218+
}
219+
}
220+
}
221+
}
211222
223+
if(output_row==0 && output_col==0 && oc == 0){
224+
for(int i=0;i<6;i++){
225+
printf("\nRowptr[%d] = %d\n",i,rowptr[i]);
226+
}
227+
}
228+
*/
229+
212230

213231
__shared__ Dtype values_s[SHMEM_SIZE];
214232
__shared__ int colidx_s[SHMEM_SIZE];
@@ -220,9 +238,10 @@ __global__ void sconv_shm(const int * rowptr, const int * colidx, const Dtype *
220238
const int length = row_end - row_start;
221239
const int BLK_SIZE = TILE_H * TILE_W;
222240

223-
//printf("Thread: (%d,%d,%d) => [%d,%d]\n",output_row,output_col,oc,row_start,row_end);
224241
Dtype sum = 0;
225242
//Dtype sum = bias[oc];
243+
244+
//Each thread has a dedicated 2D grid of weights in shared memory
226245
for(int i = 0; i < length; i += SHMEM_SIZE) {
227246
int base_addr = row_start + i;
228247
for (int j = 0; j < SHMEM_SIZE; j += BLK_SIZE) {
@@ -238,27 +257,32 @@ __global__ void sconv_shm(const int * rowptr, const int * colidx, const Dtype *
238257
__syncthreads();
239258
}
240259

241-
if (output_row < output_h) {
242-
if (output_col < output_w) {
260+
//Actual Sparse Conv Code
261+
262+
//Each thread compute a specific output pixel (With tiling since we cant actualy spawn that many threads)
263+
if (output_row < output_h) { //Check we are accessing valid row
264+
if (output_col < output_w) { //Check we are accessing valid col
265+
//printf("Thread: (%d,%d,%d) => [%d,%d]\n",output_row,output_col,oc,row_start,row_end);
266+
//Projecting the output pixel on the input grid
243267
const Dtype *in_ptr = input + output_row * stride_h * (width + pad_w) + output_col * stride_w;
244268
int end = MIN(SHMEM_SIZE, length - i);
245269
for (int off = 0; off < end; ++off) {
246-
Dtype weight = values_s[off];
247-
int pos = colidx_s[off];
248-
sum += weight * __ldg(in_ptr+pos); //This instruction store the value in the L1 cache
270+
Dtype weight = values_s[off]; //Get the current NonZeroWeight to elaborate
271+
int pos = colidx_s[off]; //Add the projected offset saved in colIdx to get the corresponding input pixel
272+
sum += weight * __ldg(in_ptr+pos); //This instruction store the value in the L1 cache
249273
}
250274
}
251275
}
252276
__syncthreads();
253277
}
254278

255-
//if (oc < num_oc) {
279+
if (oc < num_oc) {
256280
if (output_row < output_h) {
257281
if (output_col < output_w) {
258282
output[(oc * output_h + output_row) * output_w + output_col] = sum;
259283
}
260284
}
261-
//}
285+
}
262286
}
263287

264288
template <typename Dtype, int TILE_H, int TILE_W, int WIDTH, int K, int PAD = (K - 1) / 2>
@@ -535,6 +559,7 @@ void caffe_gpu_sconv(bool FUSE_RELU, int num, const Dtype *input, const int ifma
535559
const int *colidx, const Dtype *values, const Dtype *bias, int height, int width, int pad_h, int pad_w,
536560
int stride_h, int stride_w, int dilation_h, int dilation_w, int kernel_h, int kernel_w, Dtype *output, int num_oc, int num_groups)
537561
{
562+
/*
538563
printf("\033[93m");
539564
printf("Num of inputs: %d\n",num);
540565
printf("Num of input channels: %d\n",num);
@@ -546,46 +571,49 @@ void caffe_gpu_sconv(bool FUSE_RELU, int num, const Dtype *input, const int ifma
546571
printf("Padding Shape (%d,%d)\n",pad_h,pad_w);
547572
printf("Stride Shape (%d,%d)\n",stride_h,stride_w);
548573
printf("Dilation Shape (%d,%d)\n",dilation_h,dilation_w);
549-
/*
574+
550575
printf("Rowptr: [");
551576
for(int i=0;i<6;i++){
552577
printf("%f,",rowptr[i]);
553-
}*/
578+
}
554579
printf("]\n");
580+
*/
555581
//print_device_info(0);
556582
//Compute the output shape based on the
557583
const int output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
558584
const int output_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
559585

560-
printf("\033[93mOUTPUT SHAPE: [%d,%d]\n",output_h,output_w);
586+
//printf("\033[93mOUTPUT SHAPE: [%d,%d]\n",output_h,output_w);
561587
//printf("We have a sparse conv with:\n-INPUT:(%d,%d,%d)\n-Kernel:(%d,%d,%d)-Output:(%d,%d,%d)\n",);
562588
int TILE_H = 16;
563589
int TILE_W = 16;
564590
int ntiles_h = (output_h - 1) / TILE_H + 1;
565591
int ntiles_w = (output_w - 1) / TILE_W + 1;
566592
int nblocks = (num_oc - 1) / OC_BLOCK + 1;
567-
//printf("num=%d, nblocks=%d, num_oc=%d\n", num, nblocks, num_oc);
593+
568594
//printf("height=%d, width=%d, output_h=%d, output_w=%d\n", height, width, output_h, output_w);
569595
//printf("stride_h=%d, stride_w=%d, pad_h=%d, pad_width=%d\n", stride_h, stride_w, pad_h, pad_w);
570596

571597
//Dilatation is a kernel with some empty columns and rows
572598
if (dilation_h != 1 || dilation_w != 1) {
573-
printf("SparseConv With Dilation\n");
599+
//printf("SparseConv With Dilation\n");
574600
dim3 threads(TILE_W, TILE_H, OC_BLOCK);
575601
dim3 grid(ntiles_w, ntiles_h, nblocks);
576602
sconv_dilation<Dtype><<<grid, threads>>>(rowptr, colidx, values, input,
577603
height, width, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
578604
kernel_h, kernel_w, bias, output, num_oc, output_h, output_w);
579605
} else if (stride_h == 1 && stride_w == 1 && height == width && kernel_h == kernel_w && pad_h == pad_w) {
580606
if(FUSE_RELU) {
581-
printf("SparseConv With Dilation\n");
607+
//printf("SparseConv With Symmetric shapes + RELU => eg square\n");
582608
dim3 threads(16, 16, OC_BLOCK);
583609
dim3 grid(ntiles_w, ntiles_h, nblocks);
584610
sconv_relu_tiled<Dtype,16,16><<<grid, threads>>>(rowptr, colidx, values, input,
585611
height, width, pad_h, pad_w, stride_h, stride_w, kernel_h, kernel_w,
586612
bias, output, num_oc, output_h, output_w);
587613
} else {
614+
588615
if(num == 1) {
616+
//printf("SparseConv With Symmetric shapes (BatchSize = 1) => eg square\n");
589617
if(height == 27) {
590618
//if(0) {
591619
ntiles_w = DIVIDE_INTO(output_w, 32);
@@ -612,13 +640,15 @@ void caffe_gpu_sconv(bool FUSE_RELU, int num, const Dtype *input, const int ifma
612640
nblocks = std::min(max_blocks, nblocks);
613641
//printf("Launching CUDA solver: %d CTAs (max %d/SM), %d threads/CTA ...\n", nblocks, max_blocks_per_SM, nthreads);
614642
//*/
643+
//printf("num=%d, ntiles_h=%d, ntiles_w=%d, nblocks=%d, num_oc=%d\n", num, ntiles_h, ntiles_w, nblocks, num_oc);
615644
dim3 threads(TILE_W, TILE_H, 1);
616645
dim3 grid(ntiles_w, ntiles_h, nblocks);
617646
sconv_shm<Dtype,16,16><<<grid, threads>>>(rowptr, colidx, values, input,
618647
height, width, pad_h, pad_w, stride_h, stride_w, kernel_h, kernel_w,
619648
bias, output, num_oc, output_h, output_w);
620649
}
621650
} else {
651+
//printf("SparseConv With Symmetric shapes (BatchSize = N) => eg square\n");
622652
dim3 threads(16, 16, 1);
623653
//if(nblocks >= 128 && nblocks < 224) {
624654
if(0) {
@@ -637,6 +667,7 @@ void caffe_gpu_sconv(bool FUSE_RELU, int num, const Dtype *input, const int ifma
637667
}
638668
}
639669
} else {
670+
//printf("SparseConv With Asymmetric shapes => eg Rectangle\n");
640671
// fall through to the default path
641672
dim3 threads(TILE_W, TILE_H, OC_BLOCK);
642673
dim3 grid(ntiles_w, ntiles_h, nblocks);
@@ -659,10 +690,10 @@ void caffe_gpu_sconv(bool FUSE_RELU, int num, const Dtype *input, const int ifma
659690
}
660691
}
661692

662-
663-
printf("End of computation\n");
664-
printf("\033[0m");
665693
CudaTest("sconv_kernel solving failed");
694+
695+
//printf("End of computation\n");
696+
//printf("\033[0m");
666697
}
667698

668699
template void caffe_gpu_sconv<int>(bool FUSE_RELU, int num, const int *input, const int ifmap_size, const int *rowptr,

main.py renamed to test_behaviour.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.nn.functional as F
66
import torch.nn.utils.prune as prune
7+
import copy
78

89
class LeNet5(nn.Module):
910
"""
@@ -111,23 +112,25 @@ def pruning_model_random(model, px):
111112
#Generate a dummy input to give the convolution
112113
dummy_input = torch.randn(1, 1,IMG_SIZE,IMG_SIZE, dtype=torch.float).to(device)
113114
dummy_input = dummy_input.cuda()
115+
input = copy.deepcopy(dummy_input)
116+
input = input.cuda()
114117

115118
#Generate sparse conv ouptput
116119
sp_out = model.conv1.forward(dummy_input)
117120

118121
#Generate vanilla conv output
119122
model.conv1.use_sparse = False
120-
out = model.conv1.forward(dummy_input)
123+
out = model.conv1.forward(input)
121124

122125
#TODO Compare vanilla vs sparse output
123126

124-
print(f"SP_OUT: {sp_out}")
125-
print(f"OUT: {out}")
126-
127+
#print(f"SP_OUT: {sp_out}")
128+
#print(f"OUT: {out}")
129+
print("Vanilla vs SparseConv:")
127130
if torch.all(sp_out.eq(out)):
128131
print("\033[92mSUCCESS => Same Outputs\033[0m")
129132
else:
130133
print("\033[91mFAIL => Divergent Outputs\033[0m")
131134

132-
print(dummy_input.shape)
133-
print(out.shape)
135+
print(f"IN -shape: {dummy_input.shape}")
136+
print(f"OUT-shape: {out.shape}")

0 commit comments

Comments
 (0)