From 6394d795eff27689a5eae0108f4726f39a871f64 Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 10 Feb 2025 01:40:49 -0800 Subject: [PATCH 1/6] Tensor parallel Llama3 tutorial illustrating use of torch.distributed and nccl ops --- docsrc/index.rst | 2 + .../tensor_parallel_llama3.py | 67 ++++++++++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/docsrc/index.rst b/docsrc/index.rst index e7d5250e52..b3fa7c5004 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -68,6 +68,7 @@ Tutorials * :ref:`mutable_torchtrt_module_example` * :ref:`weight_streaming_example` * :ref:`pre_allocated_output_example` +* :ref:`tensor_parallel_llama` .. toctree:: :caption: Tutorials @@ -87,6 +88,7 @@ Tutorials tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example tutorials/_rendered_examples/dynamo/weight_streaming_example tutorials/_rendered_examples/dynamo/pre_allocated_output_example + tutorials/_rendered_examples/dynamo/tensor_parallel_llama Dynamo Frontend ---------------- diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index 998c378be2..ee341e90df 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -1,10 +1,26 @@ # Taken and modified pytorch lightening # https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning +""" +.. _tensor_parallel_llama: + +Torch distributed example for llama3-7B model +====================================================== + +As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + import logging import os import time import torch + +# %% +# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model +# ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model from llama3_model import ModelArgs, ParallelTransformer from tensor_parallel_initialize_dist import initialize_distributed_env from torch.distributed._composable.fsdp import MixedPrecisionPolicy @@ -14,11 +30,24 @@ checkpoint_wrapper, ) +# %% +# Initialize the distributed environment +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Depending on the inputs/outputs sharded DTensors layout specified above, proper communication operations are required to transform DTensor layouts +# eg operations: allreduce, allgather, reduce_gather +# NCCL operations enable these operations. +# The below API does the following +# Initialize the communicators and the distributed environment +# Sets the path for the TRT-LLM plugin .so path which is required for the NCCL operations in Torch-TRT backend. Please note that if you are in python3.10 environment, `import tensorrt_llm` should be enough +# Initialize the logger. eg: In case of 2 GPUs, the log files are `./tensor_parallel_llama3_0.log` and `./tensor_parallel_llama3_1.log` device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_llama3" ) -# Import should be after initialization of the TRT-LLM plugin .so path -import tensorrt_llm + +# %% +# Model initialization with torch distributed parallel plan +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ logger.info(f"Starting PyTorch TP example on rank {_rank}.") assert ( @@ -36,7 +65,39 @@ ) with torch.no_grad(): + # The plan is + # plan = { + # "attention": PrepareModuleInput( + # input_layouts=(Shard(1), None), + # desired_input_layouts=(Replicate(), None), + # ), + # "attention.wq": ColwiseParallel(), + # "attention.wk": ColwiseParallel(), + # "attention.wv": ColwiseParallel(), + # "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + # "attention_norm": SequenceParallel(), + # "feed_forward": PrepareModuleInput( + # input_layouts=(Shard(1),), + # desired_input_layouts=(Replicate(),), + # ), + # "feed_forward.w1": ColwiseParallel(), + # "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + # "feed_forward.w3": ColwiseParallel(), + # "ffn_norm": SequenceParallel(), + # } + model = ParallelTransformer(model_args, device_mesh) + + # %% + # Model inference with Torch-TensorRT backend + # ------------------------------------------- + # When we compile the distributed model using Torch-TensorRT backend, pytorch distributed libraries create the sharded model + # on multiple GPUs and the communicator operations are used for proper communication. In the above, + # `ColwiseParallel` and `RowwiseParallel` shard the attention layers in the column or row fashion. + # `SequenceParallel` performs sharded computations of the normalization layer + # `PrepareModuleInput` configures the model input with proper communication operations + # The NCCL operations used in the distributed backend is handled by the TensorRT-LLM NCCL plugins, which causes no graph breaks now + torch.manual_seed(0) inp = torch.randint(32000, (8, 256), device="cuda") python_result = model(inp) @@ -62,9 +123,11 @@ output = model(inp) end = time.time() if i == 0: + # Logging the Compilation time logger.info(f"Compilation time is {end-start}") assert ( python_result - output ).std() < 0.01, "Compilation result is not correct." elif _rank == 0: + # Logging the inference time logger.info(f"Inference time is {end-start}") From d2f83dee006d69daf18f90c2fc024b474438ee9c Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 10 Feb 2025 04:28:54 -0800 Subject: [PATCH 2/6] tensor_parallel_llama location change --- docsrc/index.rst | 2 +- examples/distributed_inference/tensor_parallel_llama3.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docsrc/index.rst b/docsrc/index.rst index b3fa7c5004..741148d5c3 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -88,7 +88,7 @@ Tutorials tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example tutorials/_rendered_examples/dynamo/weight_streaming_example tutorials/_rendered_examples/dynamo/pre_allocated_output_example - tutorials/_rendered_examples/dynamo/tensor_parallel_llama + tutorials/_rendered_examples/distributed_inference/tensor_parallel_llama Dynamo Frontend ---------------- diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index ee341e90df..8a00d4631f 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -1,12 +1,11 @@ -# Taken and modified pytorch lightening -# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning """ .. _tensor_parallel_llama: Torch distributed example for llama3-7B model ====================================================== -As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured""" +As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured +""" # %% # Imports and Model Definition From 67115d5bfbc841250cbe6c22a5803afa5c3b75bc Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 3 Feb 2025 20:46:50 +0000 Subject: [PATCH 3/6] chore: Fix docs and example --- examples/distributed_inference/README.md | 47 +++++++++---------- .../tensor_parallel_simple_example.py | 1 - 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/examples/distributed_inference/README.md b/examples/distributed_inference/README.md index d4cf9508e1..4ff3126eec 100644 --- a/examples/distributed_inference/README.md +++ b/examples/distributed_inference/README.md @@ -2,49 +2,46 @@ Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend. -1. Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference) +## Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference) Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model will be loaded onto each GPU and different chunks of batch input is processed on each device. -See the examples started with `data_parallel` for more details. +See the examples [data_parallel_gpt2.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_gpt2.py) and [data_parallel_stable_diffusion.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_stable_diffusion.py) for more details. -2. Tensor parallel distributed inference +## Tensor parallel distributed inference Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded. torchrun --nproc_per_node=2 tensor_parallel_llama2.py -3. Tensor parallel distributed inference using nccl ops plugin +## Tensor parallel distributed inference on a simple model using nccl ops plugin - apt install libmpich-dev + +We use [torch.distributed](https://pytorch.org/docs/stable/distributed.html) package to add shard the model with Tensor parallelism. The distributed ops (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation. The [converters for these operators](https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py#L25-L55) are already available in Torch-TensorRT. The functional implementation of ops is imported from `tensorrt_llm` package (to be more specific, only `libnvinfer_plugin_tensorrt_llm.so` is required). So we have two options here - apt install libopenmpi-dev +### Option 1: Install TensorRT-LLM - #For python3.10 +Follow the instructions to [install TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) - pip install tensorrt-llm +If the default installation fails due to issues like library version mismatches or Python compatibility, it is recommended to use Option 2. After a successful installation, ensure you test by running `import torch_tensorrt` to confirm it works without errors. The import might fail if the `tensorrt_llm` installation overrides `torch_tensorrt` dependencies. Option 2 is particularly advisable if you prefer not to install `tensorrt_llm` and its associated dependencies. - For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there +### Option 2: Link the TensorRT-LLM directly. - #then pip install the tensorrt and torch version compatible with installed torchTRT + Another alternative is to load the `libnvinfer_plugin_tensorrt_llm.so` directly. You can do this by + * Downloading the [tensorrt_llm-0.16.0](https://pypi.nvidia.com/tensorrt-llm/tensorrt_llm-0.16.0-cp310-cp310-linux_x86_64.whl#sha256=f86c6b89647802f49b26b4f6e40824701da14c0f053dbda3e1e7a8709d6939c7) wheel file from the NVIDIA python index. + * Extract the wheel file to a directory and you can find the `libnvinfer_plugin_tensorrt_llm.so` library under `tensorrt_llm/libs` directory. + * Please set the environment variable TRTLLM_PLUGINS_PATH to the above extracted path at the [initialize_distributed_env()](https://github.com/pytorch/TensorRT/blob/54e36dbafe567c75f36b3edb22d6f49d4278c12a/examples/distributed_inference/tensor_parallel_initialize_dist.py#L45) call. - mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py - #For other python +After configuring the TensorRT-LLM or the TensorRT-LLM plugin library path, please run the following command which illustrates tensor parallelism of a simple model and compilation with Torch-TensorRT -4. Tensor parallel distributed llama3 inference using nccl ops plugin +```py +mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py +``` - apt install libmpich-dev +We also provide a tensor paralellism compilation example on a more advanced model like `Llama-3`. Here's the command to run it - apt install libopenmpi-dev - -#For python3.10 - - pip install tensorrt-llm - - For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so - - #then pip install the tensorrt and torch version compatible with installed torchTRT - - mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py +```py +mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py +``` diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 837648fdb4..9fe1a33bc5 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -15,7 +15,6 @@ device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_simple_example" ) -import tensorrt_llm """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py From 61f57c007c44661bdfb29024e3d17aaab263dfe4 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 14 Feb 2025 07:47:24 -0800 Subject: [PATCH 4/6] README.rst for examples/distributed_inference/tensor_parallel_llama3 --- docsrc/index.rst | 2 +- examples/distributed_inference/README.md | 47 ----------- examples/distributed_inference/README.rst | 77 +++++++++++++++++++ .../tensor_parallel_llama3.py | 3 +- 4 files changed, 80 insertions(+), 49 deletions(-) delete mode 100644 examples/distributed_inference/README.md create mode 100644 examples/distributed_inference/README.rst diff --git a/docsrc/index.rst b/docsrc/index.rst index 741148d5c3..c775c4e349 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -68,7 +68,7 @@ Tutorials * :ref:`mutable_torchtrt_module_example` * :ref:`weight_streaming_example` * :ref:`pre_allocated_output_example` -* :ref:`tensor_parallel_llama` +* :ref:`tensor_parallel_llama3` .. toctree:: :caption: Tutorials diff --git a/examples/distributed_inference/README.md b/examples/distributed_inference/README.md deleted file mode 100644 index 4ff3126eec..0000000000 --- a/examples/distributed_inference/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# Torch-TensorRT parallelism for distributed inference - -Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend. - -## Data parallel distributed inference based on [Accelerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference) - -Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model -will be loaded onto each GPU and different chunks of batch input is processed on each device. - -See the examples [data_parallel_gpt2.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_gpt2.py) and [data_parallel_stable_diffusion.py](https://github.com/pytorch/TensorRT/blob/main/examples/distributed_inference/data_parallel_stable_diffusion.py) for more details. - -## Tensor parallel distributed inference - -Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded. - -torchrun --nproc_per_node=2 tensor_parallel_llama2.py - -## Tensor parallel distributed inference on a simple model using nccl ops plugin - - -We use [torch.distributed](https://pytorch.org/docs/stable/distributed.html) package to add shard the model with Tensor parallelism. The distributed ops (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation. The [converters for these operators](https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py#L25-L55) are already available in Torch-TensorRT. The functional implementation of ops is imported from `tensorrt_llm` package (to be more specific, only `libnvinfer_plugin_tensorrt_llm.so` is required). So we have two options here - -### Option 1: Install TensorRT-LLM - -Follow the instructions to [install TensorRT-LLM](https://nvidia.github.io/TensorRT-LLM/installation/linux.html) - -If the default installation fails due to issues like library version mismatches or Python compatibility, it is recommended to use Option 2. After a successful installation, ensure you test by running `import torch_tensorrt` to confirm it works without errors. The import might fail if the `tensorrt_llm` installation overrides `torch_tensorrt` dependencies. Option 2 is particularly advisable if you prefer not to install `tensorrt_llm` and its associated dependencies. - -### Option 2: Link the TensorRT-LLM directly. - - Another alternative is to load the `libnvinfer_plugin_tensorrt_llm.so` directly. You can do this by - * Downloading the [tensorrt_llm-0.16.0](https://pypi.nvidia.com/tensorrt-llm/tensorrt_llm-0.16.0-cp310-cp310-linux_x86_64.whl#sha256=f86c6b89647802f49b26b4f6e40824701da14c0f053dbda3e1e7a8709d6939c7) wheel file from the NVIDIA python index. - * Extract the wheel file to a directory and you can find the `libnvinfer_plugin_tensorrt_llm.so` library under `tensorrt_llm/libs` directory. - * Please set the environment variable TRTLLM_PLUGINS_PATH to the above extracted path at the [initialize_distributed_env()](https://github.com/pytorch/TensorRT/blob/54e36dbafe567c75f36b3edb22d6f49d4278c12a/examples/distributed_inference/tensor_parallel_initialize_dist.py#L45) call. - - -After configuring the TensorRT-LLM or the TensorRT-LLM plugin library path, please run the following command which illustrates tensor parallelism of a simple model and compilation with Torch-TensorRT - -```py -mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py -``` - -We also provide a tensor paralellism compilation example on a more advanced model like `Llama-3`. Here's the command to run it - -```py -mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py -``` diff --git a/examples/distributed_inference/README.rst b/examples/distributed_inference/README.rst new file mode 100644 index 0000000000..93005691cc --- /dev/null +++ b/examples/distributed_inference/README.rst @@ -0,0 +1,77 @@ +Torch-TensorRT Parallelism for Distributed Inference +==================================================== + +Examples in this folder demonstrate distributed inference on multiple devices with the Torch-TensorRT backend. + +Data Parallel Distributed Inference based on `Accelerate `_ +--------------------------------------------------------------------------------------------------------------- + +Using Accelerate, users can achieve data parallel distributed inference with the Torch-TensorRT backend. +In this case, the entire model will be loaded onto each GPU, and different chunks of batch input are processed on each device. + +See the examples: + +- `data_parallel_gpt2.py `_ +- `data_parallel_stable_diffusion.py `_ + +for more details. + +Tensor Parallel Distributed Inference +-------------------------------------- + +Here, we use `torch.distributed` as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded. + +.. code-block:: bash + + torchrun --nproc_per_node=2 tensor_parallel_llama2.py + +Tensor Parallel Distributed Inference on a Simple Model using NCCL Ops Plugin +------------------------------------------------------------------------------ + +We use `torch.distributed `_ to shard the model with Tensor parallelism. +The distributed operations (`all_gather` and `all_reduce`) are then expressed as TensorRT-LLM plugins to avoid graph breaks during Torch-TensorRT compilation. +The `converters for these operators `_ are already available in Torch-TensorRT. +The functional implementation of ops is imported from the `tensorrt_llm` package (specifically, `libnvinfer_plugin_tensorrt_llm.so` is required). + +We have two options: + +Option 1: Install TensorRT-LLM +------------------------------- + +Follow the instructions to `install TensorRT-LLM `_. + +If the default installation fails due to issues like library version mismatches or Python compatibility, consider using Option 2. +After a successful installation, test by running: + +.. code-block:: python + + import torch_tensorrt + +to ensure it works without errors. +The import might fail if `tensorrt_llm` overrides `torch_tensorrt` dependencies. +Option 2 is preferable if you do not wish to install `tensorrt_llm` and its dependencies. + +Option 2: Link the TensorRT-LLM Directly +----------------------------------------- + +Alternatively, you can load `libnvinfer_plugin_tensorrt_llm.so` manually: + +1. Download the `tensorrt_llm-0.16.0 `_ wheel file from NVIDIA's Python index. +2. Extract the wheel file to a directory and locate `libnvinfer_plugin_tensorrt_llm.so` under the `tensorrt_llm/libs` directory. +3. Set the environment variable `TRTLLM_PLUGINS_PATH` to the extracted path at the `initialize_distributed_env() `_ call. + +After configuring TensorRT-LLM or the TensorRT-LLM plugin library path, run the following command to illustrate tensor parallelism of a simple model and compilation with Torch-TensorRT: + +.. code-block:: bash + + mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py + +We also provide a tensor parallelism compilation example on a more advanced model like `Llama-3`. Run the following command: + +.. code-block:: bash + + mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py + +Tutorials +----------------------------------------- +* :ref:`tensor_parallel_llama3`: Illustration of distributed inference on multiple devices with the Torch-TensorRT backend. \ No newline at end of file diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index 8a00d4631f..377c26679c 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -1,5 +1,5 @@ """ -.. _tensor_parallel_llama: +.. _tensor_parallel_llama3: Torch distributed example for llama3-7B model ====================================================== @@ -16,6 +16,7 @@ import time import torch +import torch_tensorrt # %% # Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model From fb0ba7f3461f1b5cc730a8135c59636f2214897b Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 18 Feb 2025 04:25:11 -0800 Subject: [PATCH 5/6] documentation changes to include examples/distributed_inference in rendered examples --- docsrc/index.rst | 2 +- examples/distributed_inference/README.rst | 8 ++- .../distributed_inference/llama3_model.py | 4 ++ .../tensor_parallel_initialize_dist.py | 6 ++ .../tensor_parallel_llama3.py | 55 +++++++++++++------ .../tensor_parallel_simple_example.py | 4 ++ 6 files changed, 58 insertions(+), 21 deletions(-) diff --git a/docsrc/index.rst b/docsrc/index.rst index c775c4e349..b4d96dbc8d 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -88,7 +88,7 @@ Tutorials tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example tutorials/_rendered_examples/dynamo/weight_streaming_example tutorials/_rendered_examples/dynamo/pre_allocated_output_example - tutorials/_rendered_examples/distributed_inference/tensor_parallel_llama + tutorials/_rendered_examples/distributed_inference/tensor_parallel_llama3 Dynamo Frontend ---------------- diff --git a/examples/distributed_inference/README.rst b/examples/distributed_inference/README.rst index 93005691cc..1993a111fb 100644 --- a/examples/distributed_inference/README.rst +++ b/examples/distributed_inference/README.rst @@ -1,10 +1,12 @@ +.. _tensor_parallel_llama3: + Torch-TensorRT Parallelism for Distributed Inference ==================================================== Examples in this folder demonstrate distributed inference on multiple devices with the Torch-TensorRT backend. Data Parallel Distributed Inference based on `Accelerate `_ ---------------------------------------------------------------------------------------------------------------- +----------------------------------------------------------------------------------------------------------------------------------------- Using Accelerate, users can achieve data parallel distributed inference with the Torch-TensorRT backend. In this case, the entire model will be loaded onto each GPU, and different chunks of batch input are processed on each device. @@ -36,7 +38,7 @@ The functional implementation of ops is imported from the `tensorrt_llm` package We have two options: Option 1: Install TensorRT-LLM -------------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Follow the instructions to `install TensorRT-LLM `_. @@ -52,7 +54,7 @@ The import might fail if `tensorrt_llm` overrides `torch_tensorrt` dependencies. Option 2 is preferable if you do not wish to install `tensorrt_llm` and its dependencies. Option 2: Link the TensorRT-LLM Directly ------------------------------------------ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Alternatively, you can load `libnvinfer_plugin_tensorrt_llm.so` manually: diff --git a/examples/distributed_inference/llama3_model.py b/examples/distributed_inference/llama3_model.py index 9fa59b5c49..e5e8e0ca6c 100644 --- a/examples/distributed_inference/llama3_model.py +++ b/examples/distributed_inference/llama3_model.py @@ -1,3 +1,7 @@ +""" +This file contains the Llama3 model example used for tensor parallel distribution +""" + # Taken and modified pytorch lightening # https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py index 21e4cbc282..9a662e92f7 100644 --- a/examples/distributed_inference/tensor_parallel_initialize_dist.py +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -1,3 +1,9 @@ +""" +This script contains utility functions for Tensor Parallelism +using Torch-TensorRT. It sets up the necessary communication protocols, +environments and partitions the model across multiple GPUs. +""" + import logging import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index 377c26679c..9ed985b702 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -18,7 +18,6 @@ import torch import torch_tensorrt -# %% # Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model # ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model from llama3_model import ModelArgs, ParallelTransformer @@ -32,15 +31,17 @@ # %% # Initialize the distributed environment -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -# Depending on the inputs/outputs sharded DTensors layout specified above, proper communication operations are required to transform DTensor layouts -# eg operations: allreduce, allgather, reduce_gather -# NCCL operations enable these operations. -# The below API does the following -# Initialize the communicators and the distributed environment -# Sets the path for the TRT-LLM plugin .so path which is required for the NCCL operations in Torch-TRT backend. Please note that if you are in python3.10 environment, `import tensorrt_llm` should be enough -# Initialize the logger. eg: In case of 2 GPUs, the log files are `./tensor_parallel_llama3_0.log` and `./tensor_parallel_llama3_1.log` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# The following steps are performed: +# +# - Initialize the communicators and the distributed environment +# - Set the path for the `TRT-LLM`` plugin `.so` file, which is required for the NCCL operations in Torch-TRT backend. +# - Initialize the logger: +# +# - Example: In a 2-GPU setup, the log files will be: +# - `./tensor_parallel_llama3_0.log` +# - `./tensor_parallel_llama3_1.log` +# device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_llama3" ) @@ -91,13 +92,33 @@ # %% # Model inference with Torch-TensorRT backend # ------------------------------------------- - # When we compile the distributed model using Torch-TensorRT backend, pytorch distributed libraries create the sharded model - # on multiple GPUs and the communicator operations are used for proper communication. In the above, - # `ColwiseParallel` and `RowwiseParallel` shard the attention layers in the column or row fashion. - # `SequenceParallel` performs sharded computations of the normalization layer - # `PrepareModuleInput` configures the model input with proper communication operations - # The NCCL operations used in the distributed backend is handled by the TensorRT-LLM NCCL plugins, which causes no graph breaks now - + # When we compile the distributed model using the **Torch-TensorRT** backend, PyTorch's distributed libraries: + # + # - Create the **sharded model** across multiple GPUs. + # - Use **communicator operations** to ensure proper communication. + # + # The following components manage different aspects of parallelism: + # + # - **`ColwiseParallel`** and **`RowwiseParallel`**: + # - Shard the attention layers in **column-wise** or **row-wise** fashion. + # + # - **`SequenceParallel`**: + # - Performs **sharded computations** of the normalization layer. + # + # - **`PrepareModuleInput`**: + # - Configures the model input with proper **communication operations**. + # + # **NCCL Operations in TensorRT-LLM:** + # + # - The **TensorRT-LLM NCCL plugins** handle distributed backend NCCL operations, preventing **graph breaks**. + # - Depending on the **DTensor sharding layout**, proper **communication operations** are required to transform the DTensor layout. + # + # **Common NCCL Operations Used:** + # + # - `allreduce` + # - `allgather` + # - `reduce_scatter` + # torch.manual_seed(0) inp = torch.randint(32000, (8, 256), device="cuda") python_result = model(inp) diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 9fe1a33bc5..ade8f0607d 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -1,3 +1,7 @@ +""" +This file contains the Tensor parallel simple model example used for tensor parallel distribution +""" + import time import tensorrt as trt From 58a4bb3082b5621e5181a10aaae9d3e44da2a648 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 18 Feb 2025 09:52:19 -0800 Subject: [PATCH 6/6] adding README instructions --- examples/distributed_inference/README.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/distributed_inference/README.rst b/examples/distributed_inference/README.rst index 1993a111fb..5f9e2ec99a 100644 --- a/examples/distributed_inference/README.rst +++ b/examples/distributed_inference/README.rst @@ -41,6 +41,10 @@ Option 1: Install TensorRT-LLM ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Follow the instructions to `install TensorRT-LLM `_. +Please note that before installing TensorRT-LLM, you need to + +1. apt install libmpich-dev +2. apt install libopenmpi-dev If the default installation fails due to issues like library version mismatches or Python compatibility, consider using Option 2. After a successful installation, test by running: