Skip to content

Commit a6a24ed

Browse files
committed
Round of fixes
1 parent e0be6dd commit a6a24ed

2 files changed

+9
-1
lines changed

chapter11_part04_sequence-to-sequence-learning.ipynb

+2
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@
405405
" padding_mask = tf.cast(\n",
406406
" mask[:, tf.newaxis, :], dtype=\"int32\")\n",
407407
" padding_mask = tf.minimum(padding_mask, causal_mask)\n",
408+
" else:\n",
409+
" padding_mask = mask\n",
408410
" attention_output_1 = self.attention_1(\n",
409411
" query=inputs,\n",
410412
" value=inputs,\n",

chapter12_part01_text-generation.ipynb

+7-1
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@
293293
" padding_mask = tf.cast(\n",
294294
" mask[:, tf.newaxis, :], dtype=\"int32\")\n",
295295
" padding_mask = tf.minimum(padding_mask, causal_mask)\n",
296+
" else:\n",
297+
" padding_mask = mask\n",
296298
" attention_output_1 = self.attention_1(\n",
297299
" query=inputs,\n",
298300
" value=inputs,\n",
@@ -391,6 +393,8 @@
391393
" self.model_input_length = model_input_length\n",
392394
" self.temperatures = temperatures\n",
393395
" self.print_freq = print_freq\n",
396+
" vectorized_prompt = text_vectorization([prompt])[0].numpy()\n",
397+
" self.prompt_length = np.nonzero(vectorized_prompt == 0)[0][0]\n",
394398
"\n",
395399
" def on_epoch_end(self, epoch, logs=None):\n",
396400
" if (epoch + 1) % self.print_freq != 0:\n",
@@ -401,7 +405,9 @@
401405
" for i in range(self.generate_length):\n",
402406
" tokenized_sentence = text_vectorization([sentence])\n",
403407
" predictions = self.model(tokenized_sentence)\n",
404-
" next_token = sample_next(predictions[0, i, :])\n",
408+
" next_token = sample_next(\n",
409+
" predictions[0, self.prompt_length - 1 + i, :]\n",
410+
" )\n",
405411
" sampled_token = tokens_index[next_token]\n",
406412
" sentence += \" \" + sampled_token\n",
407413
" print(sentence)\n",

0 commit comments

Comments
 (0)