Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Transformers support #2970

Merged
merged 4 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/llamacpp/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
transformers==4.48.2
transformers==4.49
huggingface-hub==0.28.1
hf-transfer==0.1.9
2 changes: 1 addition & 1 deletion server/requirements_cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ tqdm==4.66.5
# outlines
# peft
# transformers
transformers==4.48.2
transformers==4.49
# via
# text-generation-server (pyproject.toml)
# compressed-tensors
Expand Down
2 changes: 1 addition & 1 deletion server/requirements_gen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ tqdm==4.67.1
# via
# huggingface-hub
# transformers
transformers==4.48.2
transformers==4.49
# via text-generation-server (pyproject.toml)
typer==0.15.1
# via text-generation-server (pyproject.toml)
Expand Down
2 changes: 1 addition & 1 deletion server/requirements_intel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ tqdm==4.66.5
# outlines
# peft
# transformers
transformers==4.48.2
transformers==4.49
# via
# text-generation-server (pyproject.toml)
# compressed-tensors
Expand Down
2 changes: 1 addition & 1 deletion server/requirements_rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ tqdm==4.66.5
# outlines
# peft
# transformers
transformers==4.48.2
transformers==4.49
# via
# text-generation-server (pyproject.toml)
# compressed-tensors
Expand Down
108 changes: 66 additions & 42 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
)
from compressed_tensors.quantization import QuantizationType
from pydantic import ValidationError
import torch
import enum
import os

from typing import Optional, List, Dict
from pathlib import Path
from loguru import logger

import torch
import transformers
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto import modeling_auto
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List, Dict
from pathlib import Path
import transformers

from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
Expand Down Expand Up @@ -736,7 +737,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return transformers_causal_lm_class.fallback(
return CausalLM.fallback(
model_id=model_id,
revision=revision,
quantize=quantize,
Expand Down Expand Up @@ -857,6 +858,15 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
return CausalLM(
model_id=model_id,
Expand Down Expand Up @@ -1054,6 +1064,15 @@ def get_model(
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif FLASH_TRANSFORMERS_BACKEND:
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
Expand Down Expand Up @@ -1467,52 +1486,38 @@ def get_model(
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")

# Fast transformers if available
transformers_model_class = getattr(
transformers,
modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""),
None,
)
if (
FLASH_TRANSFORMERS_BACKEND
and transformers_model_class is not None
and transformers_model_class._supports_flex_attn
):
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
auto_map = config_dict.get("auto_map", None)
model_class = None

if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
# If the model is already in the library
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model_class = getattr(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
)
elif (
trust_remote_code
and auto_map is not None
and "AutoModelForCausalLM" in auto_map.keys()
):
model_class = get_class_from_dynamic_module(
config_dict["auto_map"]["AutoModelForCausalLM"], model_id
)

auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM.fallback(
# This means the model is ForCausalLM
if model_class is not None:
if FLASH_TRANSFORMERS_BACKEND and model_class.is_backend_compatible():
return TransformersFlashCausalLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM.fallback(
elif sharded:
raise NotImplementedError("sharded is not supported for AutoModel")
else:
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
Expand All @@ -1521,6 +1526,25 @@ def get_model(
trust_remote_code=trust_remote_code,
)

# Not supported at this point
if sharded:
raise NotImplementedError("sharded is not supported for AutoModel")

# This means it is a ForSeq2SeqLM model
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES or (
trust_remote_code
and auto_map is not None
and "AutoModelForSeq2SeqLM" in auto_map.keys()
):
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

raise ValueError(f"Unsupported model type {model_type}")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def tgi_flash_attention_forward(

transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward

# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,
# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache
# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due
# to internal constraints it was not (yet?) possible to circumvent
REPLICATED_ATTENTION_MODELS = [
"olmo2",
"phi3",
]


class TransformersFlashCausalLM(FlashCausalLM):
def __init__(
Expand Down Expand Up @@ -119,6 +128,7 @@ def __init__(
truncation_side="left",
trust_remote_code=trust_remote_code,
)

model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
Expand All @@ -130,6 +140,8 @@ def __init__(
tp_plan="auto" if world_size > 1 else None,
)

torch.distributed.barrier(group=self.process_group)

if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
Expand All @@ -143,15 +155,19 @@ def __init__(
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads // self.process_group.size()
self.num_heads = model.config.num_attention_heads
self.num_kv_heads = model.config.num_key_value_heads
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)
self.head_size = model.config.hidden_size // model.config.num_attention_heads

# Skip it for models in the exception list
if model.config.model_type not in REPLICATED_ATTENTION_MODELS:
self.num_heads = self.num_heads // self.process_group.size()
self.num_kv_heads = (
self.num_kv_heads // self.process_group.size()
if self.num_kv_heads > 1
else self.num_kv_heads
)

self.cuda_graphs = {}
self.kv_cache = []
self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
Expand Down Expand Up @@ -186,7 +202,6 @@ def __init__(
torch.tensor(1.0, device=device),
)

torch.distributed.barrier(group=self.process_group)
# Skip FlashCausalLM init.
super(FlashCausalLM, self).__init__(
model_id=model_id,
Expand All @@ -204,6 +219,8 @@ def __init__(
self.model.original_forward = self.model.forward
self.model.forward = self._model_forward

torch.distributed.barrier(group=self.process_group)

@classmethod
def fallback(
cls,
Expand Down Expand Up @@ -237,11 +254,16 @@ def _model_forward(
prefill_cache_indices=None, # not used, but passed to match original signature
adapter_data=None, # not supported, but passed to match original signature
):
hidden_states = self.model.model.forward(
# A value of `None` (i.e. no logit slicing) translates to `0` in Transformers
logits_to_keep = lm_head_indices if lm_head_indices is not None else 0

# This is equivalent to `self.model.forward`, see the monkey patch in __init__
logits = self.model.original_forward(
input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers
position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
logits_to_keep=logits_to_keep,
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
Expand All @@ -251,20 +273,6 @@ def _model_forward(
max_s=max_s,
kv_head_mapping=self.kv_head_mapping,
kv_scales=self.kv_scales,
)[0].squeeze(dim=0)

# And compute logits from the lm_head, slicing correctly the indices
# NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules
# To update with full Transformers support asap
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head(hidden_states)

# For Granite while next transformers version is released and we can use `lm_head_indices` natively
if hasattr(self.model.config, "logits_scaling"):
logits = logits / self.model.config.logits_scaling
# For Cohere for similar reasons
elif hasattr(self.model, "logit_scale"):
logits = logits * self.model.logit_scale
).logits.squeeze(dim=0)

return logits, None
Loading