Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -629,13 +629,16 @@ void slice_out_embeds(std::shared_ptr<ov::Model> model,

if (embed_result) {
auto shape = embed_result->input(0).get_shape();
// If shape.size() is 3, then last axis should be the Vocab size.
// If shape.size() is 3, then last axis should contain the rank of embedding dimension.
// But 1st and 2nd axes can mean different things.
// 1st axis can represent the batch size, while 2nd - the number of embeddings,
// or vice-versa (in chatglm)
if (shape.size() == 3) {
OPENVINO_ASSERT(batch_dim <= 1, "Unexpected value of batch_dim: ", batch_dim, ", expected 0 or 1!");
uint32_t num_embeds_dim = 1 - batch_dim;
if (shape[num_embeds_dim] > max_generation_token_len) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow I overlooked it in the past but what is batch_dim? can this 1 - x underflow to some hugely positive value here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because batch dim is either 0 or 1 (for chat-glm), but great catch!! Let me add assert!

OPENVINO_ASSERT(shape[num_embeds_dim] >= max_generation_token_len,
"Number of output embeddings should be greater or equal to the slicing range!");
if (shape[num_embeds_dim] != max_generation_token_len) {
std::vector<int32_t> start_pos{
static_cast<int32_t>(batch_dim * (shape[num_embeds_dim] - max_generation_token_len)),
static_cast<int32_t>(num_embeds_dim * (shape[num_embeds_dim] - max_generation_token_len)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,9 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,

if (!m_generate_initialized) {
LOG_DEBUG("Copy kv-cache from prefill to generate model.");
copy_kvcache();
if (kvcache_desc.num_stored_tokens > 0) {
copy_kvcache();
}

LOG_DEBUG("Prepare inputs.");
namespace uu = ov::npuw::util;
Expand Down
Loading