Skip to content

Commit a02f8c4

Browse files
authored
[NPUW] Fix LLMInferRequest to work with the input prompt of len 1 (#32267)
### Details: - *Fix LLM inference on NPU for input prompt of length 1* ### Tickets: - *N/A*
1 parent ca5e119 commit a02f8c4

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,13 +629,16 @@ void slice_out_embeds(std::shared_ptr<ov::Model> model,
629629

630630
if (embed_result) {
631631
auto shape = embed_result->input(0).get_shape();
632-
// If shape.size() is 3, then last axis should be the Vocab size.
632+
// If shape.size() is 3, then last axis should contain the rank of embedding dimension.
633633
// But 1st and 2nd axes can mean different things.
634634
// 1st axis can represent the batch size, while 2nd - the number of embeddings,
635635
// or vice-versa (in chatglm)
636636
if (shape.size() == 3) {
637+
OPENVINO_ASSERT(batch_dim <= 1, "Unexpected value of batch_dim: ", batch_dim, ", expected 0 or 1!");
637638
uint32_t num_embeds_dim = 1 - batch_dim;
638-
if (shape[num_embeds_dim] > max_generation_token_len) {
639+
OPENVINO_ASSERT(shape[num_embeds_dim] >= max_generation_token_len,
640+
"Number of output embeddings should be greater or equal to the slicing range!");
641+
if (shape[num_embeds_dim] != max_generation_token_len) {
639642
std::vector<int32_t> start_pos{
640643
static_cast<int32_t>(batch_dim * (shape[num_embeds_dim] - max_generation_token_len)),
641644
static_cast<int32_t>(num_embeds_dim * (shape[num_embeds_dim] - max_generation_token_len)),

src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,9 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
858858

859859
if (!m_generate_initialized) {
860860
LOG_DEBUG("Copy kv-cache from prefill to generate model.");
861-
copy_kvcache();
861+
if (kvcache_desc.num_stored_tokens > 0) {
862+
copy_kvcache();
863+
}
862864

863865
LOG_DEBUG("Prepare inputs.");
864866
namespace uu = ov::npuw::util;

0 commit comments

Comments
 (0)