diff --git a/backends/llamacpp/requirements.txt b/backends/llamacpp/requirements.txt index 5c5d0cc7f11..d19c9e5bd8e 100644 --- a/backends/llamacpp/requirements.txt +++ b/backends/llamacpp/requirements.txt @@ -1,3 +1,3 @@ -transformers==4.48.2 +transformers==4.49 huggingface-hub==0.28.1 hf-transfer==0.1.9 diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 051045bc902..0764bf9246f 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -346,7 +346,7 @@ tqdm==4.66.5 # outlines # peft # transformers -transformers==4.48.2 +transformers==4.49 # via # text-generation-server (pyproject.toml) # compressed-tensors diff --git a/server/requirements_gen.txt b/server/requirements_gen.txt index d9836ad71f8..6d64a34bd8e 100644 --- a/server/requirements_gen.txt +++ b/server/requirements_gen.txt @@ -158,7 +158,7 @@ tqdm==4.67.1 # via # huggingface-hub # transformers -transformers==4.48.2 +transformers==4.49 # via text-generation-server (pyproject.toml) typer==0.15.1 # via text-generation-server (pyproject.toml) diff --git a/server/requirements_intel.txt b/server/requirements_intel.txt index 778d892ece2..c671199f4ae 100644 --- a/server/requirements_intel.txt +++ b/server/requirements_intel.txt @@ -331,7 +331,7 @@ tqdm==4.66.5 # outlines # peft # transformers -transformers==4.48.2 +transformers==4.49 # via # text-generation-server (pyproject.toml) # compressed-tensors diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 65eb998b746..fe7ca572a24 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -331,7 +331,7 @@ tqdm==4.66.5 # outlines # peft # transformers -transformers==4.48.2 +transformers==4.49 # via # text-generation-server (pyproject.toml) # compressed-tensors diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f8150b5e67c..6d53a72b5bb 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -6,17 +6,18 @@ ) from compressed_tensors.quantization import QuantizationType from pydantic import ValidationError -import torch import enum import os - +from typing import Optional, List, Dict +from pathlib import Path from loguru import logger + +import torch +import transformers from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto +from transformers.dynamic_module_utils import get_class_from_dynamic_module from huggingface_hub import hf_hub_download, HfApi -from typing import Optional, List, Dict -from pathlib import Path -import transformers from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -736,7 +737,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return transformers_causal_lm_class.fallback( + return CausalLM.fallback( model_id=model_id, revision=revision, quantize=quantize, @@ -857,6 +858,15 @@ def get_model( lora_adapter_ids=lora_adapter_ids, config_class=GPTNeoXConfig, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: return CausalLM( model_id=model_id, @@ -1054,6 +1064,15 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) + elif FLASH_TRANSFORMERS_BACKEND: + return TransformersFlashCausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: @@ -1467,43 +1486,27 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") - # Fast transformers if available - transformers_model_class = getattr( - transformers, - modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.get(model_type, ""), - None, - ) - if ( - FLASH_TRANSFORMERS_BACKEND - and transformers_model_class is not None - and transformers_model_class._supports_flex_attn - ): - return TransformersFlashCausalLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if sharded: - raise NotImplementedError("sharded is not supported for AutoModel") + auto_map = config_dict.get("auto_map", None) + model_class = None - if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - return Seq2SeqLM.fallback( - model_id, - revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, + # If the model is already in the library + if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + model_class = getattr( + transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type] + ) + elif ( + trust_remote_code + and auto_map is not None + and "AutoModelForCausalLM" in auto_map.keys() + ): + model_class = get_class_from_dynamic_module( + config_dict["auto_map"]["AutoModelForCausalLM"], model_id ) - auto_map = config_dict.get("auto_map", None) - if trust_remote_code and auto_map is not None: - if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM.fallback( + # This means the model is ForCausalLM + if model_class is not None: + if FLASH_TRANSFORMERS_BACKEND and model_class.is_backend_compatible(): + return TransformersFlashCausalLM.fallback( model_id, revision, quantize=quantize, @@ -1511,8 +1514,10 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if "AutoModelForSeq2SeqLM" in auto_map.keys(): - return Seq2SeqLM.fallback( + elif sharded: + raise NotImplementedError("sharded is not supported for AutoModel") + else: + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -1521,6 +1526,25 @@ def get_model( trust_remote_code=trust_remote_code, ) + # Not supported at this point + if sharded: + raise NotImplementedError("sharded is not supported for AutoModel") + + # This means it is a ForSeq2SeqLM model + if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES or ( + trust_remote_code + and auto_map is not None + and "AutoModelForSeq2SeqLM" in auto_map.keys() + ): + return Seq2SeqLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + raise ValueError(f"Unsupported model type {model_type}") diff --git a/server/text_generation_server/models/transformers_flash_causal_lm.py b/server/text_generation_server/models/transformers_flash_causal_lm.py index 36de89b4b4d..8773bfd3a38 100644 --- a/server/text_generation_server/models/transformers_flash_causal_lm.py +++ b/server/text_generation_server/models/transformers_flash_causal_lm.py @@ -81,6 +81,15 @@ def tgi_flash_attention_forward( transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS["tgi"] = tgi_flash_attention_forward +# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states, +# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache +# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due +# to internal constraints it was not (yet?) possible to circumvent +REPLICATED_ATTENTION_MODELS = [ + "olmo2", + "phi3", +] + class TransformersFlashCausalLM(FlashCausalLM): def __init__( @@ -119,6 +128,7 @@ def __init__( truncation_side="left", trust_remote_code=trust_remote_code, ) + model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, @@ -130,6 +140,8 @@ def __init__( tp_plan="auto" if world_size > 1 else None, ) + torch.distributed.barrier(group=self.process_group) + if tokenizer.pad_token_id is None: if model.config.pad_token_id is not None: tokenizer.pad_token_id = model.config.pad_token_id @@ -143,15 +155,19 @@ def __init__( tokenizer.add_special_tokens({"pad_token": "[PAD]"}) self.num_layers = model.config.num_hidden_layers - self.num_heads = model.config.num_attention_heads // self.process_group.size() + self.num_heads = model.config.num_attention_heads self.num_kv_heads = model.config.num_key_value_heads - self.num_kv_heads = ( - self.num_kv_heads // self.process_group.size() - if self.num_kv_heads > 1 - else self.num_kv_heads - ) self.head_size = model.config.hidden_size // model.config.num_attention_heads + # Skip it for models in the exception list + if model.config.model_type not in REPLICATED_ATTENTION_MODELS: + self.num_heads = self.num_heads // self.process_group.size() + self.num_kv_heads = ( + self.num_kv_heads // self.process_group.size() + if self.num_kv_heads > 1 + else self.num_kv_heads + ) + self.cuda_graphs = {} self.kv_cache = [] self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype @@ -186,7 +202,6 @@ def __init__( torch.tensor(1.0, device=device), ) - torch.distributed.barrier(group=self.process_group) # Skip FlashCausalLM init. super(FlashCausalLM, self).__init__( model_id=model_id, @@ -204,6 +219,8 @@ def __init__( self.model.original_forward = self.model.forward self.model.forward = self._model_forward + torch.distributed.barrier(group=self.process_group) + @classmethod def fallback( cls, @@ -237,11 +254,16 @@ def _model_forward( prefill_cache_indices=None, # not used, but passed to match original signature adapter_data=None, # not supported, but passed to match original signature ): - hidden_states = self.model.model.forward( + # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers + logits_to_keep = lm_head_indices if lm_head_indices is not None else 0 + + # This is equivalent to `self.model.forward`, see the monkey patch in __init__ + logits = self.model.original_forward( input_ids=input_ids.unsqueeze(0), # expand dim to fit Transformers position_ids=position_ids.unsqueeze(0), # expand dim to fit Transformers past_key_values=None, # we use self.kv_cache instead of transformers cache object use_cache=False, # we use self.kv_cache instead of transformers cache object + logits_to_keep=logits_to_keep, return_dict=True, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, @@ -251,20 +273,6 @@ def _model_forward( max_s=max_s, kv_head_mapping=self.kv_head_mapping, kv_scales=self.kv_scales, - )[0].squeeze(dim=0) - - # And compute logits from the lm_head, slicing correctly the indices - # NOTE: some logits post-processing (e.g. in gemma2) may be absent here with the split of the modules - # To update with full Transformers support asap - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits = self.model.lm_head(hidden_states) - - # For Granite while next transformers version is released and we can use `lm_head_indices` natively - if hasattr(self.model.config, "logits_scaling"): - logits = logits / self.model.config.logits_scaling - # For Cohere for similar reasons - elif hasattr(self.model, "logit_scale"): - logits = logits * self.model.logit_scale + ).logits.squeeze(dim=0) return logits, None