Skip to content

Commit f191153

Browse files
Modify RobertaEmbedding forward as custom op method (HabanaAI#996)
This is custom op change as PR HabanaAI#786 follow-up. Removed RobertaEmbedding class from model file and implemented it as CustomOp class in new file. forward_cuda() is the original forward function and forward_hpu() is our specific change. <!--- pyml disable-next-line no-emphasis-as-heading --> --------- Co-authored-by: Michał Kuligowski <mkuligowski@habana.ai>
1 parent 3b20086 commit f191153

3 files changed

Lines changed: 120 additions & 4 deletions

File tree

README_GAUDI.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM
386386

387387
- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
388388
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
389-
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava and qwen models.
389+
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.
390390
- `VLLM_PROMPT_USE_FLEX_ATTENTION` is enabled only for llama model, and allows to use torch.nn.attention.flex_attention instead of FusedSDPA. Note, this requires `VLLM_PROMPT_USE_FUSEDSDPA=0`
391391

392392
# Quantization, FP8 Inference and Model Calibration Process

docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM
361361

362362
- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
363363
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
364-
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava model.
364+
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.
365365

366366
## Quantization, FP8 Inference and Model Calibration Process
367367

vllm/model_executor/models/roberta.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import itertools
4+
import os
45
from typing import Iterable, Optional, Tuple
56

67
import torch
78
from torch import nn
89
from transformers import RobertaConfig
910

1011
from vllm.config import VllmConfig
12+
from vllm.model_executor.custom_op import CustomOp
1113
from vllm.model_executor.layers.pooler import CrossEncodingPooler
1214
from vllm.model_executor.layers.vocab_parallel_embedding import (
1315
VocabParallelEmbedding)
@@ -46,7 +48,8 @@ def encoder_decoder_weights():
4648
if not n.startswith("roberta."))
4749

4850

49-
class RobertaEmbedding(nn.Module):
51+
@CustomOp.register("roberta_embedding")
52+
class RobertaEmbedding(CustomOp):
5053

5154
def __init__(self, config: RobertaConfig):
5255
super().__init__()
@@ -70,7 +73,80 @@ def __init__(self, config: RobertaConfig):
7073
raise ValueError("Only 'absolute' position_embedding_type" +
7174
" is supported")
7275

73-
def forward(
76+
self.use_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL',
77+
'false').lower() == 'true'
78+
79+
def forward_hpu(
80+
self,
81+
input_ids: torch.Tensor,
82+
seq_lens: torch.Tensor,
83+
position_ids: torch.Tensor,
84+
token_type_ids: Optional[torch.Tensor] = None,
85+
) -> torch.Tensor:
86+
input_shape = input_ids.size()
87+
inputs_embeds = self.word_embeddings(input_ids)
88+
89+
# Replace position ids because in RoBERTa models
90+
# they have to start at padding_idx + 1 and ignore
91+
# existing padding tokens
92+
# Modified replace position ids
93+
# for HPU set position_ids and input_ids as [batch_size, bucket_size]
94+
# References:
95+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
96+
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
97+
pos_list = []
98+
token_list = []
99+
if self.use_merged_prefill:
100+
offset = 0
101+
for seq_len in seq_lens:
102+
pos_list.append(position_ids[0][offset:offset + seq_len])
103+
token_list.append(input_ids[0][offset:offset + seq_len])
104+
offset += seq_len
105+
106+
offset = 0
107+
for positions, tokens, seq_len in zip(pos_list, token_list,
108+
seq_lens):
109+
# Verify assumption that incoming position are
110+
# always a sequence from 0 to N.
111+
expected_pos = torch.arange(positions.size()[0],
112+
dtype=torch.long,
113+
device=inputs_embeds.device)
114+
assert torch.equal(positions, expected_pos)
115+
position_ids[0][offset:offset +
116+
seq_len] = create_position_ids_from_input_ids(
117+
tokens, self.padding_idx)
118+
offset += seq_len
119+
else:
120+
for offset in range(position_ids.size()[0]):
121+
pos_list.append(position_ids[offset])
122+
token_list.append(input_ids[offset])
123+
124+
for index, (positions, tokens, seq_len) in enumerate(
125+
zip(pos_list, token_list, seq_lens)):
126+
# Verify assumption that incoming position are
127+
# always a sequence from 0 to N.
128+
expected_pos = torch.arange(positions.size()[0],
129+
dtype=torch.long,
130+
device=inputs_embeds.device)
131+
valid_input_mask = expected_pos < seq_len
132+
expected_pos = expected_pos * valid_input_mask
133+
assert torch.equal(positions, expected_pos)
134+
position_ids[index] = create_position_ids_from_input_ids_hpu(
135+
tokens, self.padding_idx, seq_len)
136+
137+
# Position embeddings.
138+
position_embeddings = self.position_embeddings(position_ids)
139+
if token_type_ids is None:
140+
token_type_ids = torch.zeros(input_shape,
141+
dtype=torch.long,
142+
device=inputs_embeds.device)
143+
144+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
145+
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
146+
embeddings = self.LayerNorm(embeddings)
147+
return embeddings
148+
149+
def forward_native(
74150
self,
75151
input_ids: torch.Tensor,
76152
seq_lens: torch.Tensor,
@@ -118,6 +194,46 @@ def forward(
118194
embeddings = self.LayerNorm(embeddings)
119195
return embeddings
120196

197+
def forward_cuda(
198+
self,
199+
input_ids: torch.Tensor,
200+
seq_lens: torch.Tensor,
201+
position_ids: torch.Tensor,
202+
token_type_ids: Optional[torch.Tensor] = None,
203+
) -> torch.Tensor:
204+
return self.forward_native(input_ids, seq_lens, position_ids,
205+
token_type_ids)
206+
207+
208+
# Adapted from transformers
209+
def create_position_ids_from_input_ids_hpu(input_ids,
210+
padding_idx,
211+
seq_len,
212+
past_key_values_length=0):
213+
"""
214+
Replace non-padding symbols with their position numbers.
215+
Position numbers begin at padding_idx+1. Padding symbols
216+
are ignored. This is modified from fairseq's `utils.make_positions`.
217+
218+
Args:
219+
x: torch.Tensor x:
220+
221+
Returns: torch.Tensor
222+
"""
223+
# The series of casts and type-conversions here are carefully
224+
# balanced to both work with ONNX export and XLA.
225+
valid_input_mask = torch.arange(input_ids.size()[0],
226+
dtype=torch.int,
227+
device=input_ids.device)
228+
valid_input_mask = valid_input_mask < seq_len
229+
230+
mask = input_ids.ne(padding_idx).int()
231+
232+
incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
233+
past_key_values_length) * mask
234+
235+
return (incremental_indices.long() + padding_idx) * valid_input_mask
236+
121237

122238
# Adapted from transformers
123239
def create_position_ids_from_input_ids(input_ids,

0 commit comments

Comments
 (0)