diff --git a/data_utils/lm_datasets.py b/data_utils/lm_datasets.py index 65c14b8b..6aa4a2cf 100644 --- a/data_utils/lm_datasets.py +++ b/data_utils/lm_datasets.py @@ -55,8 +55,8 @@ def _process_lm(self, i, samp, model_data, no_model_data, gen_data): source_len = 1 prompt = None - if 65535 in input_ids: - source_len = np.where(input_ids==65535)[0][0] + if 65535 in input_ids or 4294967295 in input_ids: + source_len = np.where((input_ids==65535) | (input_ids==4294967295))[0][0] prompt = input_ids[:source_len] input_ids = np.concatenate([input_ids[:source_len], input_ids[source_len+1:]], axis=0) input_ids = input_ids[:self.max_length]