From d7bbd714c4c61df6b34c513c017fe55fedac9559 Mon Sep 17 00:00:00 2001 From: Kaixuan Huang Date: Fri, 4 Aug 2023 01:08:11 -0400 Subject: [PATCH] fix bugs to enable key_value cache for generation --- progen2/models/progen/modeling_progen.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/progen2/models/progen/modeling_progen.py b/progen2/models/progen/modeling_progen.py index f6e7063..26f6020 100644 --- a/progen2/models/progen/modeling_progen.py +++ b/progen2/models/progen/modeling_progen.py @@ -575,10 +575,10 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): return - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past: + # only last token for inputs_ids if past_key_values is defined in kwargs + if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) @@ -590,13 +590,13 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - if past: + if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None return { "input_ids": input_ids, - "past_key_values": past, + "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask,