Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lookahead decoding and multimodal input support #3137

Open
maxilevi opened this issue Mar 28, 2025 · 7 comments
Open

Lookahead decoding and multimodal input support #3137

maxilevi opened this issue Mar 28, 2025 · 7 comments
Assignees
Labels
question Further information is requested triaged Issue has been triaged by maintainers

Comments

@maxilevi
Copy link

Hi,

I get the following error when:

  • Lookahead decoding is enabled
  • Request has multimodal input (e.g. just custom prompt table with fake vocabulary)
  • batch size > 1
  • Inflight fused batching is enabled

Model is Llama 8B.

[TensorRT-LLM][ERROR] IExecutionContext::inferShapes: Error Code 7: Internal Error (LLaMAForCausalLM/transformer/vocab_embedding/__add___L322/elementwise_binary_L2901/ELEMENTWISE_SUM_0: dimensions not compatible for elementwise. Broadcast has incompatible dimensions: 2 != 18 && 2 != 1 && 18 != 1. Instruction: CHECK_BROADCAST 2 18.)
[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: Invalid input shape (/home/jenkins/agent/workspace/LLM/release-0.17/L0_Test-x86_64/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:574)
1       0x7f4097fd7277 /home/maximilianolevi/.cache/pypoetry/virtualenvs/tensorrt-inference-8MUMp6os-py3.10/lib/python3.10/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x6e3277) [0x7f4097fd7277]
2       0x7f4098cadc88 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::prepareBuffers(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 184
3       0x7f4098cb71d6 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 1510
4       0x7f4098cb7abf tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 223
5       0x7f4098cc13aa tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 1802
6       0x7f4098d4df85 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 437
7       0x7f4098d59cb6 tensorrt_llm::executor::Executor::Impl::executionLoop() + 1206
8       0x7f43a6e215c0 /home/maximilianolevi/.cache/pypoetry/virtualenvs/tensorrt-inference-8MUMp6os-py3.10/lib/python3.10/site-packages/torch/lib/libtorch.so(+0x145c0) [0x7f43a6e215c0]
9       0x7f43aaea4ea7 /lib/x86_64-linux-gnu/libpthread.so.0(+0x7ea7) [0x7f43aaea4ea7]
10      0x7f43aafbaacf clone + 63

Does the max_multimodal_len or the lookahead decoding parameters need to match a specific shape in this case?

@juney-nvidia
Copy link
Collaborator

@lfr-0531 may provide some quick comment on this issue.

June

@juney-nvidia juney-nvidia added question Further information is requested triaged Issue has been triaged by maintainers labels Mar 28, 2025
@maxilevi
Copy link
Author

Thank you for the fast response. Do you know if its a bug or an inherent limitation of current implementation?

@lfr-0531
Copy link
Collaborator

Currently, lookahead decoding cannot support multimodal cases.

Can you share your cmd? We can try to have a fix.

@maxilevi
Copy link
Author

maxilevi commented Mar 31, 2025

@lfr-0531 Thank you for the reply.

Currently I am testing llama 3.2 1B with the following command

trtllm-build --max_batch_size 8 --max_seq_len 1024 --max_multimodal_len 131072 --gpt_attention_plugin auto --gemm_plugin auto --model_cls_name LLaMAForCausalLM --max_draft_len 83 --speculative_decoding_mode lookahead_decoding --checkpoint_dir /var/tmp/tmp2achasyx --output_dir /var/tmp/tmp3dvky5dj/engine

Compilation always works but when the engine is processing and i submit batch_size > 1 it crashes with the following error

IExecutionContext::inferShapes: Error Code 7: Internal Error (LLaMAForCausalLM/transformer/vocab_embedding/__add___L322/elementwise_binary_L2901/ELEMENTWISE_SUM_0: dimensions not compatible for elementwise. Broadcast has incompatible dimensions: 2 != 24 && 2 != 1 && 24 != 1. Instruction: CHECK_BROADCAST 2 24.)
[TensorRT-LLM][ERROR] Encountered an error in forwardAsync function: Invalid input shape (/home/jenkins/agent/workspace/LLM/release-0.17/L0_Test-x86_64/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:574)
1       0x7f47d3710277 /home/maximilianolevi/.cache/pypoetry/virtualenvs/tensorrt-inference-8MUMp6os-py3.10/lib/python3.10/site-packages/tensorrt_llm/libs/libtensorrt_llm.so(+0x6e3277) [0x7f47d3710277]
2       0x7f47d43e6c88 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::prepareBuffers(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 184
3       0x7f47d43f01d6 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeStep(std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, std::vector<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&, int) + 1510
4       0x7f47d43f0abf tensorrt_llm::batch_manager::TrtGptModelInflightBatching::executeBatch(tensorrt_llm::batch_manager::ScheduledRequests const&) + 223
5       0x7f47d43fa3aa tensorrt_llm::batch_manager::TrtGptModelInflightBatching::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > > const&) + 1802
6       0x7f47d4486f85 tensorrt_llm::executor::Executor::Impl::forwardAsync(std::list<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>, std::allocator<std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> > >&) + 437
7       0x7f47d4492cb6 tensorrt_llm::executor::Executor::Impl::executionLoop() + 1206
8       0x7f4ae255a5c0 /home/maximilianolevi/.cache/pypoetry/virtualenvs/tensorrt-inference-8MUMp6os-py3.10/lib/python3.10/site-packages/torch/lib/libtorch.so(+0x145c0) [0x7f4ae255a5c0]
9       0x7f4ae65ddea7 /lib/x86_64-linux-gnu/libpthread.so.0(+0x7ea7) [0x7f4ae65ddea7]
10      0x7f4ae66f3acf clone + 63

But for batch size == 1 always works.

Maybe shapes are not adjusted correctly for the case batch_size > 1?

@lfr-0531
Copy link
Collaborator

lfr-0531 commented Apr 2, 2025

I can reproduce this issue. It is because tensorrt-llm cannot support PromptTuning/multimodal + Lookahead decoding now.

When setting max_multimodal_len > 0, the PromptTuningEmbedding will be used in the model. Then when using lookahead decoding, in the decoding phase, the prompt_tokens will be with shape [batch_size, 1+draft_len], but the tasks tensor will be [batch_size, 1]. Then there will be an imcompatible shape issue in this line.

For the batch size = 1 case, the tasks tensor will be with shape [1], so there is no such an error.

@maxilevi
Copy link
Author

maxilevi commented Apr 2, 2025

So we just need to broadcast the second dimension when batch_size > 1 ? I can PR

@lfr-0531
Copy link
Collaborator

lfr-0531 commented Apr 3, 2025

So we just need to broadcast the second dimension when batch_size > 1 ? I can PR

Yes, we need to expand the tasks tensor.

You are welcome to contribute the code to TensorRT-LLM directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants