diff --git a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp index 763ca2f069..24875df01e 100644 --- a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp @@ -48,6 +48,7 @@ LoraPlugin::LoraPlugin(int in_hidden_size, std::vector out_hidden_sizes, in , mPluginProfiler(pluginProfiler) { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); + TLLM_LOG_DEBUG("Plugin type: %d", static_cast(mType)); mOutHiddenSizes.resize(mNumLoraModules); mOutHiddenSizes.assign(out_hidden_sizes.begin(), out_hidden_sizes.end()); init(); @@ -292,6 +293,82 @@ int LoraPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P // only used for unified gemm auto bestTactic = mPluginProfiler->getBestConfig(numTokens, mGemmId); mLoraImpl->setBestTactic(bestTactic); + + // Add debug prints for LoraImpl->run parameters + TLLM_LOG_INFO("DEBUG: loraPlugin.cpp - LoraImpl->run parameters:"); + TLLM_LOG_INFO("DEBUG: numTokens = %d", numTokens); + TLLM_LOG_INFO("DEBUG: numReqs = %d", numReqs); + TLLM_LOG_INFO("DEBUG: input shape = [%d, %d, %d]", inputDesc[getInputTensorIdx()].dims.d[0], + inputDesc[getInputTensorIdx()].dims.d[1], inputDesc[getInputTensorIdx()].dims.d[2]); + TLLM_LOG_INFO("DEBUG: mExpandLoraRanks size = %zu", mExpandLoraRanks.size()); + TLLM_LOG_INFO("DEBUG: mExpandLoraWeightPtrs size = %zu", mExpandLoraWeightPtrs.size()); + TLLM_LOG_INFO("DEBUG: mWeightIndex = %d", mWeightIndex); + TLLM_LOG_INFO("DEBUG: outputs size = %d", mNumLoraModules); + TLLM_LOG_INFO("DEBUG: workspace size = %zu", mLoraImpl->getWorkspaceSize(numTokens, numReqs, mType)); + TLLM_LOG_INFO("DEBUG: stream = %p", (void*) stream); + + // Add more detailed debug prints + TLLM_LOG_INFO("DEBUG: loraPlugin.cpp - Additional details:"); + TLLM_LOG_INFO("DEBUG: input dtype = %d", (int) mType); + TLLM_LOG_INFO("DEBUG: mTransA = %d, mTransB = %d", mTransA, mTransB); + TLLM_LOG_INFO("DEBUG: mMaxLowRank = %d", mMaxLowRank); + TLLM_LOG_INFO("DEBUG: mRemoveInputPadding = %d", mRemoveInputPadding); + + // Print output hidden sizes + std::string outHiddenSizesStr; + for (size_t i = 0; i < mOutHiddenSizes.size(); ++i) + { + outHiddenSizesStr += std::to_string(mOutHiddenSizes[i]); + if (i < mOutHiddenSizes.size() - 1) + outHiddenSizesStr += ", "; + } + TLLM_LOG_INFO("DEBUG: mOutHiddenSizes = [%s]", outHiddenSizesStr.c_str()); + + // Print first few values of mExpandLoraRanks + TLLM_LOG_INFO("DEBUG: First 5 mExpandLoraRanks values: [%d, %d, %d, %d, %d]", + mExpandLoraRanks.size() > 0 ? mExpandLoraRanks[0] : -1, mExpandLoraRanks.size() > 1 ? mExpandLoraRanks[1] : -1, + mExpandLoraRanks.size() > 2 ? mExpandLoraRanks[2] : -1, mExpandLoraRanks.size() > 3 ? mExpandLoraRanks[3] : -1, + mExpandLoraRanks.size() > 4 ? mExpandLoraRanks[4] : -1); + + // Print actual tensor values - SAFE VERSION + TLLM_LOG_INFO("DEBUG: loraPlugin.cpp - Tensor values:"); + TLLM_LOG_INFO("DEBUG: Input tensor pointer: %p", input); + + // Print first few weight pointers - more defensive approach + TLLM_LOG_INFO("DEBUG: First 5 weight pointers: [%p, %p, %p, %p, %p]", + mExpandLoraWeightPtrs.size() > 0 ? mExpandLoraWeightPtrs[0] : nullptr, + mExpandLoraWeightPtrs.size() > 1 ? mExpandLoraWeightPtrs[1] : nullptr, + mExpandLoraWeightPtrs.size() > 2 ? mExpandLoraWeightPtrs[2] : nullptr, + mExpandLoraWeightPtrs.size() > 3 ? mExpandLoraWeightPtrs[3] : nullptr, + mExpandLoraWeightPtrs.size() > 4 ? mExpandLoraWeightPtrs[4] : nullptr); + + // Check if all tokens use the same LoRA weights and ranks (for unified GEMM path) + bool useUnifiedGemm = true; + if (mExpandLoraRanks.size() > 1) + { + int32_t firstRank = mExpandLoraRanks[0]; + void const* firstWeightPtr = mExpandLoraWeightPtrs[0]; + + TLLM_LOG_INFO( + "DEBUG: Checking for unified GEMM path - First rank: %d, First weight ptr: %p", firstRank, firstWeightPtr); + + for (size_t i = 1; i < mExpandLoraRanks.size(); ++i) + { + if (mExpandLoraRanks[i] != firstRank || mExpandLoraWeightPtrs[i] != firstWeightPtr) + { + useUnifiedGemm = false; + TLLM_LOG_INFO("DEBUG: Found different rank or weight ptr at index %zu: rank=%d, ptr=%p", i, + mExpandLoraRanks[i], mExpandLoraWeightPtrs[i]); + break; + } + } + } + + TLLM_LOG_INFO("DEBUG: GEMM path selection: %s", useUnifiedGemm ? "UNIFIED GEMM" : "SPLIT GEMM (CUTLASS)"); + + // Add a debug print to show the actual GEMM path being used + TLLM_LOG_INFO("DEBUG: Calling LoraImpl->run with useUnifiedGemm=%d", useUnifiedGemm); + mLoraImpl->run(numTokens, numReqs, input, mExpandLoraRanks.data(), mExpandLoraWeightPtrs.data(), mWeightIndex, outputs, workspace, stream); @@ -433,21 +510,26 @@ IPluginV2* LoraPluginCreator::createPlugin(char const* name, PluginFieldCollecti { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); in_hidden_size = *(static_cast(fields[i].data)); + TLLM_LOG_DEBUG("Read in_hidden_size: %d", in_hidden_size); } else if (!strcmp(attrName, "transa")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); transA = *(static_cast(fields[i].data)); + TLLM_LOG_DEBUG("Read transA: %d", transA); } else if (!strcmp(attrName, "transb")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); transB = *(static_cast(fields[i].data)); + TLLM_LOG_DEBUG("Read transB: %d", transB); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); type = static_cast(*(static_cast(fields[i].data))); + TLLM_LOG_DEBUG( + "Received type_id from Python: %d (kFLOAT=0, kHALF=1, kINT8=2, kINT32=3)", static_cast(type)); } else if (!strcmp(attrName, "remove_input_padding")) { diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index 8784445119..69744b6391 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -54,6 +54,7 @@ add_library( fp8Quantize.cpp fusedTopkSoftmax.cpp gatherTreeOp.cpp + loraOp.cpp logitsBitmaskOp.cpp mambaConv1dOp.cpp moeOp.cpp diff --git a/cpp/tensorrt_llm/thop/loraOp.cpp b/cpp/tensorrt_llm/thop/loraOp.cpp new file mode 100644 index 0000000000..41d3f88ea2 --- /dev/null +++ b/cpp/tensorrt_llm/thop/loraOp.cpp @@ -0,0 +1,285 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/kernels/lora/lora.h" +#include "tensorrt_llm/kernels/selectiveScan/selectiveScan.h" +#include "tensorrt_llm/thop/thUtils.h" + +namespace th = torch; +namespace tk = tensorrt_llm::kernels; +using tensorrt_llm::common::fmtstr; + +namespace torch_ext +{ + +enum class RequestType : int32_t +{ + kCONTEXT = 0, + kGENERATION = 1 +}; + +int64_t getNumTokens(th::Tensor const& input) +{ + int ndim = input.sizes().size(); + TLLM_CHECK_WITH_INFO( + 3 == ndim || 2 == ndim, "hidden_state dimension should be either 2 [numTokens, hidden], or 3 [b, s, hidden]"); + int64_t num_tokens = input.sizes()[0]; + if (ndim == 3) + { + num_tokens *= input.sizes()[1]; + } + return num_tokens; +} + +std::vector lora_grouped_gemm(th::Tensor const& input, th::Tensor const& host_request_types, + std::vector const& lora_ranks, // numModules tensors, each tensors has single value + std::vector const& lora_weights_pointers, th::Tensor const& host_context_lengths, + std::vector const& output_hidden_sizes, bool transA, bool transB, int64_t const max_low_rank, + int64_t const& weight_index, bool isRemoveInputPadding) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto stream + = at::cuda::getCurrentCUDAStream().stream(); // todo(dafrimi): manages its own CUDA stream and synchronization, + // while the TensorRT plugin uses the stream provided by TensorRT. + auto const numReqs = lora_ranks[0].sizes()[0]; + auto const out_shape = input.sizes(); + int const numLoraModules = lora_ranks.size(); + TLLM_CHECK_WITH_INFO(lora_ranks.size() == lora_weights_pointers.size(), "both should be numLoraModules"); + + // Add debug prints + TLLM_LOG_INFO("DEBUG: numLoraModules = %d", numLoraModules); + TLLM_LOG_INFO("DEBUG: input shape = [%ld, %ld, %ld]", input.sizes()[0], input.sizes()[1], input.sizes()[2]); + TLLM_LOG_INFO("DEBUG: numReqs = %ld", numReqs); + + // todo(dafrimi): explicitly allocates output tensors, while the TensorRT plugin relies on TensorRT's memory + // management system to provide pre-allocated output buffers. + TLLM_LOG_INFO("DEBUG: output_hidden_sizes = %d", output_hidden_sizes[0]); + TLLM_LOG_INFO("DEBUG: input options = %d", input.options()); + TLLM_LOG_INFO("DEBUG: out_shape = [%ld, %ld, %ld]", out_shape[0], out_shape[1], out_shape[2]); + + std::vector output_torch; + for (int i = 0; i < numLoraModules; i++) + { + std::vector output_shape = {out_shape[0], out_shape[1]}; + + if (!isRemoveInputPadding) + { + output_shape = {out_shape[0], out_shape[1], output_hidden_sizes[i]}; + } + + output_torch.push_back(torch::empty(output_shape, input.options())); + } + std::vector output; + for (auto tensor_it = output_torch.begin(); tensor_it != output_torch.end(); tensor_it++) + { + output.push_back(tensor_it->data_ptr()); + } + int const seqLen = isRemoveInputPadding ? 0 : input.sizes()[1]; + int32_t const* reqTypes = static_cast(host_request_types.data_ptr()); + int32_t const* hostContextLengths + = isRemoveInputPadding ? static_cast(host_context_lengths.data_ptr()) : nullptr; + + int64_t numTokens = getNumTokens(input); + + // todo(dafrimi): not reusing the vectors below + std::vector expandLoraWeightPtrs{}; + std::vector expandLoraRanks{}; + expandLoraWeightPtrs.reserve(numLoraModules * numTokens * 2); + expandLoraRanks.reserve(numLoraModules * numTokens); + + for (int loraModuleIdx = 0; loraModuleIdx < numLoraModules; loraModuleIdx++) + { + auto const loraRankModule = static_cast(lora_ranks[loraModuleIdx].data_ptr()); + auto const loraWeightModulePtrs = static_cast(lora_weights_pointers[loraModuleIdx].data_ptr()); + + int idx = 0; + for (int reqId = 0; reqId < numReqs; reqId++) + { + // loraWeightModulePtrs has 3 pointers for each module: A,B, and an optional DoRA magnitude + // the current LoRA plugin does not apply DoRA scaling, so the magnitude is ignored + RequestType const reqType = static_cast(reqTypes[reqId]); + + if (reqType == RequestType::kGENERATION) + { + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3])); + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3 + 1])); + expandLoraRanks.push_back(loraRankModule[reqId]); + idx += 1; + } + else + { + int contextLen = (isRemoveInputPadding ? hostContextLengths[reqId] : seqLen); + + for (int contextId = 0; contextId < contextLen; contextId++) + { + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3])); + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3 + 1])); + expandLoraRanks.push_back(loraRankModule[reqId]); + idx += 1; + } + } + } + + // In 1st generation phase cross attention qkv lora, cross qkv is skipped by passing an empty encoder_output + // (passing 0 to dim) getNumTokens() will get in cross qkv_lora. Skipping the check for this case. + if (numTokens > 0) + { + TLLM_CHECK_WITH_INFO(idx == numTokens, + fmtstr("LoraParams and input dims don't match, lora tokens %d input tokens %ld", idx, numTokens)); + } + } + + auto cublasHandle = getCublasHandle(); + auto cublasLtHandle = getCublasLtHandle(); + auto cublasWraper + = std::make_shared(cublasHandle, cublasLtHandle, nullptr, nullptr); + + int const inHiddenSize = input.sizes()[input.sizes().size() - 1]; + + std::vector outHiddenSizes(output_hidden_sizes.size()); + for (int i = 0; i < numLoraModules; i++) + { + outHiddenSizes[i] = output_hidden_sizes[i]; + } + nvinfer1::DataType loraRuntimeDataType; + switch (input.scalar_type()) + { + case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break; + case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break; + default: throw std::invalid_argument("Invalid dtype, only supports float16, bfloat16"); + } + + auto mLoraImpl = std::make_shared( + inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWraper); + + mLoraImpl->setBestTactic(std::nullopt); // todo(dafrimi): TensorRT plugin uses a profiler to determine the best + // tactic for the current input size + + auto const workspace_size = mLoraImpl->getWorkspaceSize(numTokens, numReqs, loraRuntimeDataType); + + auto workspace = torch::empty(std::vector{static_cast(workspace_size)}, input.options()); + + // Add debug prints for LoraImpl->run parameters + TLLM_LOG_INFO("DEBUG: loraOp.cpp - LoraImpl->run parameters:"); + TLLM_LOG_INFO("DEBUG: numTokens = %ld", numTokens); + TLLM_LOG_INFO("DEBUG: numReqs = %ld", numReqs); + TLLM_LOG_INFO("DEBUG: input shape = [%ld, %ld, %ld]", input.sizes()[0], input.sizes()[1], input.sizes()[2]); + TLLM_LOG_INFO("DEBUG: expandLoraRanks size = %zu", expandLoraRanks.size()); + TLLM_LOG_INFO("DEBUG: expandLoraWeightPtrs size = %zu", expandLoraWeightPtrs.size()); + TLLM_LOG_INFO("DEBUG: weight_index = %ld", weight_index); + TLLM_LOG_INFO("DEBUG: output size = %zu", output.size()); + TLLM_LOG_INFO("DEBUG: workspace size = %ld", workspace_size); + TLLM_LOG_INFO("DEBUG: stream = %p", (void*) stream); + + // Add more detailed debug prints + TLLM_LOG_INFO("DEBUG: loraOp.cpp - Additional details:"); + TLLM_LOG_INFO("DEBUG: input dtype = %d", (int) loraRuntimeDataType); + TLLM_LOG_INFO("DEBUG: transA = %d, transB = %d", transA, transB); + TLLM_LOG_INFO("DEBUG: max_low_rank = %ld", max_low_rank); + TLLM_LOG_INFO("DEBUG: isRemoveInputPadding = %d", isRemoveInputPadding); + TLLM_LOG_INFO("DEBUG: output_hidden_sizes = [%s]", + [&output_hidden_sizes]() + { + std::string result; + for (size_t i = 0; i < output_hidden_sizes.size(); ++i) + { + result += std::to_string(output_hidden_sizes[i]); + if (i < output_hidden_sizes.size() - 1) + result += ", "; + } + return result; + }() + .c_str()); + + // Print first few values of expandLoraRanks + TLLM_LOG_INFO("DEBUG: First 5 expandLoraRanks values: [%d, %d, %d, %d, %d]", + expandLoraRanks.size() > 0 ? expandLoraRanks[0] : -1, expandLoraRanks.size() > 1 ? expandLoraRanks[1] : -1, + expandLoraRanks.size() > 2 ? expandLoraRanks[2] : -1, expandLoraRanks.size() > 3 ? expandLoraRanks[3] : -1, + expandLoraRanks.size() > 4 ? expandLoraRanks[4] : -1); + + // Print actual tensor values - SAFE VERSION + TLLM_LOG_INFO("DEBUG: loraOp.cpp - Tensor values:"); + TLLM_LOG_INFO("DEBUG: Input tensor pointer: %p", input.data_ptr()); + + // Print first few weight pointers - more defensive approach + TLLM_LOG_INFO("DEBUG: First 5 weight pointers: [%p, %p, %p, %p, %p]", + expandLoraWeightPtrs.size() > 0 ? expandLoraWeightPtrs[0] : nullptr, + expandLoraWeightPtrs.size() > 1 ? expandLoraWeightPtrs[1] : nullptr, + expandLoraWeightPtrs.size() > 2 ? expandLoraWeightPtrs[2] : nullptr, + expandLoraWeightPtrs.size() > 3 ? expandLoraWeightPtrs[3] : nullptr, + expandLoraWeightPtrs.size() > 4 ? expandLoraWeightPtrs[4] : nullptr); + + // Check if all tokens use the same LoRA weights and ranks (for unified GEMM path) + bool useUnifiedGemm = true; + if (expandLoraRanks.size() > 1) + { + int32_t firstRank = expandLoraRanks[0]; + void const* firstWeightPtr = expandLoraWeightPtrs[0]; + + TLLM_LOG_INFO( + "DEBUG: Checking for unified GEMM path - First rank: %d, First weight ptr: %p", firstRank, firstWeightPtr); + + for (size_t i = 1; i < expandLoraRanks.size(); ++i) + { + if (expandLoraRanks[i] != firstRank || expandLoraWeightPtrs[i] != firstWeightPtr) + { + useUnifiedGemm = false; + TLLM_LOG_INFO("DEBUG: Found different rank or weight ptr at index %zu: rank=%d, ptr=%p", i, + expandLoraRanks[i], expandLoraWeightPtrs[i]); + break; + } + } + } + + TLLM_LOG_INFO("DEBUG: GEMM path selection: %s", useUnifiedGemm ? "UNIFIED GEMM" : "SPLIT GEMM (CUTLASS)"); + + // Add a debug print to show the actual GEMM path being used + TLLM_LOG_INFO("DEBUG: Calling LoraImpl->run with useUnifiedGemm=%d", useUnifiedGemm); + + mLoraImpl->run(numTokens, numReqs, input.data_ptr(), expandLoraRanks.data(), expandLoraWeightPtrs.data(), + weight_index, output.data(), workspace.data_ptr(), stream); + sync_check_cuda_error(stream); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return output_torch; +} + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "lora_grouped_gemm(Tensor input, " + "Tensor host_request_types, " + "Tensor [] lora_ranks, " + "Tensor [] lora_weights_pointers, " + "Tensor host_context_lengths, " + "int [] output_hidden_sizes, " + "bool transA, " + "bool transB, " + "int max_low_rank, " + "int weight_index, " + "bool isRemoveInputPadding) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("lora_grouped_gemm", &torch_ext::lora_grouped_gemm); +} diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 5beaec0cb8..778c40f736 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -11,6 +11,7 @@ from ..attention_backend.utils import create_attention from ..distributed import AllReduceParams, ParallelConfig, TensorParallelMode from ..model_config import ModelConfig +from ..peft.lora.layer import LoraLayer, LoraModuleType from .linear import Linear, WeightMode, WeightsLoadingConfig from .rms_norm import RMSNorm from .rotary_embedding import RotaryEmbedding @@ -134,6 +135,19 @@ def __init__( self.pos_embd_params = pos_embd_params self.rotary_emb = rotary_emb + # These two modules are mutually exclusive - either splitted_qkv_lora or fused_qkv_lora will be used, + # but never both at the same time. splitted_qkv_lora handles Q,K,V separately while fused_qkv_lora + # handles them as a single fused operation. + self.splitted_qkv_lora = LoraLayer([ + LoraModuleType.ATTENTION_Q, LoraModuleType.ATTENTION_K, + LoraModuleType.ATTENTION_V + ], [self.q_size, self.kv_size, self.kv_size]) + self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV], + [self.q_size + 2 * self.kv_size]) + + self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], + [self.hidden_size]) + if not config.skip_create_weights: self.create_weights() @@ -157,6 +171,7 @@ def forward( CAUSAL, mrope_config: Optional[dict] = None, all_reduce_params: Optional[AllReduceParams] = None, + lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) @@ -183,6 +198,17 @@ def forward( k = self.qk_norm(k) qkv = torch.concat([q, k, v], dim=-1) + if lora_params is not None: + qkv_lora = self.splitted_qkv_lora(hidden_states, lora_params, + self.layer_idx) + if qkv_lora is not None: + qkv = qkv + qkv_lora + + qkv_lora = self.fused_qkv_lora(hidden_states, lora_params, + self.layer_idx) + if qkv_lora is not None: + qkv = qkv + qkv_lora + if is_fused_qkv: if self.pos_embd_params is None and position_ids is not None: q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], @@ -216,6 +242,11 @@ def forward( attn_output = self.o_proj(attn_output, all_reduce_params=all_reduce_params) + if lora_params is not None: + attn_lora_output = self.o_lora(attn_output, lora_params, + self.layer_idx) + if attn_lora_output is not None: + attn_output = attn_output + attn_lora_output return attn_output diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 409fe55538..31213cdc68 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -14,6 +14,7 @@ from ...models.modeling_utils import QuantConfig from ..distributed import ParallelConfig, TensorParallelMode from ..utils import Fp4QuantizedTensor +from ..peft.lora.layer import LoraLayer, LoraModuleType E2M1_MAX = 6.0 @@ -56,7 +57,7 @@ def maybe_convert_to_torch_tensor(tensor: torch.Tensor, tensor_shape = weight.get_shape() def maybe_convert_to_torch_tensor( - tensor, indices: Union[slice | tuple[slice]] = slice(None)): + tensor, indices: Union[slice, tuple[slice]] = slice(None)): return tensor[indices].to(device) else: raise ValueError(f'unsupported weight type: {type(weight)}') @@ -139,10 +140,12 @@ def __init__(self, weights_loading_config: Optional[WeightsLoadingConfig] = None, is_expert: bool = False, skip_create_weights: bool = False, - use_custom_cublas_mm: bool = False): + use_custom_cublas_mm: bool = False, + layer_idx: Optional[int] = None): from ..distributed import AllReduce super().__init__() + self.layer_idx = layer_idx self.has_bias = bias self.dtype = dtype self.parallel_config = parallel_config or ParallelConfig() @@ -181,6 +184,11 @@ def __init__(self, self.is_expert = is_expert self.use_custom_cublas_mm = use_custom_cublas_mm + self.linear_lora = LoraLayer( + [LoraModuleType.DENSE], + [self.out_features + ]) # todo (dafrimi) didn't add binding to module type + if not skip_create_weights: self.create_weights() @@ -354,10 +362,11 @@ def apply_linear(self, input, weight, bias): return output def forward( - self, - input: Union[torch.Tensor, Fp4QuantizedTensor], - *, - all_reduce_params: Optional[AllReduceParams] = None + self, + input: Union[torch.Tensor, Fp4QuantizedTensor], + *, + all_reduce_params: Optional[AllReduceParams] = None, + lora_params: Optional[dict] = None, ) -> torch.Tensor: from ..distributed import allgather @@ -388,6 +397,12 @@ def forward( else: output = self.apply_linear(input, self.weight, self.bias) + if lora_params is not None: + linear_lora_output = self.linear_lora(input, lora_params, + self.layer_idx) + if linear_lora_output is not None: + output = output + linear_lora_output + return output def load_weights(self, weights: List[Dict]): diff --git a/tensorrt_llm/_torch/modules/mlp.py b/tensorrt_llm/_torch/modules/mlp.py index 0e5d0d0ab6..4e824f6251 100644 --- a/tensorrt_llm/_torch/modules/mlp.py +++ b/tensorrt_llm/_torch/modules/mlp.py @@ -6,6 +6,7 @@ from ..distributed import ParallelConfig, TensorParallelMode from ..model_config import ModelConfig +from ..peft.lora.layer import LoraLayer, LoraModuleType from .linear import Linear, WeightMode, WeightsLoadingConfig @@ -18,8 +19,11 @@ def __init__(self, bias: bool, activation: Callable[[torch.Tensor], torch.Tensor] = None, dtype: Optional[torch.dtype] = None, - config: Optional[ModelConfig] = None): + config: Optional[ModelConfig] = None, + layer_idx: Optional[int] = None): + super().__init__() + self.layer_idx = layer_idx self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.activation = activation @@ -43,8 +47,8 @@ def __init__(self, weights_loading_config=WeightsLoadingConfig( weight_mode=WeightMode.VANILLA), quant_config=config.get_quant_config(), - skip_create_weights=config.skip_create_weights, - ) + skip_create_weights=config.skip_create_weights) + self.down_proj = Linear( self.intermediate_size, self.hidden_size, @@ -58,8 +62,32 @@ def __init__(self, pipeline_parallel_size=config.mapping.pp_size, parallel_rank=config.mapping.rank), quant_config=config.get_quant_config(), - skip_create_weights=config.skip_create_weights, - ) + skip_create_weights=config.skip_create_weights) + + self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H], + [self.intermediate_size]) + self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], + [self.hidden_size]) + + def forward( + self, + x: torch.Tensor, + lora_params: Optional[dict] = None, + ) -> torch.Tensor: + + x_up = self.up_proj(x) + if lora_params is not None: + assert self.layer_idx is not None, "layer_idx is required for lora" + x_up_lora = self.up_lora(x, lora_params, self.layer_idx) + if x_up_lora is not None: + x_up = x_up + x_up_lora + + x_act = self.activation(x_up) + x_down = self.down_proj(x_act) + + if lora_params is not None: + x_down_lora = self.down_lora(x_act, lora_params, self.layer_idx) + if x_down_lora is not None: + x_down = x_down + x_down_lora - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(self.activation(self.up_proj(x))) + return x_down diff --git a/tensorrt_llm/_torch/peft/__init__.py b/tensorrt_llm/_torch/peft/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/peft/lora/__init__.py b/tensorrt_llm/_torch/peft/lora/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/peft/lora/layer.py b/tensorrt_llm/_torch/peft/lora/layer.py new file mode 100644 index 0000000000..d78d4e1608 --- /dev/null +++ b/tensorrt_llm/_torch/peft/lora/layer.py @@ -0,0 +1,161 @@ +from enum import IntEnum +from typing import Dict, List, Optional + +import torch + + +class LoraModuleType(IntEnum): + """Enum class representing different types of modules that can have LoRA adapters. + + This enum maps to the different attention and MLP components in a transformer model + that can be adapted using LoRA weights. + """ + ATTENTION_QKV = 0 # Combined QKV projection + ATTENTION_Q = 1 # Query projection + ATTENTION_K = 2 # Key projection + ATTENTION_V = 3 # Value projection + ATTENTION_DENSE = 4 # Output projection after attention + + MLP_H_TO_4H = 5 # First MLP projection (hidden to 4x hidden) + MLP_4H_TO_H = 6 # Second MLP projection (4x hidden back to hidden) + MLP_GATE = 7 # Gate projection in MLP + + CROSS_ATTENTION_QKV = 8 # Cross-attention QKV projection + CROSS_ATTENTION_Q = 9 # Cross-attention Query projection + CROSS_ATTENTION_K = 10 # Cross-attention Key projection + CROSS_ATTENTION_V = 11 # Cross-attention Value projection + CROSS_ATTENTION_DENSE = 12 # Cross-attention output projection + + MOE_H_TO_4H = 13 # MoE first projection + MOE_4H_TO_H = 14 # MoE second projection + MOE_GATE = 15 # MoE gate projection + MOE_ROUTER = 16 # MoE router + + MLP_ROUTER = 17 # MLP router + MLP_GATE_UP = 18 # Combined gate and up projections + + # TODO (dafrimi): added dense layer + DENSE = 19 # Dense layer + + def __str__(self): + """Return the name of the enum value.""" + return self.name + + @classmethod + def from_string(cls, name: str) -> "LoraModuleType": + """Convert a string to the corresponding LoraModuleType. + + Args: + name: The string name of the module type + + Returns: + The corresponding LoraModuleType enum value + + Raises: + ValueError: If the name doesn't match any LoraModuleType + """ + try: + return cls[name.upper()] + except KeyError: + raise ValueError(f"Unknown LoRA module type: {name}") + + @property + def is_attention(self) -> bool: + """Check if this is an attention module type.""" + return self in { + self.ATTENTION_QKV, self.ATTENTION_Q, self.ATTENTION_K, + self.ATTENTION_V, self.ATTENTION_DENSE, self.CROSS_ATTENTION_QKV, + self.CROSS_ATTENTION_Q, self.CROSS_ATTENTION_K, + self.CROSS_ATTENTION_V, self.CROSS_ATTENTION_DENSE + } + + @property + def is_mlp(self) -> bool: + """Check if this is an MLP module type.""" + return self in { + self.MLP_H_TO_4H, self.MLP_4H_TO_H, self.MLP_GATE, self.MLP_GATE_UP, + self.MLP_ROUTER + } + + @property + def is_moe(self) -> bool: + """Check if this is a Mixture of Experts (MoE) module type.""" + return self in { + self.MOE_H_TO_4H, self.MOE_4H_TO_H, self.MOE_GATE, self.MOE_ROUTER + } + + +class LoraLayer(torch.nn.Module): + + def __init__(self, lora_module_types: List[LoraModuleType], + output_hidden_sizes: List[int]): + super().__init__() + + self.lora_module_types = lora_module_types + self.output_hidden_sizes = output_hidden_sizes + assert len(lora_module_types) == len(output_hidden_sizes) + + def forward(self, x, lora_params: Dict, + layer_idx: int) -> Optional[torch.Tensor]: + if bool(lora_params): + + lora_ranks = [] + lora_weight_pointers = [] + + active_lora_module_ids = [] + for module_idx in self.lora_module_types: + module_idx = int(module_idx) + if module_idx in lora_params[layer_idx]: + active_lora_module_ids.append(module_idx) + + lora_params[layer_idx][module_idx]['is_dora'] + lora_ranks.append( + lora_params[layer_idx][module_idx]['adapter_size']) + lora_weight_pointers.append( + lora_params[layer_idx][module_idx]['weight_pointers']) + + num_seqs = lora_params['num_seqs'] + + if len(active_lora_module_ids) == 0: + return None + else: + print(f"x dtype: {x.dtype}") + lora_outputs = torch.ops.trtllm.lora_grouped_gemm( + x, + lora_params['host_request_types'][:num_seqs], + lora_ranks, + lora_weight_pointers, + lora_params['prompt_lens_cpu'][:num_seqs], + self.output_hidden_sizes, + False, # transA + True, # transB + max([r.max() for r in lora_ranks]), + 0, + False, # isRemoveInputPadding (set to False for fixed sequence length, if its True inputs needs to be with shape [numTokens, dim] and to pass prompt_lens_cpu) + # TODO (dafrimi): do we need to pass isRemoveInputPadding params in the lora_parms? + ) + if isinstance(lora_outputs, torch.Tensor): + return lora_outputs + else: + # For multiple LoRA modules, some might not be executed in grouped gemm. + # For those modules not executed, we create zero tensors with matching dimensions. + # Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order. + lora_output = [] + for module_idx in self.lora_module_types: + if int(module_idx) in active_lora_module_ids: + lora_output.append(lora_outputs.pop(0)) + else: + lora_output.append( + torch.zeros(list(x.shape[:-1]) + [ + self.output_hidden_sizes[ + self.lora_module_types.index( + module_idx)] + ], + dtype=x.dtype, + device=x.device)) + + lora_output = torch.cat(lora_output, dim=-1) + return lora_output + + else: + return None diff --git a/tensorrt_llm/thop/loraOp.cpp b/tensorrt_llm/thop/loraOp.cpp new file mode 100644 index 0000000000..c304f720ad --- /dev/null +++ b/tensorrt_llm/thop/loraOp.cpp @@ -0,0 +1,7 @@ +mLoraImpl->run(numTokens, numReqs, input.data_ptr(), expandLoraRanks.data(), expandLoraWeightPtrs.data(), weight_index, + output.data(), workspace.data_ptr(), stream); + +sync_check_cuda_error(stream); + +TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +return output_torch; diff --git a/tests/unittest/_torch/modules/tests_lora_modules/gated_mlp_lora_example.py b/tests/unittest/_torch/modules/tests_lora_modules/gated_mlp_lora_example.py new file mode 100644 index 0000000000..44b60b8cff --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/gated_mlp_lora_example.py @@ -0,0 +1,130 @@ +import torch + +from tensorrt_llm._torch.modules.gated_mlp import GatedMLP +from tensorrt_llm._torch.peft.lora.layer import LoraModuleType + + +def create_gated_mlp_example(): + # Configuration + hidden_size = 64 + intermediate_size = hidden_size * 4 + batch_size = 1 + seq_len = 16 + dtype = torch.float16 + device = torch.device('cuda') + + # Create GatedMLP module + gated_mlp = GatedMLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + bias=True, + dtype=dtype, + layer_idx=0 # Important for LoRA + ).to(device) + + # Create input tensor + hidden_states = torch.randn(size=[batch_size, seq_len, hidden_size], + dtype=dtype, + device=device) + + # Create LoRA parameters + lora_rank = 8 + + # Create weights for gate projection + gate_weight_in = torch.randn(hidden_size, + lora_rank, + device=device, + dtype=dtype).T + gate_weight_out = torch.randn( + lora_rank, + intermediate_size, # Gate projection size + device=device, + dtype=dtype).T + + # Create weights for up projection + up_weight_in = torch.randn(hidden_size, + lora_rank, + device=device, + dtype=dtype).T + up_weight_out = torch.randn( + lora_rank, + intermediate_size, # Up projection size + device=device, + dtype=dtype).T + + # Create weights for down projection + down_weight_in = torch.randn(intermediate_size, + lora_rank, + device=device, + dtype=dtype).T + down_weight_out = torch.randn(lora_rank, + hidden_size, + device=device, + dtype=dtype).T + + # Make weights contiguous + gate_weight_in = gate_weight_in.contiguous() + gate_weight_out = gate_weight_out.contiguous() + up_weight_in = up_weight_in.contiguous() + up_weight_out = up_weight_out.contiguous() + down_weight_in = down_weight_in.contiguous() + down_weight_out = down_weight_out.contiguous() + + # Create LoRA parameters dictionary + lora_params = { + 'num_seqs': batch_size, + 'host_request_types': torch.zeros(batch_size, dtype=torch.int32), + 'prompt_lens_cpu': torch.tensor([seq_len] * batch_size), + 0: { # layer_idx + LoraModuleType.MLP_H_TO_4H: { # Up projection + 'adapter_size': + torch.tensor([lora_rank]), + 'weight_pointers': + torch.tensor( + [[up_weight_out.data_ptr(), + up_weight_in.data_ptr()]]), + 'is_dora': + False, + 'weight_tensors': [up_weight_out, up_weight_in] + }, + LoraModuleType.MLP_GATE: { # Gate projection + 'adapter_size': + torch.tensor([lora_rank]), + 'weight_pointers': + torch.tensor( + [[gate_weight_out.data_ptr(), + gate_weight_in.data_ptr()]]), + 'is_dora': + False, + 'weight_tensors': [gate_weight_out, gate_weight_in] + }, + LoraModuleType.MLP_4H_TO_H: { # Down projection + 'adapter_size': + torch.tensor([lora_rank]), + 'weight_pointers': + torch.tensor( + [[down_weight_out.data_ptr(), + down_weight_in.data_ptr()]]), + 'is_dora': + False, + 'weight_tensors': [down_weight_out, down_weight_in] + } + } + } + + # Run forward pass + output = gated_mlp( + hidden_states.squeeze( + 0), # Remove batch dimension as expected by the module + lora_params=lora_params) + + print(f"Input shape: {hidden_states.shape}") + print(f"Output shape: {output.shape}") + print(f"Output dtype: {output.dtype}") + print(f"Output device: {output.device}") + + return output + + +if __name__ == "__main__": + output = create_gated_mlp_example() diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pivot_vs_trt.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pivot_vs_trt.py new file mode 100644 index 0000000000..6f4c477d57 --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pivot_vs_trt.py @@ -0,0 +1,525 @@ +import unittest + +import numpy as np +import torch +from transformers import LlamaConfig + +import tensorrt_llm +from tensorrt_llm import Tensor +from tensorrt_llm._torch.attention_backend.interface import \ + PredefinedAttentionMask +from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_llama import LlamaAttention +from tensorrt_llm._torch.peft.lora.layer import LoraModuleType +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._utils import str_dtype_to_torch +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.layers import (Attention, AttentionMaskType, AttentionParams, + KeyValueCacheParams) +from tensorrt_llm.layers.lora import Lora, LoraParams +from tensorrt_llm.mapping import Mapping + + +class TestLoraAttentionPivotVsTRT(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.batch_size = 1 + cls.seq_len = 8 + cls.hidden_size = 64 + cls.head_num = 1 + cls.num_hidden_layers = 1 + cls.dtype = 'float16' + cls.torch_dtype = str_dtype_to_torch(cls.dtype) + cls.device = torch.device('cuda') + + # KV cache parameters + cls.num_blocks = 4 + cls.tokens_per_block = 32 + + cls.llama_config = LlamaConfig(hidden_size=cls.hidden_size, + num_attention_heads=cls.head_num, + num_hidden_layers=cls.num_hidden_layers, + intermediate_size=256, + max_position_embeddings=512, + rms_norm_eps=1e-5, + vocab_size=32000, + num_key_value_heads=cls.head_num, + torch_dtype=cls.torch_dtype) + + # Create KV cache manager + head_dim = cls.llama_config.hidden_size // cls.llama_config.num_attention_heads + mapping = Mapping(world_size=1, tp_size=1, rank=0) + kv_cache_config = KvCacheConfig(max_tokens=cls.num_blocks * + cls.tokens_per_block) + cls.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + kv_cache_type=tensorrt_llm.bindings.internal.batch_manager. + CacheType.SELF, + num_layers=cls.llama_config.num_hidden_layers, + num_kv_heads=cls.llama_config.num_key_value_heads, + head_dim=head_dim, + tokens_per_block=cls.tokens_per_block, + max_seq_len=cls.num_blocks * cls.tokens_per_block, + max_batch_size=cls.batch_size, + mapping=mapping, + dtype=tensorrt_llm.bindings.DataType.HALF) + + @classmethod + def tearDownClass(cls): + cls.kv_cache_manager.shutdown() + + def _create_attention_inputs(self): + hidden_states = torch.empty( + size=[self.batch_size, self.seq_len, self.hidden_size], + dtype=self.torch_dtype, + device='cuda') + hidden_states.normal_(0.0, 0.02) + + # Create weights + q_weight = torch.empty(size=[self.hidden_size, self.hidden_size], + dtype=self.torch_dtype) + torch.nn.init.xavier_uniform_(q_weight) + + # Set K and V and O weights to identity matrix + eye_weight = torch.eye(self.hidden_size, dtype=self.torch_dtype) + qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1) + out_weight = eye_weight + + return hidden_states, qkv_weight, out_weight + + def _create_lora_params(self): + lora_ranks_list = [8] + + host_context_lengths = torch.Tensor( + [self.seq_len for _ in range(self.batch_size)]).to(torch.int32) + lora_ranks = torch.Tensor(lora_ranks_list * self.batch_size).to( + torch.int32) + host_request_types = torch.zeros_like(host_context_lengths, + device='cpu').int() + + lora_weight_ins = [ + torch.randn(self.hidden_size, lora_rank, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + lora_weight_outs = [ + torch.randn(lora_rank, self.hidden_size, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + + lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins] + lora_weight_outs = [ + tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs + ] + + # Create weight pointers for TensorRT + lora_weights_pointers = [] + for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs): + lora_weights_pointers.append(in_ptr.data_ptr()) + lora_weights_pointers.append(out_ptr.data_ptr()) + + lora_weights_pointers = torch.LongTensor(lora_weights_pointers).to( + torch.int64).reshape([self.batch_size, 2]) + + return { + 'lora_ranks': lora_ranks, + 'host_context_lengths': host_context_lengths, + 'host_request_types': host_request_types, + 'lora_weights_pointers': lora_weights_pointers, + 'lora_weight_ins': lora_weight_ins, + 'lora_weight_outs': lora_weight_outs + } + + def _setup_attention_module(self, qkv_weight, out_weight): + """Set up the attention module with weights.""" + model_config = ModelConfig(pretrained_config=self.llama_config, + attn_backend="VANILLA") + layer_idx = 0 + attention_module = LlamaAttention(model_config, layer_idx=layer_idx).to( + self.device).to(self.torch_dtype) + + # Set weights + attention_module.qkv_proj.weight.data = torch.from_numpy( + np.ascontiguousarray(qkv_weight.cpu().numpy().transpose( + [1, 0]))).to(self.device) + attention_module.o_proj.weight.data = torch.from_numpy( + np.ascontiguousarray(out_weight.cpu().numpy().transpose( + [1, 0]))).to(self.device) + + return attention_module, model_config + + def _create_attention_metadata(self, model_config): + sequence_lengths = [self.seq_len] + past_seen_tokens = [0] + request_ids = [0] + token_nums = [self.seq_len] + prompt_lens = token_nums + + self.kv_cache_manager.add_dummy_requests(request_ids, token_nums) + + metadata_cls = get_attention_backend(model_config.attn_backend).Metadata + return metadata_cls( + seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32), + num_contexts=len(sequence_lengths), + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=past_seen_tokens, + ), + kv_cache_manager=self.kv_cache_manager, + request_ids=request_ids, + prompt_lens=prompt_lens, + max_num_requests=self.batch_size, + max_num_tokens=self.batch_size * self.seq_len, + ) + + def _setup_trt_network(self, hidden_states, lora_params, attention_module): + builder = tensorrt_llm.Builder() + net = builder.create_network() + net.plugin_config.gpt_attention_plugin = self.dtype + net.plugin_config.lora_plugin = self.dtype + net.plugin_config.remove_input_padding = True + net.plugin_config.paged_kv_cache = True + + with tensorrt_llm.net_guard(net): + # Create LoRA tensors + host_request_types_tensor = Tensor( + name='host_request_types', + shape=[lora_params['host_request_types'].shape[0]], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths_tensor = Tensor( + name='host_context_lengths', + shape=[lora_params['host_context_lengths'].shape[0]], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_ranks_tensor = Tensor( + name='lora_ranks', + shape=(lora_params['lora_ranks'].shape[0], ), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_weights_pointers_tensor = Tensor( + name='lora_weights_pointers', + shape=lora_params['lora_weights_pointers'].shape, + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + + # Create tensors for GPT Attention Plugin + sequence_length_tensor = Tensor( + name='sequence_length', + shape=[self.batch_size], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + context_lengths_tensor = Tensor( + name='context_lengths', + shape=[self.batch_size], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_past_key_value_lengths_tensor = Tensor( + name='host_past_key_value_lengths', + shape=[self.batch_size], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_max_attention_window_sizes_tensor = Tensor( + name='host_max_attention_window_sizes', + shape=[self.num_hidden_layers], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_sink_token_length_tensor = Tensor( + name='host_sink_token_length', + shape=[self.num_hidden_layers], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + cache_indirection_tensor = Tensor( + name='cache_indirection', + shape=[self.batch_size, 1, 1], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + # Use dummy block offsets as we are in context phase and don't actually need KV cache values yet + # Shape: [num_layers, batch_size, 2, max_blocks_per_seq] + max_blocks_per_seq = (self.seq_len + self.tokens_per_block - + 1) // self.tokens_per_block + kv_cache_block_offsets_tensor = Tensor( + name='kv_cache_block_offsets', + shape=[ + self.num_hidden_layers, self.batch_size, 2, + max_blocks_per_seq + ], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_kv_cache_block_offsets_tensor = Tensor( + name='host_kv_cache_block_offsets', + shape=[ + self.num_hidden_layers, self.batch_size, 2, + max_blocks_per_seq + ], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + # Add tensors for perf knobs and context progress + host_runtime_perf_knobs_tensor = Tensor( + name='host_runtime_perf_knobs', + shape=[1], # Typically a single int64 value + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + host_context_progress_tensor = Tensor( + name='host_context_progress', + shape=[self.batch_size], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + # Add tensors for paged kv cache pool management + host_kv_cache_pool_pointers_tensor = Tensor( + name='host_kv_cache_pool_pointers', + shape=[2], # Pointers to K and V pools + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + host_kv_cache_pool_mapping_tensor = Tensor( + name='host_kv_cache_pool_mapping', + shape=[self.batch_size], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + + # Create LoRA parameters object + lora_layer_params = LoraParams( + lora_ranks=[{ + "attn_q_lora_ranks": lora_ranks_tensor, + "attn_k_lora_ranks": lora_ranks_tensor, + "attn_v_lora_ranks": lora_ranks_tensor, + "attn_dense_lora_ranks": lora_ranks_tensor, + }], + lora_weights_pointers=[{ + "attn_q_lora_weights_pointers": + lora_weights_pointers_tensor, + "attn_k_lora_weights_pointers": + lora_weights_pointers_tensor, + "attn_v_lora_weights_pointers": + lora_weights_pointers_tensor, + "attn_dense_lora_weights_pointers": + lora_weights_pointers_tensor, + }], + host_context_lengths=host_context_lengths_tensor, + host_request_types=host_request_types_tensor, + ) + + # Create AttentionParams and KeyValueCacheParams + attention_params = AttentionParams( + sequence_length=sequence_length_tensor, + context_lengths=context_lengths_tensor, + host_context_lengths= + host_context_lengths_tensor, # Use the same tensor on host + max_context_length=self. + seq_len, # Use current seq_len as max for context phase + host_request_types= + host_request_types_tensor, # Use the same tensor on host + host_runtime_perf_knobs=host_runtime_perf_knobs_tensor, + host_context_progress=host_context_progress_tensor) + + kv_cache_params = KeyValueCacheParams( + host_past_key_value_lengths=host_past_key_value_lengths_tensor, + host_max_attention_window_sizes= + host_max_attention_window_sizes_tensor, + host_sink_token_length=host_sink_token_length_tensor, + kv_cache_block_offsets=kv_cache_block_offsets_tensor, + host_kv_cache_block_offsets=host_kv_cache_block_offsets_tensor, + cache_indirection=cache_indirection_tensor, + # past_key_value needs to be None for context phase + past_key_value=None, + # Add pool pointers and mapping + host_kv_cache_pool_pointers=host_kv_cache_pool_pointers_tensor, + host_kv_cache_pool_mapping=host_kv_cache_pool_mapping_tensor) + + attn_layer = Attention( + local_layer_idx=0, + hidden_size=hidden_states.shape[-1], + num_attention_heads=1, + num_kv_heads=1, # Added num_kv_heads + max_position_embeddings=self.llama_config. + max_position_embeddings, # Use config value + attention_mask_type=AttentionMaskType.causal, + bias=False) + + attn_layer.qkv_lora = Lora( + in_hidden_size=attn_layer.hidden_size, + out_hidden_sizes=[ + attn_layer.num_attention_heads * + attn_layer.attention_head_size, + attn_layer.num_attention_kv_heads * + attn_layer.attention_head_size, + attn_layer.num_attention_kv_heads * + attn_layer.attention_head_size + ], + max_low_rank=8, + ) + + attn_layer.dense.lora = Lora( + in_hidden_size=attn_layer.dense.in_features, + out_hidden_sizes=[attn_layer.dense.out_features], + max_low_rank=8, + ) + + # Set attention layer weights + attn_layer.qkv.weight.value = attention_module.qkv_proj.weight.data + attn_layer.dense.weight.value = attention_module.o_proj.weight.data + + # Create input tensor - already flattened to [numToken, dim] + trt_hidden_states = Tensor( + name='hidden_states', + shape=hidden_states.reshape(-1, hidden_states.shape[-1]).shape, + dtype=tensorrt_llm.str_dtype_to_trt(self.dtype)) + + # Update forward call for GPT Attention Plugin + output, _ = attn_layer( # GPT Attention Plugin returns a tuple (context, past_key_value) + hidden_states=trt_hidden_states, + lora_layer_params=lora_layer_params, # Use the renamed object + attention_params=attention_params, + kv_cache_params=kv_cache_params, + use_cache=True # Must be True for GPT Attention Plugin + ) + output.mark_output('output', + tensorrt_llm.str_dtype_to_trt(self.dtype)) + + return builder, net + + def _run_trt_inference(self, builder, net, hidden_states, lora_params): + builder_config = builder.create_builder_config(name='attention', + precision=self.dtype) + engine_buffer = builder.build_engine(net, builder_config) + session = tensorrt_llm.runtime.Session.from_serialized_engine( + engine_buffer) + + stream = torch.cuda.current_stream().cuda_stream + + # Prepare inputs for GPT Attention Plugin + sequence_length_tensor = torch.tensor([self.seq_len] * self.batch_size, + dtype=torch.int32, + device='cuda') + context_lengths_tensor = torch.tensor([self.seq_len] * self.batch_size, + dtype=torch.int32, + device='cuda') + host_past_key_value_lengths_tensor = torch.tensor( + [0] * self.batch_size, + dtype=torch.int32) # Start from 0 for context phase + max_seq_len = self.num_blocks * self.tokens_per_block + host_max_attention_window_sizes_tensor = torch.tensor( + [max_seq_len] * self.num_hidden_layers, dtype=torch.int32) + host_sink_token_length_tensor = torch.tensor([0] * + self.num_hidden_layers, + dtype=torch.int32) + cache_indirection_tensor = torch.arange(self.batch_size, + dtype=torch.int32, + device='cuda').reshape( + self.batch_size, 1, 1) + # Create dummy block offsets for context phase + max_blocks_per_seq = (self.seq_len + self.tokens_per_block - + 1) // self.tokens_per_block + shape = (self.num_hidden_layers, self.batch_size, 2, max_blocks_per_seq) + kv_cache_block_offsets_tensor = torch.zeros(shape, + dtype=torch.int32, + device='cuda') + host_kv_cache_block_offsets_tensor = torch.zeros( + shape, dtype=torch.int32) # Host copy + # Add tensors for paged kv cache pool management (dummy values for context phase) + # Get the actual pointers from the cache manager if needed for generation phase + dummy_pool_pointers = torch.tensor([0, 0], + dtype=torch.int64) # Dummy pointers + host_kv_cache_pool_pointers_tensor = dummy_pool_pointers + host_kv_cache_pool_mapping_tensor = torch.zeros( + [self.batch_size], dtype=torch.int32) # Map all to pool 0 + host_runtime_perf_knobs_tensor = torch.tensor( + [0], dtype=torch.int64) # Default value + host_context_progress_tensor = torch.zeros( + [self.batch_size], + dtype=torch.int32) # Default value for context phase + + inputs = { + 'hidden_states': hidden_states.reshape(-1, hidden_states.shape[-1]), + 'host_request_types': lora_params['host_request_types'], + 'host_context_lengths': lora_params['host_context_lengths'], + 'lora_ranks': lora_params['lora_ranks'], + 'lora_weights_pointers': lora_params['lora_weights_pointers'], + # Inputs for GPT Attention Plugin + 'sequence_length': sequence_length_tensor, + 'context_lengths': context_lengths_tensor, + 'host_past_key_value_lengths': host_past_key_value_lengths_tensor, + 'host_max_attention_window_sizes': + host_max_attention_window_sizes_tensor, + 'host_sink_token_length': host_sink_token_length_tensor, + 'cache_indirection': cache_indirection_tensor, + 'kv_cache_block_offsets': kv_cache_block_offsets_tensor, + 'host_kv_cache_block_offsets': host_kv_cache_block_offsets_tensor, + 'host_runtime_perf_knobs': host_runtime_perf_knobs_tensor, + 'host_context_progress': host_context_progress_tensor, + # Add pool pointers and mapping to inputs + 'host_kv_cache_pool_pointers': host_kv_cache_pool_pointers_tensor, + 'host_kv_cache_pool_mapping': host_kv_cache_pool_mapping_tensor, + } + + outputs = { + 'output': + # Output shape is [num_tokens, hidden_size] when remove_input_padding is True + torch.empty( + hidden_states.reshape(-1, hidden_states.shape[-1]).shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch(self.dtype), + device='cuda'), + } + + session.run(inputs=inputs, outputs=outputs, stream=stream) + torch.cuda.synchronize() + + # Reshape output back to [batch_size, seq_len, hidden_size] for comparison + return outputs['output'].reshape(hidden_states.shape) + + def test_attention_with_lora(self): + hidden_states, qkv_weight, out_weight = self._create_attention_inputs() + + lora_params = self._create_lora_params() + + attention_module, model_config = self._setup_attention_module( + qkv_weight, out_weight) + + attn_metadata = self._create_attention_metadata(model_config) + builder, net = self._setup_trt_network(hidden_states, lora_params, + attention_module) + trt_output = self._run_trt_inference(builder, net, hidden_states, + lora_params) + + lora_params_pivot = { + 'num_seqs': self.batch_size, + 'host_request_types': lora_params['host_request_types'], + 'prompt_lens_cpu': lora_params['host_context_lengths'], + 0: { # layer_idx + LoraModuleType.ATTENTION_Q: { # Module type + 'adapter_size': + lora_params['lora_ranks'], + 'weight_pointers': + lora_params['lora_weights_pointers'], + 'is_dora': + False, + }, + LoraModuleType.ATTENTION_K: { + 'adapter_size': + lora_params['lora_ranks'], + 'weight_pointers': lora_params['lora_weights_pointers'], + 'is_dora': + False, + }, + LoraModuleType.ATTENTION_V: { + 'adapter_size': + lora_params['lora_ranks'], + 'weight_pointers': + lora_params['lora_weights_pointers'], + 'is_dora': + False, + }, + LoraModuleType.ATTENTION_DENSE: { + 'adapter_size': + lora_params['lora_ranks'], + 'weight_pointers': + lora_params['lora_weights_pointers'], + 'is_dora': + False, + } + } + } + + with torch.inference_mode(): + attn_metadata.prepare() + hidden_states_reshaped = hidden_states.reshape( + -1, hidden_states.shape[-1]) + + pivot_output = attention_module( + position_ids=None, + hidden_states=hidden_states_reshaped, + attn_metadata=attn_metadata, + attention_mask=PredefinedAttentionMask.CAUSAL, + lora_params=lora_params_pivot) + + torch.testing.assert_close(pivot_output, trt_output, atol=2e-3, rtol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pivot_vs_vanilla_torch.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pivot_vs_vanilla_torch.py new file mode 100644 index 0000000000..764a469c42 --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pivot_vs_vanilla_torch.py @@ -0,0 +1,303 @@ +import os +import sys +import unittest + +import numpy as np +import torch +from transformers import LlamaConfig + +import tensorrt_llm +from tensorrt_llm._torch.attention_backend.interface import \ + PredefinedAttentionMask +from tensorrt_llm._torch.attention_backend.utils import get_attention_backend +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models.modeling_llama import LlamaAttention +from tensorrt_llm._torch.peft.lora.layer import LoraModuleType +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.mapping import Mapping + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../../utils')) +from torch_ref import attention_qkvpacked_ref + + +class TestLoraAttentionPivotVsVanilla(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.batch_size = 1 + cls.seq_len = 16 + cls.hidden_size = 64 + cls.head_num = 1 + cls.num_hidden_layers = 1 + cls.dtype = torch.float16 + cls.device = torch.device('cuda') + + # KV cache parameters + cls.num_blocks = 4 + cls.tokens_per_block = 32 + + cls.llama_config = LlamaConfig(hidden_size=cls.hidden_size, + num_attention_heads=cls.head_num, + num_hidden_layers=cls.num_hidden_layers, + intermediate_size=256, + max_position_embeddings=512, + rms_norm_eps=1e-5, + vocab_size=32000, + num_key_value_heads=cls.head_num, + torch_dtype=cls.dtype) + + head_dim = cls.llama_config.hidden_size // cls.llama_config.num_attention_heads + mapping = Mapping(world_size=1, tp_size=1, rank=0) + kv_cache_config = KvCacheConfig(max_tokens=cls.num_blocks * + cls.tokens_per_block) + cls.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + kv_cache_type=tensorrt_llm.bindings.internal.batch_manager. + CacheType.SELF, + num_layers=cls.llama_config.num_hidden_layers, + num_kv_heads=cls.llama_config.num_key_value_heads, + head_dim=head_dim, + tokens_per_block=cls.tokens_per_block, + max_seq_len=cls.num_blocks * cls.tokens_per_block, + max_batch_size=cls.batch_size, + mapping=mapping, + dtype=tensorrt_llm.bindings.DataType.HALF) + + @classmethod + def tearDownClass(cls): + cls.kv_cache_manager.shutdown() + + def _get_lora_params(self, in_dim, out_dim): + lora_rank = 8 + lora_weight_ins = torch.randn(in_dim, + lora_rank, + device=self.device, + dtype=self.dtype) + lora_weight_outs = torch.randn(lora_rank, + out_dim, + device=self.device, + dtype=self.dtype) + return lora_weight_ins, lora_weight_outs + + def _create_attention_inputs(self): + hidden_states = torch.randn(self.seq_len, + self.llama_config.hidden_size, + dtype=self.dtype, + device=self.device) + + q_weight = torch.empty(size=[self.hidden_size, self.hidden_size], + dtype=self.dtype) + torch.nn.init.xavier_uniform_(q_weight) + + # Set K and V weights to identity matrix + eye_weight = torch.eye(self.hidden_size, dtype=self.dtype) + qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1) + out_weight = eye_weight + + return hidden_states, qkv_weight, out_weight + + def _setup_attention_module(self, qkv_weight, out_weight): + """Set up the attention module with weights.""" + model_config = ModelConfig(pretrained_config=self.llama_config, + attn_backend="VANILLA") + attention_module = LlamaAttention(model_config, layer_idx=0).to( + self.device).to(self.dtype) + + # Set weights + attention_module.qkv_proj.weight.data = torch.from_numpy( + np.ascontiguousarray(qkv_weight.cpu().numpy().transpose( + [1, 0]))).to(self.device) + attention_module.o_proj.weight.data = torch.from_numpy( + np.ascontiguousarray(out_weight.cpu().numpy().transpose( + [1, 0]))).to(self.device) + + return attention_module, model_config + + def _create_attention_metadata(self, model_config): + sequence_lengths = [self.seq_len] + past_seen_tokens = [0] + request_ids = [0] + token_nums = [self.seq_len] + prompt_lens = token_nums + + # Add requests to KV cache manager + self.kv_cache_manager.add_dummy_requests(request_ids, token_nums) + + metadata_cls = get_attention_backend(model_config.attn_backend).Metadata + return metadata_cls( + seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32), + num_contexts=len(sequence_lengths), + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=past_seen_tokens, + ), + kv_cache_manager=self.kv_cache_manager, + request_ids=request_ids, + prompt_lens=prompt_lens, + max_num_requests=self.batch_size, + max_num_tokens=self.batch_size * self.seq_len, + ) + + def _run_vanilla_attention(self, + hidden_states, + qkv_weight, + lora_params=None): + head_dim = self.hidden_size // self.head_num + + # Base QKV computation + packed_torch_qkv = hidden_states.to("cuda") @ qkv_weight.to("cuda") + + if lora_params: + # Get the LoRA weights from the new structure + dense_params = lora_params[0][ + LoraModuleType. + ATTENTION_DENSE] # TODO (dafrimi) 0 is the layer_idx, needs to pass it here somehow + Q_params = lora_params[0][LoraModuleType.ATTENTION_Q] + K_params = lora_params[0][LoraModuleType.ATTENTION_K] + V_params = lora_params[0][LoraModuleType.ATTENTION_V] + + A_q, B_q = Q_params['weight_tensors'][0], Q_params[ + 'weight_tensors'][1] + A_k, B_k = K_params['weight_tensors'][0], K_params[ + 'weight_tensors'][1] + A_v, B_v = V_params['weight_tensors'][0], V_params[ + 'weight_tensors'][1] + A_o, B_o = dense_params['weight_tensors'][0], dense_params[ + 'weight_tensors'][1] + + # Apply LoRA + lora_output_q = (hidden_states @ B_q.T) @ A_q.T + lora_output_k = (hidden_states @ B_k.T) @ A_k.T + lora_output_v = (hidden_states @ B_v.T) @ A_v.T + + packed_lora_torch_qkv = torch.cat( + [lora_output_q, lora_output_k, lora_output_v], dim=-1) + packed_lora_torch_qkv = packed_torch_qkv + packed_lora_torch_qkv + + packed_lora_torch_qkv = packed_lora_torch_qkv.reshape( + [self.batch_size, self.seq_len, 3, self.head_num, head_dim]) + + mha_out_lora, _ = attention_qkvpacked_ref(packed_lora_torch_qkv, + causal=True, + upcast=False, + bias=None) + + torch_out = mha_out_lora.reshape( + [self.batch_size, self.seq_len, self.hidden_size]) + torch_out = torch_out.squeeze(0) + + # Apply output LoRA and skip projection of O matrix since it's identity + lora_o = (torch_out @ B_o.T) @ A_o.T + torch_out = torch_out + lora_o + else: + # Run vanilla attention without LoRA + packed_torch_qkv = packed_torch_qkv.reshape( + [self.batch_size, self.seq_len, 3, self.head_num, head_dim]) + + mha_out, _ = attention_qkvpacked_ref(packed_torch_qkv, + causal=True, + upcast=False, + bias=None) + + torch_out = mha_out.reshape( + [self.batch_size, self.seq_len, self.hidden_size]) + torch_out = torch_out.squeeze(0) + + return torch_out + + def _run_pivot_attention(self, attention_module, hidden_states, + attn_metadata, lora_params): + with torch.inference_mode(): + attn_metadata.prepare() + return attention_module( + position_ids=None, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + attention_mask=PredefinedAttentionMask.CAUSAL, + lora_params=lora_params) + + def test_attention_with_lora(self): + hidden_states, qkv_weight, out_weight = self._create_attention_inputs() + + # Create LoRA parameters + A_q, B_q = self._get_lora_params(self.hidden_size, self.hidden_size) + A_k, B_k = self._get_lora_params(self.hidden_size, self.hidden_size) + A_v, B_v = self._get_lora_params(self.hidden_size, self.hidden_size) + A_o, B_o = self._get_lora_params(self.hidden_size, self.hidden_size) + + attention_module, model_config = self._setup_attention_module( + qkv_weight, out_weight) + + attn_metadata = self._create_attention_metadata(model_config) + + # Verify QKV projection + self.assertTrue( + torch.allclose(attention_module.qkv_proj.forward(hidden_states), + hidden_states.to("cuda") @ qkv_weight.to("cuda"), + atol=2e-1)) + + # Create lora_params in the new format + lora_params = { + 'num_seqs': self.batch_size, + 'host_request_types': torch.zeros(self.batch_size, + dtype=torch.int32), + 'prompt_lens_cpu': torch.tensor([self.seq_len] * self.batch_size), + 0: { # layer_idx + LoraModuleType.ATTENTION_Q: { # Q module + 'adapter_size': + torch.tensor([8]), # lora_rank + 'weight_pointers': + torch.tensor([[A_q.data_ptr(), + B_q.data_ptr()]]), + 'is_dora': + False, + 'weight_tensors': [A_q, B_q] + }, + LoraModuleType.ATTENTION_K: { # K module + 'adapter_size': + torch.tensor([8]), # lora_rank + 'weight_pointers': + torch.tensor([[A_k.data_ptr(), + B_k.data_ptr()]]), + 'is_dora': + False, + 'weight_tensors': [A_k, B_k] + }, + LoraModuleType.ATTENTION_V: { # V module + 'adapter_size': + torch.tensor([8]), # lora_rank + 'weight_pointers': + torch.tensor([[A_v.data_ptr(), + B_v.data_ptr()]]), + 'is_dora': + False, + 'weight_tensors': [A_v, B_v] + }, + LoraModuleType.ATTENTION_DENSE: { # Output projection module + 'adapter_size': + torch.tensor([8]), # lora_rank + 'weight_pointers': + torch.tensor([[A_o.data_ptr(), + B_o.data_ptr()]]), + 'is_dora': + False, + 'weight_tensors': [A_o, B_o] + } + } + } + + # Run vanilla attention with LoRA + vanilla_output = self._run_vanilla_attention(hidden_states, qkv_weight, + lora_params) + + # Run pivot attention with LoRA + pivot_output = self._run_pivot_attention(attention_module, + hidden_states, attn_metadata, + lora_params) + + torch.testing.assert_close(pivot_output, + vanilla_output, + atol=2e-2, + rtol=0) diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_linear_pivot_vs_trt.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_linear_pivot_vs_trt.py new file mode 100644 index 0000000000..02302f4a54 --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_linear_pivot_vs_trt.py @@ -0,0 +1,178 @@ +import os +import sys +import unittest + +import numpy as np +import tensorrt as trt +import torch + +import tensorrt_llm +from tensorrt_llm._torch.modules.linear import Linear as PivotLinear +from tensorrt_llm._torch.peft.lora.layer import LoraModuleType +from tensorrt_llm._utils import torch_to_numpy +from tensorrt_llm.functional import Tensor +from tensorrt_llm.layers.linear import Linear +from tensorrt_llm.layers.lora import Lora, LoraRuntimeParams +from tensorrt_llm.runtime.session import Session + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) +from utils.util import create_session, run_session + + +class TestLoraLinearPivotVsTRT(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.dtype = "float16" + cls.batch_size = 4 + cls.seq_len = 8 + cls.hidden_size = 1024 + cls.lora_rank = 8 + + # Create input tensors + cls.input_tensor = torch.randn(cls.batch_size, + cls.seq_len, + cls.hidden_size, + dtype=tensorrt_llm.str_dtype_to_torch( + cls.dtype), + device="cuda") + + cls.weight = torch.randn(cls.hidden_size, + cls.hidden_size, + dtype=tensorrt_llm.str_dtype_to_torch( + cls.dtype), + device="cuda") + + cls.A = torch.randn(cls.lora_rank, + cls.hidden_size, + dtype=tensorrt_llm.str_dtype_to_torch(cls.dtype), + device="cuda") + cls.B = torch.randn(cls.hidden_size, + cls.lora_rank, + dtype=tensorrt_llm.str_dtype_to_torch(cls.dtype), + device="cuda") + + def _create_linear_lora_trt_session(self) -> Session: + builder = tensorrt_llm.Builder() + network = builder.create_network() + + with tensorrt_llm.net_guard(network): + network.plugin_config.lora_plugin = self.dtype + network.plugin_config.remove_input_padding = False + + linear = Linear(in_features=self.hidden_size, + out_features=self.hidden_size, + dtype=self.dtype, + bias=False) + linear.lora = Lora(in_hidden_size=self.hidden_size, + out_hidden_sizes=[self.hidden_size], + max_low_rank=self.lora_rank) + + linear.weight.value = np.ascontiguousarray( + torch_to_numpy(self.weight.cpu())) + + inp = Tensor( + name="input_tensor", + shape=[self.batch_size, self.seq_len, self.hidden_size], + dtype=tensorrt_llm.str_dtype_to_trt(self.dtype)) + + lora_weights_pointers = Tensor(name="lora_weights_pointers", + shape=[self.batch_size, 3], + dtype=trt.int64) + + host_request_types = Tensor(name="host_request_types", + shape=[self.batch_size], + dtype=trt.int32) + + lora_ranks = Tensor(name="lora_ranks", + shape=(self.batch_size, ), + dtype=trt.int32) + + lora_params = LoraRuntimeParams( + lora_ranks=[lora_ranks], + lora_weights_pointers=[lora_weights_pointers], + host_request_types=host_request_types, + weight_index=0) + + output = linear(inp, lora_runtime_params=lora_params) + output.mark_output("output", self.dtype) + + return create_session(builder, network, precision=self.dtype) + + def _create_trt_inputs(self): + host_request_types = torch.zeros(self.batch_size, dtype=torch.int32) + magnitude_dora = torch.zeros(self.hidden_size, + dtype=tensorrt_llm.str_dtype_to_torch( + self.dtype), + device="cuda") + + inputs = { + "input_tensor": self.input_tensor, + "host_request_types": host_request_types + } + + # Create LoRA weight pointers + weights_ptrs = torch.tensor( + [[self.A.data_ptr(), + self.B.data_ptr(), + magnitude_dora.data_ptr()] for _ in range(self.batch_size)], + dtype=torch.int64) + inputs["lora_weights_pointers"] = weights_ptrs + inputs["lora_ranks"] = torch.tensor([self.lora_rank] * self.batch_size, + dtype=torch.int32) + + return inputs + + def _setup_pivot_linear(self): + pivot_linear = PivotLinear(in_features=self.hidden_size, + out_features=self.hidden_size, + bias=False, + dtype=tensorrt_llm.str_dtype_to_torch( + self.dtype), + layer_idx=0) + + pivot_linear.weight.data = self.weight + return pivot_linear + + def test_lora_linear_layer(self): + session = self._create_linear_lora_trt_session() + + inputs = self._create_trt_inputs() + outputs = run_session(session, inputs) + torch.cuda.synchronize() + + pivot_linear = self._setup_pivot_linear() + + lora_params = { + 'num_seqs': + self.batch_size, + 'host_request_types': + inputs["host_request_types"], + 'prompt_lens_cpu': + torch.tensor([self.seq_len] * self.batch_size, dtype=torch.int32), + 0: { # layer_idx + LoraModuleType.DENSE: { # module_type + 'adapter_size': inputs["lora_ranks"], + 'weight_pointers': inputs["lora_weights_pointers"], + 'is_dora': False, + } + } + } + + outputs_pivot = pivot_linear(self.input_tensor, lora_params=lora_params) + + print(f"outputs: {outputs['output']}") + print(f"outputs_pivot: {outputs_pivot}") + + torch.testing.assert_close(outputs["output"], + outputs_pivot, + atol=2e-3, + rtol=0) + + +if __name__ == "__main__": + unittest.main() + # x = 0 + # for i in range(100): + # x += 1 + # print(x) diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_linear_pivot_vs_vanilla_torch.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_linear_pivot_vs_vanilla_torch.py new file mode 100644 index 0000000000..09b2e1707c --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_linear_pivot_vs_vanilla_torch.py @@ -0,0 +1,92 @@ +import unittest + +import torch +from torch.nn import Linear + +from tensorrt_llm._torch.modules.linear import Linear as PivotLinear +from tensorrt_llm._torch.peft.lora.layer import LoraModuleType + + +class TestLoraLinearPivotVsVanilla(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.in_dim = 64 + cls.out_dim = 64 + cls.input_length = 10 + cls.batch_size = 4 + cls.device = "cuda" + cls.input_tensor = torch.randn(cls.batch_size, + cls.input_length, + cls.in_dim, + device=cls.device) + + def _get_lora_params(self): + lora_rank = 8 + lora_weight_ins = torch.randn(self.in_dim, + lora_rank, + device=self.device) + lora_weight_outs = torch.randn(lora_rank, + self.out_dim, + device=self.device) + + lora_params = { + 'num_seqs': self.batch_size, + 'host_request_types': torch.zeros(self.batch_size, + dtype=torch.int32), + 'prompt_lens_cpu': + torch.tensor([self.input_length] * self.batch_size), + 0: { # layer_idx + LoraModuleType.DENSE: { # module_type + 'adapter_size': + torch.tensor([lora_rank]), + 'weight_pointers': + torch.tensor([ + lora_weight_ins.data_ptr(), + lora_weight_outs.data_ptr() + ]), + 'weight_tensors': [lora_weight_ins, lora_weight_outs], + 'is_dora': + False + } + } + } + return lora_params + + def _setup_linear_layers(self): + torch_linear = Linear(self.in_dim, self.out_dim).to(self.device) + pivot_linear = PivotLinear(in_features=self.in_dim, + out_features=self.out_dim, + layer_idx=0) + + # Initialize pivot linear with same weights as torch linear + pivot_linear.weight.data = torch_linear.weight.data + pivot_linear.bias.data = torch_linear.bias.data + + return torch_linear, pivot_linear + + def test_compare_linear_torch_pivot_lora(self): + lora_params = self._get_lora_params() + torch_linear, pivot_linear = self._setup_linear_layers() + + lora_weight_ins = lora_params[0][ + LoraModuleType.DENSE]['weight_tensors'][0] + lora_weight_outs = lora_params[0][ + LoraModuleType.DENSE]['weight_tensors'][1] + lora_output = ( + self.input_tensor @ lora_weight_outs.T) @ lora_weight_ins.T + + torch_output = torch_linear(self.input_tensor) + torch_output = torch_output + lora_output + + pivot_output = pivot_linear(self.input_tensor, lora_params=lora_params) + + self.assertTrue(torch.allclose(torch_output, pivot_output)) + + def test_compare_linear_torch_pivot(self): + torch_linear, pivot_linear = self._setup_linear_layers() + + torch_output = torch_linear(self.input_tensor) + pivot_output = pivot_linear(self.input_tensor) + + torch.testing.assert_close(torch_output, pivot_output) diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_mlp_pivot_vs_trt.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_mlp_pivot_vs_trt.py new file mode 100644 index 0000000000..40d76c4924 --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_mlp_pivot_vs_trt.py @@ -0,0 +1,282 @@ +import unittest + +import torch + +import tensorrt_llm +from tensorrt_llm import Tensor +from tensorrt_llm._torch.modules.mlp import MLP as PivotMLP +from tensorrt_llm._torch.peft.lora.layer import LoraModuleType +from tensorrt_llm._utils import str_dtype_to_torch +from tensorrt_llm.layers import MLP as TRTMLP +from tensorrt_llm.layers.lora import Lora, LoraParams + + +class TestLoraMLPPivotVsTRT(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.batch_size = 1 + cls.seq_len = 16 + cls.hidden_size = 64 + cls.intermediate_size = cls.hidden_size * 4 + cls.num_hidden_layers = 1 + cls.dtype = 'float16' + cls.torch_dtype = str_dtype_to_torch(cls.dtype) + cls.device = torch.device('cuda') + + def _create_mlp_inputs(self): + hidden_states = torch.empty( + size=[self.batch_size, self.seq_len, self.hidden_size], + dtype=self.torch_dtype, + device='cuda') + hidden_states.normal_(0.0, 0.02) + + return hidden_states + + def _get_lora_params(self, in_dim, out_dim): + lora_rank = 8 + A = torch.randn(in_dim, + lora_rank, + device=self.device, + dtype=self.torch_dtype) + B = torch.randn(lora_rank, + out_dim, + device=self.device, + dtype=self.torch_dtype) + return A, B + + def _create_lora_params(self): + lora_ranks_list = [8] + + host_context_lengths = torch.Tensor( + [self.seq_len for _ in range(self.batch_size)]).to(torch.int32) + + lora_ranks = torch.tensor(lora_ranks_list * self.batch_size, + dtype=torch.int32) + + host_request_types = torch.zeros_like(host_context_lengths, + device='cpu').int() + + # Create weights for up projection + lora_weight_ins_up = [ + torch.randn(self.hidden_size, lora_rank, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + lora_weight_outs_up = [ + torch.randn(lora_rank, self.intermediate_size, + device=self.device).to(self.torch_dtype) * 0.1 + for lora_rank in lora_ranks_list + ] + + # Create weights for down projection + lora_weight_ins_down = [ + torch.randn(self.intermediate_size, lora_rank, + device=self.device).to(self.torch_dtype) * 0.1 + for lora_rank in lora_ranks_list + ] + lora_weight_outs_down = [ + torch.randn(lora_rank, self.hidden_size, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + + lora_weight_ins_up = [tmp.contiguous() for tmp in lora_weight_ins_up] + lora_weight_outs_up = [ + tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs_up + ] + lora_weight_ins_down = [ + tmp.contiguous() for tmp in lora_weight_ins_down + ] + lora_weight_outs_down = [ + tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs_down + ] + + # Create weight pointers for TensorRT + lora_weights_pointers_up = [] + for in_ptr, out_ptr in zip(lora_weight_ins_up, lora_weight_outs_up): + lora_weights_pointers_up.append(in_ptr.data_ptr()) + lora_weights_pointers_up.append(out_ptr.data_ptr()) + + lora_weights_pointers_down = [] + for in_ptr, out_ptr in zip(lora_weight_ins_down, lora_weight_outs_down): + lora_weights_pointers_down.append(in_ptr.data_ptr()) + lora_weights_pointers_down.append(out_ptr.data_ptr()) + + lora_weights_pointers_up = torch.LongTensor( + lora_weights_pointers_up).to(torch.int64).reshape( + [self.batch_size, 2]) + lora_weights_pointers_down = torch.LongTensor( + lora_weights_pointers_down).to(torch.int64).reshape( + [self.batch_size, 2]) + + return { + 'lora_ranks': lora_ranks, + 'host_context_lengths': host_context_lengths, + 'host_request_types': host_request_types, + 'lora_weights_pointers_up': lora_weights_pointers_up, + 'lora_weights_pointers_down': lora_weights_pointers_down, + 'lora_weight_ins_up': lora_weight_ins_up, + 'lora_weight_outs_up': lora_weight_outs_up, + 'lora_weight_ins_down': lora_weight_ins_down, + 'lora_weight_outs_down': lora_weight_outs_down + } + + def _setup_mlp_module(self): + + mlp_module = PivotMLP(hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + bias=True, + activation=torch.nn.functional.silu, + dtype=self.torch_dtype, + layer_idx=0).to(self.device) + return mlp_module + + def _setup_trt_network(self, hidden_states, mlp_module, lora_params): + builder = tensorrt_llm.Builder() + net = builder.create_network() + net.plugin_config.to_legacy_setting() + net.plugin_config.lora_plugin = self.dtype + net.plugin_config.remove_input_padding = False + + with tensorrt_llm.net_guard(net): + trt_hidden_states = Tensor(name='hidden_states', + shape=hidden_states.shape, + dtype=tensorrt_llm.str_dtype_to_trt( + self.dtype)) + + host_request_types = Tensor( + name='host_request_types', + shape=[lora_params['host_request_types'].shape[0]], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths = Tensor( + name='host_context_lengths', + shape=[lora_params['host_context_lengths'].shape[0]], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_ranks = Tensor(name='lora_ranks', + shape=(lora_params['lora_ranks'].shape[0], ), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_weights_pointers_up = Tensor( + name='lora_weights_pointers_up', + shape=lora_params['lora_weights_pointers_up'].shape, + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + lora_weights_pointers_down = Tensor( + name='lora_weights_pointers_down', + shape=lora_params['lora_weights_pointers_down'].shape, + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + + mlp_layer = TRTMLP( + hidden_size=self.hidden_size, + ffn_hidden_size=self.intermediate_size, + hidden_act='silu', # not tested gated activations + bias=True, + dtype=self.dtype) + + # Create LoRA layers for both linear projections + mlp_layer.fc.lora = Lora( + in_hidden_size=self.hidden_size, + out_hidden_sizes=[self.intermediate_size], + max_low_rank=8, + ) + mlp_layer.proj.lora = Lora( + in_hidden_size=self.intermediate_size, + out_hidden_sizes=[self.hidden_size], + max_low_rank=8, + ) + + # Set weights + mlp_layer.fc.weight.value = mlp_module.up_proj.weight.data + mlp_layer.fc.bias.value = mlp_module.up_proj.bias.data + + mlp_layer.proj.weight.value = mlp_module.down_proj.weight.data + mlp_layer.proj.bias.value = mlp_module.down_proj.bias.data + + # Create LoRA parameters for TensorRT + trt_lora_params = LoraParams( + lora_ranks=[{ + "mlp_h_to_4h_lora_ranks": lora_ranks, + "mlp_4h_to_h_lora_ranks": lora_ranks, + }], + lora_weights_pointers=[{ + "mlp_h_to_4h_lora_weights_pointers": + lora_weights_pointers_up, + "mlp_4h_to_h_lora_weights_pointers": + lora_weights_pointers_down, + }], + host_context_lengths=host_context_lengths, + host_request_types=host_request_types) + + output = mlp_layer(trt_hidden_states, + lora_layer_params=trt_lora_params) + output.mark_output('output', + tensorrt_llm.str_dtype_to_trt(self.dtype)) + + return builder, net + + def _run_trt_inference(self, builder, net, hidden_states, lora_params): + builder_config = builder.create_builder_config(name='mlp', + precision=self.dtype) + engine_buffer = builder.build_engine(net, builder_config) + session = tensorrt_llm.runtime.Session.from_serialized_engine( + engine_buffer) + + stream = torch.cuda.current_stream().cuda_stream + inputs = { + 'hidden_states': hidden_states, + 'host_request_types': lora_params['host_request_types'], + 'host_context_lengths': lora_params['host_context_lengths'], + 'lora_ranks': lora_params['lora_ranks'], + 'lora_weights_pointers_up': lora_params['lora_weights_pointers_up'], + 'lora_weights_pointers_down': + lora_params['lora_weights_pointers_down'], + } + + outputs = { + 'output': + torch.empty(hidden_states.shape, + dtype=tensorrt_llm._utils.str_dtype_to_torch( + self.dtype), + device='cuda'), + } + + session.run(inputs=inputs, outputs=outputs, stream=stream) + torch.cuda.synchronize() + + return outputs['output'] + + def test_mlp(self): + hidden_states = self._create_mlp_inputs() + lora_params = self._create_lora_params() + + mlp_module = self._setup_mlp_module() + + builder, net = self._setup_trt_network(hidden_states, mlp_module, + lora_params) + trt_output = self._run_trt_inference(builder, net, hidden_states, + lora_params) + + # Create LoRA parameters for PyTorch MLP + lora_params_pivot = { + 'num_seqs': self.batch_size, + 'host_request_types': lora_params['host_request_types'], + 'prompt_lens_cpu': lora_params['host_context_lengths'], + 0: { + LoraModuleType.MLP_H_TO_4H: { + 'adapter_size': lora_params['lora_ranks'], + 'weight_pointers': lora_params['lora_weights_pointers_up'], + 'is_dora': False, + }, + LoraModuleType.MLP_4H_TO_H: { + 'adapter_size': lora_params['lora_ranks'], + 'weight_pointers': + lora_params['lora_weights_pointers_down'], + 'is_dora': False, + } + } + } + + pivot_output = mlp_module(hidden_states, lora_params_pivot) + + torch.testing.assert_close(pivot_output, trt_output, atol=2e-3, rtol=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_plugin_vs_layer.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_plugin_vs_layer.py new file mode 100644 index 0000000000..21c9370a97 --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_plugin_vs_layer.py @@ -0,0 +1,180 @@ +import os +import sys +import unittest + +import torch + +import tensorrt_llm +from tensorrt_llm import Tensor + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) +from utils.util import create_session, run_session + + +class TestLoraPluginVsLayer(unittest.TestCase): + + def setUp(self): + tensorrt_llm.logger.set_level('info') + torch.random.manual_seed(0) + self.dtype = 'float16' + self.torch_dtype = torch.float16 + self.device = 'cuda' + self.batch_size = 4 + self.seq_len = 8 + self.hidden_size = 1024 + self.lora_rank = 8 + + def _create_input_tensors(self, batch_size, seq_len, hidden_size, + lora_ranks_list): + input_tensor = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=self.torch_dtype, + device=self.device) * 0.1 + + lora_weight_ins = [ + torch.randn(hidden_size, lora_rank, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + lora_weight_outs = [ + torch.randn(lora_rank, hidden_size, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + + lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins] + lora_weight_outs = [ + tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs + ] + + # Create LoRA weight pointers + lora_weights_pointers = [] + for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs): + lora_weights_pointers.append(in_ptr.data_ptr()) + lora_weights_pointers.append(out_ptr.data_ptr()) + # null dora scale + lora_weights_pointers.append(0) + + lora_weights_pointers = torch.LongTensor(lora_weights_pointers).to( + torch.int64).reshape([batch_size, 3]) + + # Create other tensors + host_context_lengths = torch.Tensor( + [seq_len for _ in range(batch_size)]).to(torch.int32) + lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32) + host_request_types = torch.zeros_like(host_context_lengths, + device='cpu').int() + + return { + 'input_tensor': input_tensor, + 'lora_weight_ins': lora_weight_ins, + 'lora_weight_outs': lora_weight_outs, + 'lora_weights_pointers': lora_weights_pointers, + 'host_context_lengths': host_context_lengths, + 'lora_ranks': lora_ranks, + 'host_request_types': host_request_types, + 'batch_size': batch_size, + 'seq_len': seq_len, + 'hidden_size': hidden_size, + 'max_lora_rank': max(max(lora_ranks_list), 8) + } + + def _create_lora_plugin_session(self, tensors): + # Construct TensorRT network + builder = tensorrt_llm.Builder() + network = builder.create_network() + network.plugin_config.set_lora_plugin(self.dtype) + network.plugin_config.remove_input_padding = False + + with tensorrt_llm.net_guard(network): + input_tensor = Tensor(name='input_tensor', + shape=[ + tensors['batch_size'], tensors['seq_len'], + tensors['hidden_size'] + ], + dtype=tensorrt_llm.str_dtype_to_trt( + self.dtype)) + host_request_types_tensor = Tensor( + name='host_request_types', + shape=[tensors['batch_size']], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths_tensor = Tensor( + name='host_context_lengths', + shape=[tensors['batch_size']], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_ranks_tensor = Tensor( + name='lora_ranks', + shape=[tensors['batch_size']], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_weights_pointers_tensor = Tensor( + name='lora_weights_pointers', + shape=[tensors['batch_size'], 3], + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + + output = tensorrt_llm.functional.lora_plugin( + input_tensor, + tensors['hidden_size'], + [tensors['hidden_size']], + host_request_types_tensor, + False, # transA + True, # transB + host_context_lengths_tensor, + tensors['max_lora_rank'], + [lora_ranks_tensor], + [lora_weights_pointers_tensor], + weight_index=0, + ) + output.mark_output('output') + + return create_session(builder, network, precision=self.dtype) + + def _run_lora_grouped_gemm(self, tensors): + """Run the lora_grouped_gemm operation directly""" + # Prepare parameters for lora_grouped_gemm + x = tensors['input_tensor'] + host_request_types = tensors[ + 'host_request_types'][:tensors['batch_size']] + lora_ranks = tensors['lora_ranks'] + lora_weight_pointers = tensors['lora_weights_pointers'] + prompt_lens_cpu = tensors[ + 'host_context_lengths'][:tensors['batch_size']] + output_hidden_sizes = [tensors['hidden_size']] + transA = False + transB = True + max_rank = max([r.item() for r in lora_ranks]) + weight_index = 0 + is_remove_input_padding = False + + lora_outputs = torch.ops.trtllm.lora_grouped_gemm( + x, host_request_types, [lora_ranks], [lora_weight_pointers], + prompt_lens_cpu, output_hidden_sizes, transA, transB, max_rank, + weight_index, is_remove_input_padding) + + return lora_outputs[0] + + def test_lora_plugin_vs_lora_op(self): + lora_ranks_list = [self.lora_rank] * self.batch_size + + tensors = self._create_input_tensors(self.batch_size, self.seq_len, + self.hidden_size, lora_ranks_list) + + session = self._create_lora_plugin_session(tensors) + inputs = { + 'input_tensor': tensors['input_tensor'], + 'host_request_types': tensors['host_request_types'], + 'host_context_lengths': tensors['host_context_lengths'], + 'lora_ranks': tensors['lora_ranks'], + 'lora_weights_pointers': tensors['lora_weights_pointers'], + } + outputs = run_session(session, inputs) + torch.cuda.synchronize() + + lora_outputs = self._run_lora_grouped_gemm(tensors) + + torch.testing.assert_close(outputs['output'], + lora_outputs, + atol=5e-3, + rtol=0.3) + + +if __name__ == "__main__": + unittest.main()