|
30 | 30 | "source": [
|
31 | 31 | "## Stack representation\n",
|
32 | 32 | "\n",
|
33 |
| - "The `stack` variable has two variables, buffer and index. The buffer is the writable buffer where stack elements are stored. Index points to top of stack + 1." |
| 33 | + "The `stack` variable has two variables, buffer and index. The buffer is the writable buffer where stack elements are stored. Index points to top of stack + 1.\n", |
| 34 | + "\n", |
| 35 | + "Note: Since these functions can create variables, they must execute eagerly." |
34 | 36 | ]
|
35 | 37 | },
|
36 | 38 | {
|
|
248 | 250 | "print(tape.gradient(element, buffer))"
|
249 | 251 | ]
|
250 | 252 | },
|
| 253 | + { |
| 254 | + "cell_type": "markdown", |
| 255 | + "metadata": {}, |
| 256 | + "source": [ |
| 257 | + "## Stack peek\n", |
| 258 | + "\n", |
| 259 | + "Get the stack top without any modification" |
| 260 | + ] |
| 261 | + }, |
| 262 | + { |
| 263 | + "cell_type": "code", |
| 264 | + "execution_count": 7, |
| 265 | + "metadata": {}, |
| 266 | + "outputs": [ |
| 267 | + { |
| 268 | + "name": "stdout", |
| 269 | + "output_type": "stream", |
| 270 | + "text": [ |
| 271 | + "tf.Tensor([3. 3. 3.], shape=(3,), dtype=float32)\n", |
| 272 | + "None\n" |
| 273 | + ] |
| 274 | + } |
| 275 | + ], |
| 276 | + "source": [ |
| 277 | + "@tf.function\n", |
| 278 | + "def stack_peek(stack):\n", |
| 279 | + " buffer, index = stack\n", |
| 280 | + " index = tf.roll(index, shift=-1, axis=0)\n", |
| 281 | + " element = superposition_lookup_vectored(buffer, index)\n", |
| 282 | + " return element\n", |
| 283 | + "\n", |
| 284 | + "buffer = tf.Variable([\n", |
| 285 | + " [1,1,1],\n", |
| 286 | + " [2,2,2],\n", |
| 287 | + " [3,3,3]\n", |
| 288 | + "],dtype=tf.float32)\n", |
| 289 | + "stack = new_stack_from_buffer(buffer, True)\n", |
| 290 | + "\n", |
| 291 | + "with tf.GradientTape() as tape:\n", |
| 292 | + " element = stack_peek(stack)\n", |
| 293 | + "\n", |
| 294 | + "print(element)\n", |
| 295 | + "print(tape.gradient(element, buffer))" |
| 296 | + ] |
| 297 | + }, |
251 | 298 | {
|
252 | 299 | "cell_type": "markdown",
|
253 | 300 | "metadata": {},
|
|
262 | 309 | },
|
263 | 310 | {
|
264 | 311 | "cell_type": "code",
|
265 |
| - "execution_count": 7, |
| 312 | + "execution_count": 8, |
266 | 313 | "metadata": {},
|
267 | 314 | "outputs": [
|
268 | 315 | {
|
|
328 | 375 | },
|
329 | 376 | {
|
330 | 377 | "cell_type": "code",
|
331 |
| - "execution_count": 8, |
| 378 | + "execution_count": 9, |
332 | 379 | "metadata": {},
|
333 | 380 | "outputs": [
|
334 | 381 | {
|
|
0 commit comments