diff --git a/Feature_Guide/Data_Pipelines/model_repository/model1/config.pbtxt b/Feature_Guide/Data_Pipelines/model_repository/model1/config.pbtxt index 55e32cc5..8e171674 100644 --- a/Feature_Guide/Data_Pipelines/model_repository/model1/config.pbtxt +++ b/Feature_Guide/Data_Pipelines/model_repository/model1/config.pbtxt @@ -82,3 +82,25 @@ output [ dims: [-1] } ] + +# The configuration of engineArgs +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} + +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} + +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.5" + } +} \ No newline at end of file diff --git a/Popular_Models_Guide/Llama2/trtllm_guide.md b/Popular_Models_Guide/Llama2/trtllm_guide.md new file mode 100644 index 00000000..309ea27f --- /dev/null +++ b/Popular_Models_Guide/Llama2/trtllm_guide.md @@ -0,0 +1,161 @@ +<!-- +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--> + +## Pre-build instructions + +For this tutorial, we are using the Llama2-7B HuggingFace model with pre-trained weights. +Clone the repo of the model with weights and tokens [here](https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main). +You will need to get permissions for the Llama2 repository as well as get access to the huggingface cli. To get access to the huggingface cli, go here: [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). + +## Installation + +1. The installation starts with cloning the TensorRT-LLM Backend and update the TensorRT-LLM submodule: +```bash +git clone https://github.com/triton-inference-server/tensorrtllm_backend.git --branch <release branch> +# Update the submodules +cd tensorrtllm_backend +# Install git-lfs if needed +apt-get update && apt-get install git-lfs -y --no-install-recommends +git lfs install +git submodule update --init --recursive +``` + +2. Launch Triton docker container with TensorRT-LLM backend. Note I'm mounting `tensorrtllm_backend` to `/tensorrtllm_backend` and the Llama2 model to `/Llama-2-7b-hf` in the docker container for simplicity. Make an `engines` folder outside docker to reuse engines for future runs. +```bash +docker run --rm -it --net host --shm-size=2g \ + --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \ + -v /path/to/tensorrtllm_backend:/tensorrtllm_backend \ + -v /path/to/Llama2/repo:/Llama-2-7b-hf \ + -v /path/to/engines:/engines \ + nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3 +``` + +Alternatively, you can follow instructions [here](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/README.md) to build Triton Server with Tensorrt-LLM Backend if you want to build a specialized container. + +Don't forget to allow gpu usage when you launch the container. + +## Create Engines for each model [skip this step if you already have an engine] +TensorRT-LLM requires each model to be compiled for the configuration you need before running. To do so, before you run your model for the first time on Triton Server you will need to create a TensorRT-LLM engine for the model for the configuration you want with the following steps: + +1. Install Tensorrt-LLM python package + ```bash + # Install CMake + bash /tensorrtllm_backend/tensorrt_llm/docker/common/install_cmake.sh + export PATH="/usr/local/cmake/bin:${PATH}" + + # PyTorch needs to be built from source for aarch64 + ARCH="$(uname -i)" + if [ "${ARCH}" = "aarch64" ]; then TORCH_INSTALL_TYPE="src_non_cxx11_abi"; \ + else TORCH_INSTALL_TYPE="pypi"; fi && \ + (cd /tensorrtllm_backend/tensorrt_llm && + bash docker/common/install_pytorch.sh $TORCH_INSTALL_TYPE && + python3 ./scripts/build_wheel.py --trt_root=/usr/local/tensorrt && + pip3 install ./build/tensorrt_llm*.whl) + ``` + +2. Compile model engines + + The script to build Llama models is located in [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples). We use the one located in the docker container as `/tensorrtllm_backend/tensorrt_llm/examples/llama/build.py`. + This command compiles the model with inflight batching and 1 GPU. To run with more GPUs, you will need to change the build command to use `--world_size X`. + More details for the scripting please see the documentation for the Llama example [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama/README.md). + + ```bash + python /tensorrtllm_backend/tensorrt_llm/examples/llama/build.py --model_dir /Llama-2-7b-hf/ \ + --dtype bfloat16 \ + --use_gpt_attention_plugin bfloat16 \ + --use_inflight_batching \ + --paged_kv_cache \ + --remove_input_padding \ + --use_gemm_plugin bfloat16 \ + --output_dir /engines/1-gpu/ \ + --world_size 1 + ``` + + > Optional: You can check test the output of the model with `run.py` + > located in the same llama examples folder. + > + > ```bash + > python3 /tensorrtllm_backend/tensorrt_llm/examples/llama/run.py --engine_dir=/engines/1-gpu/ --max_output_len 100 --tokenizer_dir /Llama-2-7b-hf --input_text "How do I count to ten in French?" + > ``` + +## Serving with Triton + +The last step is to create a Triton readable model. You can +find a template of a model that uses inflight batching in [tensorrtllm_backend/all_models/inflight_batcher_llm](https://github.com/triton-inference-server/tensorrtllm_backend/tree/main/all_models/inflight_batcher_llm). +To run our Llama2-7B model, you will need to: + + +1. Copy over the inflight batcher models repository + + ```bash + cp -R /tensorrtllm_backend/all_models/inflight_batcher_llm /opt/tritonserver/. + ``` + +2. Modify config.pbtxt for the preprocessing, postprocessing and processing steps. See details in [documentation](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/README.md#create-the-model-repository): + + ```bash + # preprocessing + sed -i 's#${tokenizer_dir}#/Llama-2-7b-hf/#' /opt/tritonserver/inflight_batcher_llm/preprocessing/config.pbtxt + sed -i 's#${tokenizer_type}#auto#' /opt/tritonserver/inflight_batcher_llm/preprocessing/config.pbtxt + sed -i 's#${tokenizer_dir}#/Llama-2-7b-hf/#' /opt/tritonserver/inflight_batcher_llm/postprocessing/config.pbtxt + sed -i 's#${tokenizer_type}#auto#' /opt/tritonserver/inflight_batcher_llm/postprocessing/config.pbtxt + + sed -i 's#${decoupled_mode}#false#' /opt/tritonserver/inflight_batcher_llm/tensorrt_llm/config.pbtxt + sed -i 's#${engine_dir}#/engines/1-gpu/#' /opt/tritonserver/inflight_batcher_llm/tensorrt_llm/config.pbtxt + ``` + Also, ensure that the `gpt_model_type` parameter is set to `inflight_fused_batching` + +3. Launch Tritonserver + + Use the [launch_triton_server.py](https://github.com/triton-inference-server/tensorrtllm_backend/blob/release/0.5.0/scripts/launch_triton_server.py) script. This launches multiple instances of `tritonserver` with MPI. + ```bash + python3 /tensorrtllm_backend/scripts/launch_triton_server.py --world_size=<world size of the engine> --model_repo=/opt/tritonserver/inflight_batcher_llm + ``` + +## Client + +You can test the results of the run with: +1. The [inflight_batcher_llm_client.py](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/inflight_batcher_llm/client/inflight_batcher_llm_client.py) script. + +```bash +# Using the SDK container as an example +docker run --rm -it --net host --shm-size=2g \ + --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \ + -v /path/to/tensorrtllm_backend:/tensorrtllm_backend \ + -v /path/to/Llama2/repo:/Llama-2-7b-hf \ + -v /path/to/engines:/engines \ + nvcr.io/nvidia/tritonserver:23.10-py3-sdk +# Install extra dependencies for the script +pip3 install transformers sentencepiece +python3 /tensorrtllm_backend/inflight_batcher_llm/client/inflight_batcher_llm_client.py --request-output-len 200 --tokenizer_type llama --tokenizer_dir /Llama-2-7b-hf +``` + +2. The [generate endpoint](https://github.com/triton-inference-server/tensorrtllm_backend/tree/release/0.5.0#query-the-server-with-the-triton-generate-endpoint) if you are using the Triton TensorRT-LLM Backend container with versions greater than `r23.10`. + + + diff --git a/Quick_Deploy/vLLM/Dockerfile b/Quick_Deploy/HuggingFaceTransformers/Dockerfile similarity index 91% rename from Quick_Deploy/vLLM/Dockerfile rename to Quick_Deploy/HuggingFaceTransformers/Dockerfile index 6358584b..cfb7a2b8 100644 --- a/Quick_Deploy/vLLM/Dockerfile +++ b/Quick_Deploy/HuggingFaceTransformers/Dockerfile @@ -23,6 +23,5 @@ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -FROM nvcr.io/nvidia/tritonserver:23.09-py3 -RUN pip install vllm==0.2.0 +FROM nvcr.io/nvidia/tritonserver:23.10-py3 +RUN pip install transformers==4.34.0 protobuf==3.20.3 sentencepiece==0.1.99 accelerate==0.23.0 einops==0.6.1 diff --git a/Quick_Deploy/HuggingFaceTransformers/README.md b/Quick_Deploy/HuggingFaceTransformers/README.md new file mode 100644 index 00000000..e7635058 --- /dev/null +++ b/Quick_Deploy/HuggingFaceTransformers/README.md @@ -0,0 +1,360 @@ +<!-- +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +--> + +# Deploying Hugging Face Transformer Models in Triton + +The following tutorial demonstrates how to deploy an arbitrary hugging face transformer +model on the Triton Inference Server using Triton's [Python backend](https://github.com/triton-inference-server/python_backend). For the purposes of this example, two transformer +models will be deployed: +- [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) +- [adept/persimmon-8b-base](https://huggingface.co/adept/persimmon-8b-base) + +These models were selected because of their popularity and consistent response quality. +However, this tutorial is also generalizable for any transformer model provided +sufficient infrastructure. + +*NOTE*: The tutorial is intended to be a reference example only. It may not be tuned for +optimal performance. + +## Step 1: Create a Model Repository + +The first step is to create a model repository containing the models we want the Triton +Inference Server to load and use for inference processing. To accomplish this, create a +directory called `model_repository` and copy the `falcon7b` model folder into it: + +``` +mkdir -p model_repository +cp -r falcon7b/ model_repository/ +``` + +The `falcon7b/` folder we copied is organized in the way Triton expects and contains +two important files needed to serve models in Triton: +- **config.pbtxt** - Outlines the backend to use, model input/output details, and custom +parameters to use for execution. More information on the full range of model configuration +properties Triton supports can be found [here](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html). +- **model.py** - Implements how Triton should handle the model during the initialization, +execution, and finalization stages. More information regarding python backend usage +can be found [here](https://github.com/triton-inference-server/python_backend#usage). + + +## Step 2: Build a Triton Container Image + +The second step is to create an image that includes all the dependencies necessary +to deploy hugging face transformer models on the Triton Inference Server. This can be done +by building an image from the provided Dockerfile: + +``` +docker build -t triton_transformer_server . +``` + +## Step 3: Launch the Triton Inference Server + +Once the ```triton_transformer_server``` image is created, you can launch the Triton Inference +Server in a container with the following command: + +```bash +docker run --gpus all -it --rm --net=host --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}/model_repository:/opt/tritonserver/model_repository triton_transformer_server tritonserver --model-repository=model_repository +``` + +The server has launched successfully when you see the following outputs in your console: + +``` +I0922 23:28:40.351809 1 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001 +I0922 23:28:40.352017 1 http_server.cc:3558] Started HTTPService at 0.0.0.0:8000 +I0922 23:28:40.395611 1 http_server.cc:187] Started Metrics Service at 0.0.0.0:8002 +``` + +## Step 4: Query the Server + +Now we can query the server using curl, specifying the server address and input details: + +```bash +curl -X POST localhost:8000/v2/models/falcon7b/infer -d '{"inputs": [{"name":"text_input","datatype":"BYTES","shape":[1],"data":["I am going"]}]}' +``` +In our testing, the server returned the following result (formatted for legibility): +```json +{ + "model_name": "falcon7b", + "model_version": "1", + "outputs": [ + { + "name": "text", + "datatype": "BYTES", + "shape": [ + 1 + ], + "data": [ + "I am going to be in the market for a new laptop soon. I" + ] + } + ] +} +``` + +## Step 5: Host Multiple Models in Triton + +So far in this tutorial, we have only loaded a single model. However, Triton is capable +of hosting many models, simultaneously. To accomplish this, first ensure you have +exited the docker container by invoking `Ctrl+C` and waiting for the container to exit. + +Next copy the remaining model provided into the model repository: +``` +cp -r persimmon8b/ model_repository/ +``` +*NOTE*: The combined size of these two models is large. If your current hardware cannot +support hosting both models simultaneously, consider loading a smaller model, such as +[opt-125m](https://huggingface.co/facebook/opt-125m), by creating a folder for it +using the templates provided and copying it into `model_repository`. + +Again, launch the server by invoking the `docker run` command from above and wait for confirmation +that the server has launched successfully. + +Query the server making sure to change the host address for each model: +```bash +curl -X POST localhost:8000/v2/models/falcon7b/infer -d '{"inputs": [{"name":"text_input","datatype":"BYTES","shape":[1],"data":["How can you be"]}]}' +curl -X POST localhost:8000/v2/models/persimmon8b/infer -d '{"inputs": [{"name":"text_input","datatype":"BYTES","shape":[1],"data":["Where is the nearest"]}]}' +``` +In our testing, these queries returned the following parsed results: +```bash +# falcon7b +"How can you be sure that you are getting the best deal on your car" + +# persimmon8b +"Where is the nearest starbucks?" +``` +Beginning in the 23.10 release, users can now interact with large language models (LLMs) hosted +by Triton in a simplified fashion by using Triton's generate endpoint: + +```bash +curl -X POST localhost:8000/v2/models/falcon7b/generate -d '{"text_input":"How can you be"}' +``` +## 'Day Zero' Support + +The latest transformer models may not always be supported in the most recent, official +release of the `transformers` package. In such a case, you should still be able to +load these 'bleeding edge' models in Triton by building `transformers` from source. +This can be done by replacing the transformers install directive in the provided +Dockerfile with: +```docker +RUN pip install git+https://github.com/huggingface/transformers.git +``` +Using this technique you should be able to serve any transformer models supported by +hugging face with Triton. + + +# Next Steps +The following sections expand on the base tutorial and provide guidance for future sandboxing. + +## Loading Cached Models +In the previous steps, we downloaded the falcon-7b model from hugging face when we +launched the Triton server. We can avoid this lengthy download process in subsequent runs +by loading cached models into Triton. By default, the provided `model.py` files will cache +the falcon and persimmon models in their respective directories within the `model_repository` +folder. This is accomplished by setting the `TRANSFORMERS_CACHE` environmental variable. +To set this environmental variable for an abtitrary model, include the following lines in +your `model.py` **before** importing the 'transformers' module, making sure to replace +`{MODEL}` with your target model. + +```python +import os +os.environ['TRANSFORMERS_CACHE'] = '/opt/tritonserver/model_repository/{MODEL}/hf_cache' +``` + +Alternatively, if your system has already cached a hugging face model you wish to deploy in Triton, +you can mount it to the Triton container by adding the following mount option to the `docker run` +command from earlier (making sure to replace `${HOME}` with the path to your associated username's home directory): + +```bash +# Option to mount a specific cached model (falcon-7b in this case) +-v ${HOME}/.cache/huggingface/hub/models--tiiuae--falcon-7b:/root/.cache/huggingface/hub/models--tiiuae--falcon-7b + +# Option to mount all cached models on the host system +-v ${HOME}/.cache/huggingface:/root/.cache/huggingface +``` + +## Triton Tool Ecosystem +Deploying models in Triton also comes with the benefit of access to a fully-supported suite +of deployment analyzers to help you better understand and tailor your systems to fit your +needs. Triton currently has two options for deployment analysis: +- [Performance Analyzer](https://docs.nvidia.com/deeplearning/triton-inference-server/archives/triton-inference-server-2310/user-guide/docs/user_guide/perf_analyzer.html): An inference performance optimizer. +- [Model Analyzer](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_analyzer.html) A GPU memory and compute utilization optimizer. + +### Performance Analyzer +To use the performance analyzer, please remove the persimmon8b model from `model_repository` and restart +the Triton server using the `docker run` command from above. + +Once Triton launches successfully, start a Triton SDK container by running the following in a separate window: + +```bash +docker run -it --net=host nvcr.io/nvidia/tritonserver:23.10-py3-sdk bash +``` +This container comes with all of Triton's deployment analyzers pre-installed, meaning +we can simply enter the following to get feedback on our model's inference performance: + +```bash +perf_analyzer -m falcon7b --collect-metrics +``` + +This command should run quickly and profile the performance of our falcon7b model. +As the analyzer runs, it will output useful metrics such as latency percentiles, +latency by stage of inference, and successful request count. A subset of the output +data is shown below: + +```bash +#Avg request latency +46307 usec (overhead 25 usec + queue 25 usec + compute input 26 usec + compute infer 46161 usec + compute output 68 usec) + +#Avg GPU Utilization +GPU-57c7b00e-ca04-3876-91e2-c1eae40a0733 : 66.0556% + +#Inferences/Second vs. Client Average Batch Latency +Concurrency: 1, throughput: 21.3841 infer/sec, latency 46783 usec +``` + +These metrics tell us that we are not fully utilizing our hardware and that our +throughput is low. We can immediately improve these results by batching our requests +instead of computing inferences one at a time. The `model.py` file for the falcon model +is already configured to handle batched requests. Enabling batching in Triton is as simple +as adding the following to falcon's `config.pbtxt` file: + +``` +dynamic_batching { } +max_batch_size: 8 +``` +The integer corresponding to the `max_batch_size`, can be any of your choosing, however, +for this example, we select 8. Now let's re-run the perf_analyzer with increasing levels +of concurrency and see how it impacts GPU utilization and throughput by executing: +```bash +perf_analyzer -m falcon7b --collect-metrics --concurrency-range=2:16:2 +``` +After executing for a few minutes, the performance analyzer should return +results similar to these (depending on hardware): +```bash +# Concurrency = 4 +GPU-57c7b00e-ca04-3876-91e2-c1eae40a0733 : 74.1111% +Throughput: 31.8264 infer/sec, latency 125174 usec + +# Concurrency = 8 +GPU-57c7b00e-ca04-3876-91e2-c1eae40a0733 : 81.7895% +Throughput: 46.2105 infer/sec, latency 172920 usec + +# Concurrency = 16 +GPU-57c7b00e-ca04-3876-91e2-c1eae40a0733 : 90.5556% +Throughput: 53.6549 infer/sec, latency 299178 usec +``` +Using the performance analyzer we were able to quickly profile different model configurations +to obtain better throughput and hardware utilization. In this case, we were able to +identify a configuration that nearly triples our throughput and increases GPU +utilization by ~24% in less than 5 minutes. + +This is a single, simple use case for the performance analyzer. For more information and +a more complete list of performance analyzer parameters and use cases, please see +[this](https://docs.nvidia.com/deeplearning/triton-inference-server/archives/triton-inference-server-2310/user-guide/docs/user_guide/perf_analyzer.html) +guide. + +For more information regarding dynamic batching in Triton, please see [this](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#dynamic-batcher) +guide. + +### Model Analyzer + +In the performance analyzer section, we used intuition to increase our throughput by changing +a subset of variables and measuring the difference in performance. However, we only changed +a few variables across a wide search space. + +To sweep this parameter space in a more robust fashion, we can use Triton's model analyzer, which +not only sweeps a large spectrum of configuration parameters, but also generates visual reports +to analyze post-execution. + +To use the model analyzer, please terminate your Triton server by invoking `Ctrl+C` and relaunching +it with the following command (ensuring the dynamic_batching parameters from above have been added +to the falcon model's config.pbtxt): +```bash +docker run --gpus all -it --rm --net=host --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}/model_repository:/opt/tritonserver/model_repository triton_transformer_server +``` + +Next, to get the most accurate GPU metrics from the model analyzer, we will install and launch it from +our local server container. To accomplish this, first install the model analyzer: +```bash +pip3 install triton-model-analyzer +``` + +Once the model analyzer installs successfully, enter the following command (modifying the instance +count to something lower for your GPU, if necessary): +```bash +model-analyzer profile -m /opt/tritonserver/model_repository/ --profile-models falcon7b --run-config-search-max-instance-count=3 --run-config-search-min-model-batch-size=8 +``` +This tool will take longer to execute than the performance analyzer example (~40 minutes). +If this execution time is too long, you can also run the analyzer with the +`--run-config-search-mode quick` option. In our experimentation, enabling the quick search option +yielded fewer results but took half the time. Regardless, once the model analyzer is complete, +it will provide you a full summary relating to throughput, latency, and hardware utilization +in multiple formats. A snippet from the summary report produced by the model analyzer for +our run is ranked by performance and shown below: + +| Model Config Name | Max Batch Size | Dynamic Batching | Total Instance Count | p99 Latency (ms) | Throughput (infer/sec) | Max GPU Memory Usage (MB) | Average GPU Utilization (%) | +| :---: | :----: | :---: | :----: | :---: | :----: | :---: | :---: | +| falcon7b_config_7 | 16 | Enabled | 3:GPU | 1412.581 | 71.944 | 46226 | 100.0 | +| falcon7b_config_8 | 32 | Enabled | 3:GPU | 2836.225 | 63.9652 | 46268 | 100.0 | +| falcon7b_config_4 | 16 | Enabled | 2:GPU | 7601.437 | 63.9454 | 31331 | 100.0 | +| falcon7b_config_default | 8 | Enabled | 1:GPU | 4151.873 | 63.9384 | 16449 | 89.3 | + +We can examine the performance of any of these configurations with more granularity by viewing +their detailed reports. This subset of reports focuses on a single configuration's latency +and concurrency metrics as they relate to throughput and hardware utilization. A snippet from +the top performing configuration for our tests is shown below (abridged for brevity): + +| Request Concurrency | p99 Latency (ms) | Client Response Wait (ms) | Server Queue (ms) | Server Compute Input (ms) | Server Compute Infer (ms) | Throughput (infer/sec) | Max GPU Memory Usage (MB) | Average GPU Utilization (%) | +| :---: | :----: | :---: | :----: | :---: | :----: | :---: | :---: | :---: | +| 512 | 8689.491 | 8190.506 | 7397.975 | 0.166 | 778.565 | 63.954 | 46230.667264 | 100.0 | +| | | | | ... | | | | | +| 128 | 2289.118 | 2049.37 | 1277.34 | 0.159 | 770.771 | 61.2953 | 46230.667264 | 100.0 | +| 64 | 1412.581 | 896.924 | 227.108 | 0.157 | 667.757 | 71.944 | 46226.47296 | 100.0 | +| 32 | 781.362 | 546.35 | 86.078 | 0.103 | 459.257 | 57.7877 | 46226.47296 | 100.0 | +| | | | | ... | | | | | +| 1 | 67.12 | 49.707 | 0.049 | 0.024 | 49.121 | 20.0993 | 46207.598592 | 54.9 | + +Similarly, this is a single use case for the model analyzer. For more information and a more complete list +of model analyzer parameters and run options, please see [this](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_analyzer.html) guide. + +*Please note that both the performance and model analyzer experiments were conducted +on a system with an Intel i9 and NVIDIA A6000 GPU. Your results may vary depending on +you hardware.* + +## Customization + +The `model.py` files have been kept minimal in order to maximize generalizability. Should you wish +to modify the behavior of the transformer models, such as increasing the number of generated sequences +to return, be sure to modify the corresponding `config.pbtxt` and `model.py` files and copy them +into the `model_repository`. + +The transformers used in this tutorial were all suited for text-generation tasks, however, this +is not a limitation. The principles of this tutorial can be applied to serve models suited for +any other transformer task. + +Triton offers a rich variety of available server configuration options not mentioned in this tutorial. +For a more custom deployment, please see our [model configuration guide](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html) to see how the scope of this tutorial can be expanded to fit your needs. diff --git a/Quick_Deploy/HuggingFaceTransformers/falcon7b/1/model.py b/Quick_Deploy/HuggingFaceTransformers/falcon7b/1/model.py new file mode 100644 index 00000000..71bede0e --- /dev/null +++ b/Quick_Deploy/HuggingFaceTransformers/falcon7b/1/model.py @@ -0,0 +1,109 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os + +os.environ[ + "TRANSFORMERS_CACHE" +] = "/opt/tritonserver/model_repository/falcon7b/hf_cache" +import json + +import numpy as np +import torch +import transformers +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.logger = pb_utils.Logger + self.model_config = json.loads(args["model_config"]) + self.model_params = self.model_config.get("parameters", {}) + default_hf_model = "tiiuae/falcon-7b" + default_max_gen_length = "15" + # Check for user-specified model name in model config parameters + hf_model = self.model_params.get("huggingface_model", {}).get( + "string_value", default_hf_model + ) + # Check for user-specified max length in model config parameters + self.max_output_length = int( + self.model_params.get("max_output_length", {}).get( + "string_value", default_max_gen_length + ) + ) + + self.logger.log_info(f"Max sequence length: {self.max_output_length}") + self.logger.log_info(f"Loading HuggingFace model: {hf_model}...") + # Assume tokenizer available for same model + self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model) + self.pipeline = transformers.pipeline( + "text-generation", + model=hf_model, + torch_dtype=torch.float16, + tokenizer=self.tokenizer, + device_map="auto", + ) + self.pipeline.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + def execute(self, requests): + prompts = [] + for request in requests: + input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input") + multi_dim = input_tensor.as_numpy().ndim > 1 + if not multi_dim: + prompt = input_tensor.as_numpy()[0].decode("utf-8") + self.logger.log_info(f"Generating sequences for text_input: {prompt}") + prompts.append(prompt) + else: + # Implementation to accept dynamically batched inputs + num_prompts = input_tensor.as_numpy().shape[0] + for prompt_index in range(0, num_prompts): + prompt = input_tensor.as_numpy()[prompt_index][0].decode("utf-8") + prompts.append(prompt) + + batch_size = len(prompts) + return self.generate(prompts, batch_size) + + def generate(self, prompts, batch_size): + sequences = self.pipeline( + prompts, + max_length=self.max_output_length, + pad_token_id=self.tokenizer.eos_token_id, + batch_size=batch_size, + ) + responses = [] + texts = [] + for i, seq in enumerate(sequences): + output_tensors = [] + text = seq[0]["generated_text"] + texts.append(text) + tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_)) + output_tensors.append(tensor) + responses.append(pb_utils.InferenceResponse(output_tensors=output_tensors)) + + return responses + + def finalize(self): + print("Cleaning up...") diff --git a/Quick_Deploy/HuggingFaceTransformers/falcon7b/config.pbtxt b/Quick_Deploy/HuggingFaceTransformers/falcon7b/config.pbtxt new file mode 100644 index 00000000..9949472d --- /dev/null +++ b/Quick_Deploy/HuggingFaceTransformers/falcon7b/config.pbtxt @@ -0,0 +1,36 @@ +# Triton backend to use +backend: "python" + +# Hugging face model path. Parameters must follow this +# key/value structure +parameters: { + key: "huggingface_model", + value: {string_value: "tiiuae/falcon-7b"} +} + +# The maximum number of tokens to generate in response +# to our input +parameters: { + key: "max_output_length", + value: {string_value: "15"} +} + +# Triton should expect as input a single string of set +# length named 'text_input' +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +# Triton should expect to respond with a single string +# output of variable length named 'text_output' +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] diff --git a/Quick_Deploy/HuggingFaceTransformers/persimmon8b/1/model.py b/Quick_Deploy/HuggingFaceTransformers/persimmon8b/1/model.py new file mode 100644 index 00000000..5119d406 --- /dev/null +++ b/Quick_Deploy/HuggingFaceTransformers/persimmon8b/1/model.py @@ -0,0 +1,103 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os + +os.environ[ + "TRANSFORMERS_CACHE" +] = "/opt/tritonserver/model_repository/persimmon8b/hf_cache" + +import json + +import numpy as np +import torch +import transformers +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.logger = pb_utils.Logger + self.model_config = json.loads(args["model_config"]) + self.model_params = self.model_config.get("parameters", {}) + default_hf_model = "adept/persimmon-8b-base" + default_max_gen_length = "15" + # Check for user-specified model name in model config parameters + hf_model = self.model_params.get("huggingface_model", {}).get( + "string_value", default_hf_model + ) + # Check for user-specified max length in model config parameters + self.max_output_length = int( + self.model_params.get("max_output_length", {}).get( + "string_value", default_max_gen_length + ) + ) + + self.logger.log_info(f"Max output length: {self.max_output_length}") + self.logger.log_info(f"Loading HuggingFace model: {hf_model}...") + # Assume tokenizer available for same model + self.tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model) + self.pipeline = transformers.pipeline( + "text-generation", + model=hf_model, + torch_dtype=torch.float16, + tokenizer=self.tokenizer, + device_map="auto", + ) + + def execute(self, requests): + responses = [] + for request in requests: + # Assume input named "prompt", specified in autocomplete above + input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input") + prompt = input_tensor.as_numpy()[0].decode("utf-8") + + self.logger.log_info(f"Generating sequences for text_input: {prompt}") + response = self.generate(prompt) + responses.append(response) + + return responses + + def generate(self, prompt): + sequences = self.pipeline( + prompt, + max_length=self.max_output_length, + pad_token_id=self.tokenizer.eos_token_id, + ) + + output_tensors = [] + texts = [] + for i, seq in enumerate(sequences): + text = seq["generated_text"] + self.logger.log_info(f"Sequence {i+1}: {text}") + texts.append(text) + + tensor = pb_utils.Tensor("text_output", np.array(texts, dtype=np.object_)) + output_tensors.append(tensor) + response = pb_utils.InferenceResponse(output_tensors=output_tensors) + return response + + def finalize(self): + print("Cleaning up...") diff --git a/Quick_Deploy/HuggingFaceTransformers/persimmon8b/config.pbtxt b/Quick_Deploy/HuggingFaceTransformers/persimmon8b/config.pbtxt new file mode 100644 index 00000000..5098c2a6 --- /dev/null +++ b/Quick_Deploy/HuggingFaceTransformers/persimmon8b/config.pbtxt @@ -0,0 +1,36 @@ +# Triton backend to use +backend: "python" + +# Hugging face model path. Parameters must follow this +# key/value structure +parameters: { + key: "huggingface_model", + value: {string_value: "adept/persimmon-8b-base"} +} + +# The maximum number of tokens to generate in response +# to our input +parameters: { + key: "max_output_length", + value: {string_value: "15"} +} + +# Triton should expect as input a single string of set +# length named 'text_input' +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ 1 ] + } +] + +# Triton should expect to respond with a single string +# output of variable length named 'text_output' +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + } +] diff --git a/Quick_Deploy/vLLM/README.md b/Quick_Deploy/vLLM/README.md index 5d292511..24e59eab 100644 --- a/Quick_Deploy/vLLM/README.md +++ b/Quick_Deploy/vLLM/README.md @@ -31,94 +31,161 @@ The following tutorial demonstrates how to deploy a simple [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model on -Triton Inference Server using Triton's [Python backend](https://github.com/triton-inference-server/python_backend) and the -[vLLM](https://github.com/vllm-project/vllm) library. +Triton Inference Server using the Triton's +[Python-based](https://github.com/triton-inference-server/backend/blob/main/docs/python_based_backends.md#python-based-backends) +[vLLM](https://github.com/triton-inference-server/vllm_backend/tree/main) +backend. -*NOTE*: The tutorial is intended to be a reference example only. It is a work in progress with -[known limitations](#limitations). +*NOTE*: The tutorial is intended to be a reference example only and has [known limitations](#limitations). -## Step 1: Build a Triton Container Image with vLLM +## Step 1: Prepare your model repository -We will build a new container image derived from tritonserver:23.08-py3 with vLLM. +To use Triton, we need to build a model repository. A sample model repository for deploying `facebook/opt-125m` using vLLM in Triton is +included with this demo as `model_repository` directory. +The model repository should look like this: ``` -docker build -t tritonserver_vllm . +model_repository/ +└── vllm_model + ├── 1 + │ └── model.json + └── config.pbtxt ``` -The above command should create the tritonserver_vllm image with vLLM and all of its dependencies. - +The configuration of engineArgs is in config.pbtxt: -## Step 2: Start Triton Inference Server - -A sample model repository for deploying `facebook/opt-125m` using vLLM in Triton is -included with this demo as `model_repository` directory. -The model repository should look like this: -``` -model_repository/ -`-- vllm - |-- 1 - | `-- model.py - |-- config.pbtxt - |-- vllm_engine_args.json ``` +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} -The content of `vllm_engine_args.json` is: +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} -```json -{ - "model": "facebook/opt-125m", - "disable_log_requests": "true", - "gpu_memory_utilization": 0.5 +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.8" + } } ``` + This file can be modified to provide further settings to the vLLM engine. See vLLM [AsyncEngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L165) and [EngineArgs](https://github.com/vllm-project/vllm/blob/32b6816e556f69f1672085a6267e8516bcb8e622/vllm/engine/arg_utils.py#L11) -for supported key-value pairs. +for supported key-value pairs. Inflight batching and paged attention is handled +by the vLLM engine. -For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified in [`vllm_engine_args.json`](model_repository/vllm/vllm_engine_args.json). +For multi-GPU support, EngineArgs like `tensor_parallel_size` can be specified in [`config.pbtxt`](model_repository/vllm/config.pbtxt). *Note*: vLLM greedily consume up to 90% of the GPU's memory under default settings. This tutorial updates this behavior by setting `gpu_memory_utilization` to 50%. You can tweak this behavior using fields like `gpu_memory_utilization` and other settings -in [`vllm_engine_args.json`](model_repository/vllm/vllm_engine_args.json). +in [`config.pbtxt`](model_repository/vllm/config.pbtxt). Read through the documentation in [`model.py`](model_repository/vllm/1/model.py) to understand how to configure this sample for your use-case. -Run the following commands to start the server container: +## Step 2: Launch Triton Inference Server + +Once you have the model repository setup, it is time to launch the triton server. +Starting with 23.10 release, a dedicated container with vLLM pre-installed +is available on [NGC.](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) +To use this container to launch Triton, you can use the docker command below. +``` +docker run --gpus all -it --net=host --rm -p 8001:8001 --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/work -w /work nvcr.io/nvidia/tritonserver:<xx.yy>-vllm-python-py3 tritonserver --model-store ./model_repository +``` +Throughout the tutorial, \<xx.yy\> is the version of Triton +that you want to use. Please note, that Triton's vLLM +container was first published in 23.10 release, so any prior version +will not work. + +After you start Triton you will see output on the console showing +the server starting up and loading the model. When you see output +like the following, Triton is ready to accept inference requests. ``` -docker run --gpus all -it --rm -p 8001:8001 --shm-size=1G --ulimit memlock=-1 --ulimit stack=67108864 -v ${PWD}:/work -w /work tritonserver_vllm tritonserver --model-store ./model_repository +I1030 22:33:28.291908 1 grpc_server.cc:2513] Started GRPCInferenceService at 0.0.0.0:8001 +I1030 22:33:28.292879 1 http_server.cc:4497] Started HTTPService at 0.0.0.0:8000 +I1030 22:33:28.335154 1 http_server.cc:270] Started Metrics Service at 0.0.0.0:8002 ``` -Upon successful start of the server, you should see the following at the end of the output. +## Step 3: Use a Triton Client to Send Your First Inference Request + +In this tutorial, we will show how to send an inference request to the +[facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model in 2 ways: +* [Using the generate endpoint](#using-generate-endpoint) +* [Using the gRPC asyncio client](#using-grpc-asyncio-client) + +### Using the Generate Endpoint +After you start Triton with the sample model_repository, +you can quickly run your first inference request with the +[generate](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md) +endpoint. + +Start Triton's SDK container with the following command: ``` -I0901 23:39:08.729123 1 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001 -I0901 23:39:08.729640 1 http_server.cc:3558] Started HTTPService at 0.0.0.0:8000 -I0901 23:39:08.772522 1 http_server.cc:187] Started Metrics Service at 0.0.0.0:8002 +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:<xx.yy>-py3-sdk bash ``` -## Step 3: Use a Triton Client to Query the Server +Now, let's send an inference request: +``` +curl -X POST localhost:8000/v2/models/vllm_model/generate -d '{"text_input": "What is Triton Inference Server?", "parameters": {"stream": false, "temperature": 0}}' +``` + +Upon success, you should see a response from the server like this one: +``` +{"model_name":"vllm_model","model_version":"1","text_output":"What is Triton Inference Server?\n\nTriton Inference Server is a server that is used by many"} +``` -We will run the client within Triton's SDK container to issue multiple async requests using the +### Using the gRPC Asyncio Client +Now, we will see how to run the client within Triton's SDK container +to issue multiple async requests using the [gRPC asyncio client](https://github.com/triton-inference-server/client/blob/main/src/python/library/tritonclient/grpc/aio/__init__.py) library. +This method requires a +[client.py](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +script and a set of +[prompts](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt), +which are provided in the +[samples](https://github.com/triton-inference-server/vllm_backend/tree/main/samples) +folder of +[vllm_backend](https://github.com/triton-inference-server/vllm_backend/tree/main) +repository. + +Use the following command to download `client.py` and `prompts.txt` to your +current directory: ``` -docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:23.08-py3-sdk bash +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/client.py +wget https://raw.githubusercontent.com/triton-inference-server/vllm_backend/main/samples/prompts.txt ``` -Within the container, run [`client.py`](client.py) with: +Now, we are ready to start Triton's SDK container: +``` +docker run -it --net=host -v ${PWD}:/workspace/ nvcr.io/nvidia/tritonserver:<xx.yy>-py3-sdk bash +``` +Within the container, run +[`client.py`](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/client.py) +with: ``` python3 client.py ``` -The client reads prompts from the [prompts.txt](prompts.txt) file, sends them to Triton server for +The client reads prompts from the +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt) +file, sends them to Triton server for inference, and stores the results into a file named `results.txt` by default. The output of the client should look like below: @@ -129,15 +196,22 @@ Storing results into `results.txt`... PASS: vLLM example ``` -You can inspect the contents of the `results.txt` for the response from the server. The `--iterations` -flag can be used with the client to increase the load on the server by looping through the list of -provided prompts in [`prompts.txt`](prompts.txt). +You can inspect the contents of the `results.txt` for the response +from the server. The `--iterations` flag can be used with the client +to increase the load on the server by looping through the list of +provided prompts in +[prompts.txt](https://github.com/triton-inference-server/vllm_backend/blob/main/samples/prompts.txt). -When you run the client in verbose mode with the `--verbose` flag, the client will print more details -about the request/response transactions. +When you run the client in verbose mode with the `--verbose` flag, +the client will print more details about the request/response transactions. ## Limitations - We use decoupled streaming protocol even if there is exactly 1 response for each request. - The asyncio implementation is exposed to model.py. - Does not support providing specific subset of GPUs to be used. +- If you are running multiple instances of Triton server with +a Python-based vLLM backend, you need to specify a different +`shm-region-prefix-name` for each server. See +[here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server) +for more information. diff --git a/Quick_Deploy/vLLM/client.py b/Quick_Deploy/vLLM/client.py deleted file mode 100644 index db1aa2db..00000000 --- a/Quick_Deploy/vLLM/client.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import argparse -import asyncio -import json -import queue -import sys -from os import system - -import numpy as np -import tritonclient.grpc.aio as grpcclient -from tritonclient.utils import * - - -def create_request( - prompt, - stream, - request_id, - sampling_parameters, - model_name, - send_parameters_as_tensor=True, -): - inputs = [] - prompt_data = np.array([prompt.encode("utf-8")], dtype=np.object_) - try: - inputs.append(grpcclient.InferInput("PROMPT", [1], "BYTES")) - inputs[-1].set_data_from_numpy(prompt_data) - except Exception as e: - print(f"Encountered an error {e}") - - stream_data = np.array([stream], dtype=bool) - inputs.append(grpcclient.InferInput("STREAM", [1], "BOOL")) - inputs[-1].set_data_from_numpy(stream_data) - - # Request parameters are not yet supported via BLS. Provide an - # optional mechanism to send serialized parameters as an input - # tensor until support is added - - if send_parameters_as_tensor: - sampling_parameters_data = np.array( - [json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_ - ) - inputs.append(grpcclient.InferInput("SAMPLING_PARAMETERS", [1], "BYTES")) - inputs[-1].set_data_from_numpy(sampling_parameters_data) - - # Add requested outputs - outputs = [] - outputs.append(grpcclient.InferRequestedOutput("TEXT")) - - # Issue the asynchronous sequence inference. - return { - "model_name": model_name, - "inputs": inputs, - "outputs": outputs, - "request_id": str(request_id), - "parameters": sampling_parameters, - } - - -async def main(FLAGS): - model_name = "vllm" - sampling_parameters = {"temperature": "0.1", "top_p": "0.95"} - stream = FLAGS.streaming_mode - with open(FLAGS.input_prompts, "r") as file: - print(f"Loading inputs from `{FLAGS.input_prompts}`...") - prompts = file.readlines() - - results_dict = {} - - async with grpcclient.InferenceServerClient( - url=FLAGS.url, verbose=FLAGS.verbose - ) as triton_client: - # Request iterator that yields the next request - async def async_request_iterator(): - try: - for iter in range(FLAGS.iterations): - for i, prompt in enumerate(prompts): - prompt_id = FLAGS.offset + (len(prompts) * iter) + i - results_dict[str(prompt_id)] = [] - yield create_request( - prompt, stream, prompt_id, sampling_parameters, model_name - ) - except Exception as error: - print(f"caught error in request iterator: {error}") - - try: - # Start streaming - response_iterator = triton_client.stream_infer( - inputs_iterator=async_request_iterator(), - stream_timeout=FLAGS.stream_timeout, - ) - # Read response from the stream - async for response in response_iterator: - result, error = response - if error: - print(f"Encountered error while processing: {error}") - else: - output = result.as_numpy("TEXT") - for i in output: - results_dict[result.get_response().id].append(i) - - except InferenceServerException as error: - print(error) - sys.exit(1) - - with open(FLAGS.results_file, "w") as file: - for id in results_dict.keys(): - for result in results_dict[id]: - file.write(result.decode("utf-8")) - file.write("\n") - file.write("\n=========\n\n") - print(f"Storing results into `{FLAGS.results_file}`...") - - if FLAGS.verbose: - print(f"\nContents of `{FLAGS.results_file}` ===>") - system(f"cat {FLAGS.results_file}") - - print("PASS: vLLM example") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "-v", - "--verbose", - action="store_true", - required=False, - default=False, - help="Enable verbose output", - ) - parser.add_argument( - "-u", - "--url", - type=str, - required=False, - default="localhost:8001", - help="Inference server URL and it gRPC port. Default is localhost:8001.", - ) - parser.add_argument( - "-t", - "--stream-timeout", - type=float, - required=False, - default=None, - help="Stream timeout in seconds. Default is None.", - ) - parser.add_argument( - "--offset", - type=int, - required=False, - default=0, - help="Add offset to request IDs used", - ) - parser.add_argument( - "--input-prompts", - type=str, - required=False, - default="prompts.txt", - help="Text file with input prompts", - ) - parser.add_argument( - "--results-file", - type=str, - required=False, - default="results.txt", - help="The file with output results", - ) - parser.add_argument( - "--iterations", - type=int, - required=False, - default=1, - help="Number of iterations through the prompts file", - ) - parser.add_argument( - "-s", - "--streaming-mode", - action="store_true", - required=False, - default=False, - help="Enable streaming mode", - ) - FLAGS = parser.parse_args() - asyncio.run(main(FLAGS)) diff --git a/Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt b/Quick_Deploy/vLLM/config.pbtxt similarity index 85% rename from Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt rename to Quick_Deploy/vLLM/config.pbtxt index 243491a6..764df417 100644 --- a/Quick_Deploy/vLLM/model_repository/vllm/config.pbtxt +++ b/Quick_Deploy/vLLM/config.pbtxt @@ -41,18 +41,18 @@ model_transaction_policy { input [ { - name: "PROMPT" + name: "text_input" data_type: TYPE_STRING dims: [ 1 ] }, { - name: "STREAM" + name: "stream" data_type: TYPE_BOOL dims: [ 1 ] optional: true }, { - name: "SAMPLING_PARAMETERS" + name: "sampling_parameters" data_type: TYPE_STRING dims: [ 1 ] optional: true @@ -61,7 +61,7 @@ input [ output [ { - name: "TEXT" + name: "text_output" data_type: TYPE_STRING dims: [ -1 ] } @@ -74,3 +74,25 @@ instance_group [ kind: KIND_MODEL } ] + +# The configuration of engineArgs +parameters { + key: "model" + value: { + string_value: "facebook/opt-125m", + } +} + +parameters { + key: "disable_log_requests" + value: { + string_value: "true" + } +} + +parameters { + key: "gpu_memory_utilization" + value: { + string_value: "0.8" + } +} \ No newline at end of file diff --git a/Quick_Deploy/vLLM/model_repository/vllm/1/model.py b/Quick_Deploy/vLLM/model_repository/vllm/1/model.py index d70cad57..cd77a0b9 100644 --- a/Quick_Deploy/vLLM/model_repository/vllm/1/model.py +++ b/Quick_Deploy/vLLM/model_repository/vllm/1/model.py @@ -26,7 +26,6 @@ import asyncio import json -import os import threading from typing import AsyncGenerator @@ -37,8 +36,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.utils import random_uuid -_VLLM_ENGINE_ARGS_FILENAME = "vllm_engine_args.json" - class TritonPythonModel: def initialize(self, args): @@ -55,21 +52,19 @@ def initialize(self, args): self.using_decoupled ), "vLLM Triton backend must be configured to use decoupled model transaction policy" - engine_args_filepath = os.path.join( - args["model_repository"], _VLLM_ENGINE_ARGS_FILENAME - ) - assert os.path.isfile( - engine_args_filepath - ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{args['model_repository']}'" - with open(engine_args_filepath) as file: - vllm_engine_config = json.load(file) + self.model_name = args["model_name"] + assert ( + self.model_name + ), "Parameter of [name] must be configured, and can not be empty in config.pbtxt" # Create an AsyncLLMEngine from the config from JSON self.llm_engine = AsyncLLMEngine.from_engine_args( - AsyncEngineArgs(**vllm_engine_config) + AsyncEngineArgs(**self.handle_initializing_config()) ) - output_config = pb_utils.get_output_config_by_name(self.model_config, "TEXT") + output_config = pb_utils.get_output_config_by_name( + self.model_config, "text_output" + ) self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) # Counter to keep track of ongoing request counts @@ -83,6 +78,38 @@ def initialize(self, args): self._shutdown_event = asyncio.Event() self._loop_thread.start() + def handle_initializing_config(self): + model_params = self.model_config.get("parameters", {}) + model_engine_args = {} + for key, value in model_params.items(): + model_engine_args[key] = value['string_value'] + + bool_keys = ["trust_remote_code", "use_np_weights", "use_dummy_weights", + "worker_use_ray", "disable_log_stats"] + for k in bool_keys: + if k in model_engine_args: + model_engine_args[k] = bool(model_engine_args[k]) + + float_keys = ["gpu_memory_utilization"] + for k in float_keys: + if k in model_engine_args: + model_engine_args[k] = float(model_engine_args[k]) + + int_keys = ["seed", "pipeline_parallel_size", "tensor_parallel_size", "block_size", + "swap_space", "max_num_batched_tokens", "max_num_seqs"] + for k in int_keys: + if k in model_engine_args: + model_engine_args[k] = int(model_engine_args[k]) + + # Check necessary parameter configuration in model config + model_param = model_engine_args["model"] + assert ( + model_param + ), "Parameter of [model] must be configured, and can not be empty in config.pbtxt" + + self.logger.log_info(f"Initialize engineArgs: {model_engine_args}") + return model_engine_args + def create_task(self, coro): """ Creates a task on the engine's event loop which is running on a separate thread. @@ -112,11 +139,17 @@ async def await_shutdown(self): # Wait for the ongoing_requests while self.ongoing_request_count > 0: self.logger.log_info( - "Awaiting remaining {} requests".format(self.ongoing_request_count) + "[vllm] Awaiting remaining {} requests".format( + self.ongoing_request_count + ) ) await asyncio.sleep(5) - self.logger.log_info("Shutdown complete") + for task in asyncio.all_tasks(loop=self._loop): + if task is not asyncio.current_task(): + task.cancel() + + self.logger.log_info("[vllm] Shutdown complete") def get_sampling_params_dict(self, params_json): """ @@ -160,7 +193,7 @@ def create_response(self, vllm_output): (prompt + output.text).encode("utf-8") for output in vllm_output.outputs ] triton_output_tensor = pb_utils.Tensor( - "TEXT", np.asarray(text_outputs, dtype=self.output_dtype) + "text_output", np.asarray(text_outputs, dtype=self.output_dtype) ) return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor]) @@ -172,22 +205,23 @@ async def generate(self, request): self.ongoing_request_count += 1 try: request_id = random_uuid() - - prompt = pb_utils.get_input_tensor_by_name(request, "PROMPT").as_numpy()[0] + prompt = pb_utils.get_input_tensor_by_name( + request, "text_input" + ).as_numpy()[0] if isinstance(prompt, bytes): prompt = prompt.decode("utf-8") - - # stream is an optional input - stream = False - stream_input_tensor = pb_utils.get_input_tensor_by_name(request, "STREAM") - if stream_input_tensor: - stream = stream_input_tensor.as_numpy()[0] + stream = pb_utils.get_input_tensor_by_name(request, "stream") + if stream: + stream = stream.as_numpy()[0] + else: + stream = False # Request parameters are not yet supported via # BLS. Provide an optional mechanism to receive serialized # parameters as an input tensor until support is added + parameters_input_tensor = pb_utils.get_input_tensor_by_name( - request, "SAMPLING_PARAMETERS" + request, "sampling_parameters" ) if parameters_input_tensor: parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") @@ -210,10 +244,10 @@ async def generate(self, request): response_sender.send(self.create_response(last_output)) except Exception as e: - self.logger.log_info(f"Error generating stream: {e}") + self.logger.log_info(f"[vllm] Error generating stream: {e}") error = pb_utils.TritonError(f"Error generating stream: {e}") triton_output_tensor = pb_utils.Tensor( - "TEXT", np.asarray(["N/A"], dtype=self.output_dtype) + "text_output", np.asarray(["N/A"], dtype=self.output_dtype) ) response = pb_utils.InferenceResponse( output_tensors=[triton_output_tensor], error=error @@ -242,7 +276,7 @@ def finalize(self): """ Triton virtual method; called when the model is unloaded. """ - self.logger.log_info("Issuing finalize to vllm backend") + self.logger.log_info("[vllm] Issuing finalize to vllm backend") self._shutdown_event.set() if self._loop_thread is not None: self._loop_thread.join() diff --git a/Quick_Deploy/vLLM/model_repository/vllm/vllm_engine_args.json b/Quick_Deploy/vLLM/model_repository/vllm/vllm_engine_args.json deleted file mode 100644 index e610c3cb..00000000 --- a/Quick_Deploy/vLLM/model_repository/vllm/vllm_engine_args.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "model":"facebook/opt-125m", - "disable_log_requests": "true", - "gpu_memory_utilization": 0.5 -} diff --git a/Quick_Deploy/vLLM/prompts.txt b/Quick_Deploy/vLLM/prompts.txt deleted file mode 100644 index 133800ec..00000000 --- a/Quick_Deploy/vLLM/prompts.txt +++ /dev/null @@ -1,4 +0,0 @@ -Hello, my name is -The most dangerous animal is -The capital of France is -The future of AI is diff --git a/README.md b/README.md index 43b95b75..f328ff07 100644 --- a/README.md +++ b/README.md @@ -10,10 +10,21 @@ For users experiencing the "Tensor in" & "Tensor out" approach to Deep Learning The focus of these examples is to demonstrate deployment for models trained with various frameworks. These are quick demonstrations made with an understanding that the user is somewhat familiar with Triton. -#### Deploy a ... +### Deploy a ... | [PyTorch Model](./Quick_Deploy/PyTorch/README.md) | [TensorFlow Model](./Quick_Deploy/TensorFlow/README.md) | [ONNX Model](./Quick_Deploy/ONNX/README.md) | [TensorRT Accelerated Model](https://github.com/NVIDIA/TensorRT/tree/main/quickstart/deploy_to_triton) | [vLLM Model](./Quick_Deploy/vLLM/README.md) | --------------- | ------------ | --------------- | --------------- | --------------- | +## LLM Tutorials +The table below contains some popular models that are supported in our tutorials +| Example Models | Tutorial Link | +| :-------------: | :------------------------------: | +| [Llama-2-7B](https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main) |[TensorRT-LLM Tutorial](Popular_Models_Guide/Llama2/trtllm_guide.md) | +| [Persimmon-8B](https://www.adept.ai/blog/persimmon-8b) | [HuggingFace Transformers Tutorial](https://github.com/triton-inference-server/tutorials/tree/main/Quick_Deploy/HuggingFaceTransformers) | +[Falcon-7B](https://huggingface.co/tiiuae/falcon-7b) |[HuggingFace Transformers Tutorial](https://github.com/triton-inference-server/tutorials/tree/main/Quick_Deploy/HuggingFaceTransformers) | + +**Note:** +This is not an exhausitive list of what Triton supports, just what is included in the tutorials. + ## What does this repository contain? This repository contains the following resources: * [Conceptual Guide](./Conceptual_Guide/): This guide focuses on building a conceptual understanding of the general challenges faced whilst building inference infrastructure and how to best tackle these challenges with Triton Inference Server.