Skip to content

Commit 33f8612

Browse files
committed
Rewrite crop cuda kernel
1 parent eeebdab commit 33f8612

File tree

3 files changed

+69
-80
lines changed

3 files changed

+69
-80
lines changed

include/caffe/layers/crop_layer.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ class CropLayer : public Layer<Dtype> {
4141
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
4242
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
4343

44-
vector<int> offsets;
44+
Blob<int> offsets;
45+
Blob<int> src_strides_;
46+
Blob<int> dest_strides_;
4547

4648
private:
4749
// Recursive copy function.
4850
void crop_copy(const vector<Blob<Dtype>*>& bottom,
4951
const vector<Blob<Dtype>*>& top,
50-
const vector<int>& offsets,
52+
const int* offsets,
5153
vector<int> indices,
5254
int cur_dim,
5355
const Dtype* src_data,

src/caffe/layers/crop_layer.cpp

+16-5
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ void CropLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
4040
const int start_axis = bottom[0]->CanonicalAxisIndex(param.axis());
4141

4242
// Initialize offsets to 0 and the new shape to the current shape of the data.
43-
offsets = vector<int>(input_dim, 0);
4443
vector<int> new_shape(bottom[0]->shape());
44+
vector<int> offsets_shape(1, input_dim);
45+
offsets.Reshape(offsets_shape);
46+
int* offset_data = offsets.mutable_cpu_data();
4547

4648
// Determine crop offsets and the new shape post-crop.
4749
for (int i = 0; i < input_dim; ++i) {
@@ -63,15 +65,22 @@ void CropLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
6365
<< "size " << bottom[1]->shape(i) << " and offset " << crop_offset;
6466
}
6567
new_shape[i] = new_size;
66-
offsets[i] = crop_offset;
68+
offset_data[i] = crop_offset;
6769
}
6870
top[0]->Reshape(new_shape);
71+
// Compute strides
72+
src_strides_.Reshape(offsets_shape);
73+
dest_strides_.Reshape(offsets_shape);
74+
for (int i = 0; i < input_dim; ++i) {
75+
src_strides_.mutable_cpu_data()[i] = bottom[0]->count(i + 1, input_dim);
76+
dest_strides_.mutable_cpu_data()[i] = top[0]->count(i + 1, input_dim);
77+
}
6978
}
7079

7180
template <typename Dtype>
7281
void CropLayer<Dtype>::crop_copy(const vector<Blob<Dtype>*>& bottom,
7382
const vector<Blob<Dtype>*>& top,
74-
const vector<int>& offsets,
83+
const int* offsets,
7584
vector<int> indices,
7685
int cur_dim,
7786
const Dtype* src_data,
@@ -115,7 +124,8 @@ void CropLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
115124
std::vector<int> indices(top[0]->num_axes(), 0);
116125
const Dtype* bottom_data = bottom[0]->cpu_data();
117126
Dtype* top_data = top[0]->mutable_cpu_data();
118-
crop_copy(bottom, top, offsets, indices, 0, bottom_data, top_data, true);
127+
crop_copy(bottom, top, offsets.cpu_data(), indices, 0, bottom_data, top_data,
128+
true);
119129
}
120130

121131
template <typename Dtype>
@@ -127,7 +137,8 @@ void CropLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
127137
if (propagate_down[0]) {
128138
caffe_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
129139
std::vector<int> indices(top[0]->num_axes(), 0);
130-
crop_copy(bottom, top, offsets, indices, 0, top_diff, bottom_diff, false);
140+
crop_copy(bottom, top, offsets.cpu_data(), indices, 0, top_diff,
141+
bottom_diff, false);
131142
}
132143
}
133144

src/caffe/layers/crop_layer.cu

+49-73
Original file line numberDiff line numberDiff line change
@@ -4,103 +4,79 @@
44

