Skip to content

Commit

Permalink
EXP: Use cross attention
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Apr 14, 2024
1 parent 61f5c4c commit 1cb791d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
1 change: 0 additions & 1 deletion ReAct/data/minipile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import datasets
import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset, load_from_disk
from jaxtyping import Array
Expand Down
6 changes: 5 additions & 1 deletion ReAct/model/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, i
input_arr, idx = input_tup # i is the iteration index

block = eqx.combine(_dynamic_bl, static_part) # reconstruct the block
output = block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16) # self-attention

if idx == 0:
output = block(input_arr, input_arr, pad_mask, enable_dropout, key).astype(jnp.bfloat16)
else:
output = block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16) # self-attention

return (output, idx + 1), None

Expand Down

0 comments on commit 1cb791d

Please sign in to comment.