diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 1f6d4a77..bc7d7268 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -273,6 +273,22 @@ async def serve_inner( batch_safety_margin: int, sharded: bool = False, ): + if quantize not in [None, "gptq", "bitsandbytes"]: + raise ValueError(f"Unrecognized quantization method specified: {quantize}") + + if quantize is None and dtype_str == "int8": + print_rank_n("Inferring quantize = bitsandbytes because dtype == int8") + quantize = "bitsandbytes" + + cuda_available = torch.cuda.is_available() + + # Default dtype based on device if not provided + if dtype_str is None: + dtype_str = "float16" if cuda_available else "float32" + + if quantize is not None and not cuda_available: + raise ValueError("Quantization requires CUDA") + if ESTIMATE_MEMORY == "auto" and PAGED_ATTENTION: # fit memory model using flash model in separate process (ensures GPU memory is entirely cleaned up) from text_generation_server.utils.paged import fit_memory_scaling_model @@ -296,22 +312,6 @@ async def serve_inner( ] local_url = server_urls[local_rank] - if quantize not in [None, "gptq", "bitsandbytes"]: - raise ValueError(f"Unrecognized quantization method specified: {quantize}") - - # Default dtype based on device if not provided - if dtype_str is None: - dtype_str = "float16" if torch.cuda.is_available() else "float32" - - if quantize is None and dtype_str == "int8": - print_rank_n("Inferring quantize = bitsandbytes because dtype == int8") - quantize = "bitsandbytes" - - cuda_available = torch.cuda.is_available() - - if quantize is not None and not cuda_available: - raise ValueError("Quantization requires CUDA") - # Set the fraction of cuda/gpu mem available to this process, then load the model if cuda_available and cuda_process_memory_fraction < 1: torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)