Skip to content

Commit faf82a1

Browse files
trees forward pass done
1 parent c6b0293 commit faf82a1

File tree

3 files changed

+278
-85
lines changed

3 files changed

+278
-85
lines changed

notebooks/library/stacks.py

+8
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,11 @@ def stack_pop(state):
4343
element = superposition_lookup_vectored(buffer, index)
4444
state = (buffer, index)
4545
return state, element
46+
47+
48+
@tf.function
49+
def stack_peek(stack):
50+
buffer, index = stack
51+
index = tf.roll(index, shift=-1, axis=0)
52+
element = superposition_lookup_vectored(buffer, index)
53+
return element

notebooks/stacks.ipynb

+50-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
"source": [
3131
"## Stack representation\n",
3232
"\n",
33-
"The `stack` variable has two variables, buffer and index. The buffer is the writable buffer where stack elements are stored. Index points to top of stack + 1."
33+
"The `stack` variable has two variables, buffer and index. The buffer is the writable buffer where stack elements are stored. Index points to top of stack + 1.\n",
34+
"\n",
35+
"Note: Since these functions can create variables, they must execute eagerly."
3436
]
3537
},
3638
{
@@ -248,6 +250,51 @@
248250
"print(tape.gradient(element, buffer))"
249251
]
250252
},
253+
{
254+
"cell_type": "markdown",
255+
"metadata": {},
256+
"source": [
257+
"## Stack peek\n",
258+
"\n",
259+
"Get the stack top without any modification"
260+
]
261+
},
262+
{
263+
"cell_type": "code",
264+
"execution_count": 7,
265+
"metadata": {},
266+
"outputs": [
267+
{
268+
"name": "stdout",
269+
"output_type": "stream",
270+
"text": [
271+
"tf.Tensor([3. 3. 3.], shape=(3,), dtype=float32)\n",
272+
"None\n"
273+
]
274+
}
275+
],
276+
"source": [
277+
"@tf.function\n",
278+
"def stack_peek(stack):\n",
279+
" buffer, index = stack\n",
280+
" index = tf.roll(index, shift=-1, axis=0)\n",
281+
" element = superposition_lookup_vectored(buffer, index)\n",
282+
" return element\n",
283+
"\n",
284+
"buffer = tf.Variable([\n",
285+
" [1,1,1],\n",
286+
" [2,2,2],\n",
287+
" [3,3,3]\n",
288+
"],dtype=tf.float32)\n",
289+
"stack = new_stack_from_buffer(buffer, True)\n",
290+
"\n",
291+
"with tf.GradientTape() as tape:\n",
292+
" element = stack_peek(stack)\n",
293+
"\n",
294+
"print(element)\n",
295+
"print(tape.gradient(element, buffer))"
296+
]
297+
},
251298
{
252299
"cell_type": "markdown",
253300
"metadata": {},
@@ -262,7 +309,7 @@
262309
},
263310
{
264311
"cell_type": "code",
265-
"execution_count": 7,
312+
"execution_count": 8,
266313
"metadata": {},
267314
"outputs": [
268315
{
@@ -328,7 +375,7 @@
328375
},
329376
{
330377
"cell_type": "code",
331-
"execution_count": 8,
378+
"execution_count": 9,
332379
"metadata": {},
333380
"outputs": [
334381
{

0 commit comments

Comments
 (0)