Skip to content

Commit 8164ea9

Browse files
Fixing several bugs in the inference-api and the kernels (#1951)
Co-authored-by: Jeff Rasley <[email protected]>
1 parent b8ff482 commit 8164ea9

File tree

14 files changed

+1087
-134
lines changed

14 files changed

+1087
-134
lines changed

.github/workflows/amd.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,23 @@ jobs:
3737
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
3838
sudo apt-get update
3939
sudo apt-get install -y libaio-dev
40+
41+
- name: Install transformers
42+
run: |
43+
git clone https://github.com/huggingface/transformers
44+
cd transformers
45+
# if needed switch to the last known good SHA until transformers@master is fixed
46+
# git checkout 1cc453d33
47+
git rev-parse --short HEAD
48+
pip install .
49+
4050
# Runs a set of commands using the runners shell
4151
- name: Install deepspeed
4252
run: |
4353
sudo /opt/conda/bin/pip install .[dev,1bit,autotuning]
4454
#python -c "from deepspeed.env_report import cli_main; cli_main()"
4555
ds_report
56+
4657
# Runs a set of commands using the runners shell
4758
- name: Unit tests
4859
run: |

.github/workflows/nv-torch12-p40.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ jobs:
3232
python -c "import torch; print('torch:', torch.__version__, torch)"
3333
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
3434
35+
- name: Install transformers
36+
run: |
37+
git clone https://github.com/huggingface/transformers
38+
cd transformers
39+
# if needed switch to the last known good SHA until transformers@master is fixed
40+
# git checkout 1cc453d33
41+
git rev-parse --short HEAD
42+
pip install .
43+
3544
- name: Install deepspeed
3645
run: |
3746
pip install .[dev,autotuning]

.github/workflows/nv-torch18-v100.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,21 @@ jobs:
3232
pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
3333
python -c "import torch; print('torch:', torch.__version__, torch)"
3434
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
35+
36+
- name: Install transformers
37+
run: |
38+
git clone https://github.com/huggingface/transformers
39+
cd transformers
40+
# if needed switch to the last known good SHA until transformers@master is fixed
41+
# git checkout 1cc453d33
42+
git rev-parse --short HEAD
43+
pip install .
44+
3545
- name: Install deepspeed
3646
run: |
3747
pip install .[dev,1bit,autotuning,sparse_attn]
3848
ds_report
49+
3950
- name: Unit tests
4051
run: |
4152
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch

csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cuda_profiler_api.h>
55
#endif
66

7+
namespace cg = cooperative_groups;
78
namespace cg = cooperative_groups;
89

910
__global__ void apply_rotary_pos_emb(float* mixed_query,
@@ -153,7 +154,9 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
153154
int lane = id & 0x1f;
154155

155156
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
157+
unsigned seq_index = head_id % seq_len;
156158
unsigned offset = head_id * head_size;
159+
unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size;
157160

158161
constexpr unsigned mask[32] = {
159162
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
@@ -171,7 +174,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
171174
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
172175
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
173176
float q = (float)mixed_query[offset + lane];
174-
float k = (float)key_layer[offset + lane];
177+
float k = (float)key_layer[k_offset + lane];
175178
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
176179
float q_rot = (q * rotary_sign);
177180
float k_rot = (k * rotary_sign);
@@ -183,7 +186,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
183186
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
184187

185188
mixed_query[offset + lane] = (__half)q;
186-
key_layer[offset + lane] = (__half)k;
189+
key_layer[k_offset + lane] = (__half)k;
187190

188191
lane += WARP_SIZE;
189192
}
@@ -237,6 +240,7 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
237240
bool,
238241
bool,
239242
cudaStream_t);
243+
240244
/*
241245
__global__ void apply_rotary_pos_emb(float* mixed_query,
242246
float* key_layer,

csrc/transformer/inference/csrc/gelu.cu

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,18 @@ __global__ void gptj_residual_add(float* input,
317317
float4 out = output_cast[offset];
318318
float4 res_vec = attn_cast[offset];
319319
float4 bias_data = bias_cast[offset % intermediate_size];
320-
float4 attn_bias = attnbias_cast[offset % intermediate_size];
321320

322-
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
323-
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
324-
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
325-
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
321+
if (attnbias) {
322+
float4 attn_bias = attnbias_cast[offset % intermediate_size];
323+
data.x += attn_bias.x;
324+
data.y += attn_bias.y;
325+
data.z += attn_bias.z;
326+
data.w += attn_bias.w;
327+
}
328+
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x);
329+
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y);
330+
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z);
331+
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w);
326332

327333
output_cast[offset] = data;
328334
}
@@ -354,13 +360,11 @@ __global__ void gptj_residual_add(__half* input,
354360
float2 res_vec = attn_cast[offset];
355361

356362
float2 bias_vec = bias_cast[offset % intermediate_size];
357-
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
358363

359364
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
360365
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
361366
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
362367
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
363-
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
364368

365369
float2 low_data = __half22float2(vals_half[0]);
366370
float2 high_data = __half22float2(vals_half[1]);
@@ -373,18 +377,21 @@ __global__ void gptj_residual_add(__half* input,
373377

374378
float2 low_bias = __half22float2(bias_half[0]);
375379
float2 high_bias = __half22float2(bias_half[1]);
376-
377-
float2 attn_low_bias = __half22float2(attnbias_half[0]);
378-
float2 attn_high_bias = __half22float2(attnbias_half[1]);
379-
380-
low_data.x =
381-
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
382-
low_data.y =
383-
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
384-
high_data.x =
385-
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
386-
high_data.y =
387-
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
380+
if (attn_bias) {
381+
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
382+
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
383+
float2 attn_low_bias = __half22float2(attnbias_half[0]);
384+
float2 attn_high_bias = __half22float2(attnbias_half[1]);
385+
low_data.x += attn_low_bias.x;
386+
low_data.y += attn_low_bias.y;
387+
high_data.x += attn_high_bias.x;
388+
high_data.y += attn_high_bias.y;
389+
}
390+
391+
low_data.x = low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x));
392+
low_data.y = low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y));
393+
high_data.x = high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x));
394+
high_data.y = high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y));
388395

389396
vals_half[0] = __float22half2_rn(low_data);
390397
vals_half[1] = __float22half2_rn(high_data);

0 commit comments

Comments
 (0)