Skip to content

Commit

Permalink
fix: move parameter validation before fit_memory_scaling_model
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <[email protected]>
  • Loading branch information
tjohnson31415 committed May 30, 2024
1 parent 9b4aea8 commit 8065972
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,22 @@ async def serve_inner(
batch_safety_margin: int,
sharded: bool = False,
):
if quantize not in [None, "gptq", "bitsandbytes"]:
raise ValueError(f"Unrecognized quantization method specified: {quantize}")

if quantize is None and dtype_str == "int8":
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
quantize = "bitsandbytes"

cuda_available = torch.cuda.is_available()

# Default dtype based on device if not provided
if dtype_str is None:
dtype_str = "float16" if cuda_available else "float32"

if quantize is not None and not cuda_available:
raise ValueError("Quantization requires CUDA")

if ESTIMATE_MEMORY == "auto" and PAGED_ATTENTION:
# fit memory model using flash model in separate process (ensures GPU memory is entirely cleaned up)
from text_generation_server.utils.paged import fit_memory_scaling_model
Expand All @@ -296,22 +312,6 @@ async def serve_inner(
]
local_url = server_urls[local_rank]

if quantize not in [None, "gptq", "bitsandbytes"]:
raise ValueError(f"Unrecognized quantization method specified: {quantize}")

# Default dtype based on device if not provided
if dtype_str is None:
dtype_str = "float16" if torch.cuda.is_available() else "float32"

if quantize is None and dtype_str == "int8":
print_rank_n("Inferring quantize = bitsandbytes because dtype == int8")
quantize = "bitsandbytes"

cuda_available = torch.cuda.is_available()

if quantize is not None and not cuda_available:
raise ValueError("Quantization requires CUDA")

# Set the fraction of cuda/gpu mem available to this process, then load the model
if cuda_available and cuda_process_memory_fraction < 1:
torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)
Expand Down

0 comments on commit 8065972

Please sign in to comment.