55
namespace caffe {
66

7-
// Copy (one line per thread) from one array to another, with arbitrary
8-
// strides in the last two dimensions.
7+
__device__ int compute_uncropped_index(
8+
int index,
9+
const int ndims,
10+
const int* src_strides,
11+
const int* dest_strides,
12+
const int* offsets) {
13+
int dest_index = index;
14+
int src_index = 0;
15+
for (int i = 0; i < ndims; ++i) {
16+
int coord = dest_index / dest_strides[i];
17+
dest_index -= coord * dest_strides[i];
18+
src_index += src_strides[i] * (coord + offsets[i]);
19+
}
20+
return src_index;
21+
}
22+
923
template <typename Dtype>
10-
__global__ void copy_kernel(const int n, const int height, const int width,
11-
const int src_inner_stride,
12-
const int dest_inner_stride,
24+
__global__ void crop_kernel_forward(const int nthreads,
25+
const int ndims,
26+
const int* src_strides,
27+
const int* dest_strides,
28+
const int* offsets,
1329
const Dtype* src, Dtype* dest) {
14-
CUDA_KERNEL_LOOP(index, n) {
15-
int src_start = index * src_inner_stride;
16-
int dest_start = index * dest_inner_stride;
17-
for (int i = 0; i < width; ++i) {
18-
dest[dest_start + i] = src[src_start + i];
19-
}
30+
CUDA_KERNEL_LOOP(index, nthreads) {
31+
int src_index = compute_uncropped_index(
32+
index, ndims, src_strides, dest_strides, offsets);
33+
dest[index] = src[src_index];
2034
}
2135
}
2236

2337
template <typename Dtype>
24-
void CropLayer<Dtype>::crop_copy_gpu(const vector<Blob<Dtype>*>& bottom,
25-
const vector<Blob<Dtype>*>& top,
26-
const vector<int>& offsets,
27-
vector<int> indices,
28-
int cur_dim,
29-
const Dtype* src_data,
30-
Dtype* dest_data,
31-
bool is_forward) {
32-
if (cur_dim + 2 < top[0]->num_axes()) {
33-
// We are not yet at the final dimension, call copy recursivley
34-
for (int i = 0; i < top[0]->shape(cur_dim); ++i) {
35-
indices[cur_dim] = i;
36-
crop_copy_gpu(bottom, top, offsets, indices, cur_dim+1,
37-
src_data, dest_data, is_forward);
38-
}
39-
} else {
40-
// We are at the last two dimensions, which are stored continuously in
41-
// memory. With (N,C,H,W)
42-
// (0,1,2,3) cur_dim -> H
43-
// cur_dim+1 -> W
44-
const int lines = top[0]->shape(cur_dim);
45-
const int height = top[0]->shape(cur_dim);
46-
const int width = top[0]->shape(cur_dim+1);
47-
std::vector<int> ind_off(cur_dim+2, 0);
48-
for (int j = 0; j < cur_dim; ++j) {
49-
ind_off[j] = indices[j] + offsets[j];
50-
}
51-
ind_off[cur_dim] = offsets[cur_dim];
52-
ind_off[cur_dim+1] = offsets[cur_dim+1];
53-
// Compute copy strides
54-
const int src_inner_stride = bottom[0]->shape(cur_dim+1);
55-
const int dest_inner_stride = top[0]->shape(cur_dim+1);
56-
57-
if (is_forward) {
58-
const Dtype* bottom_data = bottom[0]->gpu_data() +
59-
bottom[0]->offset(ind_off);
60-
Dtype* top_data = top[0]->mutable_gpu_data() +
61-
top[0]->offset(indices);
62-
// NOLINT_NEXT_LINE(whitespace/operators)
63-
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
64-
lines, height, width,
65-
src_inner_stride,
66-
dest_inner_stride,
67-
bottom_data, top_data);
68-
69-
} else {
70-
const Dtype* top_diff = top[0]->gpu_diff() +
71-
top[0]->offset(indices);
72-
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff() +
73-
bottom[0]->offset(ind_off);
74-
// NOLINT_NEXT_LINE(whitespace/operators)
75-
copy_kernel<<<CAFFE_GET_BLOCKS(lines), CAFFE_CUDA_NUM_THREADS>>>(
76-
lines, height, width,
77-
dest_inner_stride,
78-
src_inner_stride,
79-
top_diff, bottom_diff);
80-
}
38+
__global__ void crop_kernel_backward(const int nthreads,
39+
const int ndims,
40+
const int* src_strides,
41+
const int* dest_strides,
42+
const int* offsets,
43+
Dtype* src, const Dtype* dest) {
44+
CUDA_KERNEL_LOOP(index, nthreads) {
45+
int src_index = compute_uncropped_index(
46+
index, ndims, src_strides, dest_strides, offsets);
47+
src[src_index] = dest[index];
8148
}
8249
}
8350

8451
template <typename Dtype>
8552
void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
8653
const vector<Blob<Dtype>*>& top) {
87-
std::vector<int> indices(top[0]->num_axes(), 0);
8854
const Dtype* bottom_data = bottom[0]->gpu_data();
8955
Dtype* top_data = top[0]->mutable_gpu_data();
90-
crop_copy_gpu(bottom, top, offsets, indices, 0, bottom_data, top_data, true);
56+
int n = top[0]->count();
57+
crop_kernel_forward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n,
58+
bottom[0]->num_axes(),
59+
src_strides_.gpu_data(),
60+
dest_strides_.gpu_data(),
61+
offsets.gpu_data(),
62+
bottom_data, top_data);
9163
}
9264

9365
template <typename Dtype>
9466
void CropLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
9567
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
9668
const Dtype* top_diff = top[0]->gpu_diff();
9769
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
70+
int n = top[0]->count();
9871

9972
if (propagate_down[0]) {
10073
caffe_gpu_set(bottom[0]->count(), static_cast<Dtype>(0), bottom_diff);
101-
std::vector<int> indices(top[0]->num_axes(), 0);
102-
crop_copy_gpu(bottom, top, offsets, indices, 0, top_diff, bottom_diff,
103-
false);
74+
crop_kernel_backward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n,
75+
bottom[0]->num_axes(),
76+
src_strides_.gpu_data(),
77+
dest_strides_.gpu_data(),
78+
offsets.gpu_data(),
79+
bottom_diff, top_diff);
10480
}
10581
}
10682

0 commit comments

Comments
 (0)