|
293 | 293 | " padding_mask = tf.cast(\n",
|
294 | 294 | " mask[:, tf.newaxis, :], dtype=\"int32\")\n",
|
295 | 295 | " padding_mask = tf.minimum(padding_mask, causal_mask)\n",
|
| 296 | + " else:\n", |
| 297 | + " padding_mask = mask\n", |
296 | 298 | " attention_output_1 = self.attention_1(\n",
|
297 | 299 | " query=inputs,\n",
|
298 | 300 | " value=inputs,\n",
|
|
391 | 393 | " self.model_input_length = model_input_length\n",
|
392 | 394 | " self.temperatures = temperatures\n",
|
393 | 395 | " self.print_freq = print_freq\n",
|
| 396 | + " vectorized_prompt = text_vectorization([prompt])[0].numpy()\n", |
| 397 | + " self.prompt_length = np.nonzero(vectorized_prompt == 0)[0][0]\n", |
394 | 398 | "\n",
|
395 | 399 | " def on_epoch_end(self, epoch, logs=None):\n",
|
396 | 400 | " if (epoch + 1) % self.print_freq != 0:\n",
|
|
401 | 405 | " for i in range(self.generate_length):\n",
|
402 | 406 | " tokenized_sentence = text_vectorization([sentence])\n",
|
403 | 407 | " predictions = self.model(tokenized_sentence)\n",
|
404 |
| - " next_token = sample_next(predictions[0, i, :])\n", |
| 408 | + " next_token = sample_next(\n", |
| 409 | + " predictions[0, self.prompt_length - 1 + i, :]\n", |
| 410 | + " )\n", |
405 | 411 | " sampled_token = tokens_index[next_token]\n",
|
406 | 412 | " sentence += \" \" + sampled_token\n",
|
407 | 413 | " print(sentence)\n",
|
|
0 commit comments