diff --git a/ReAct/data/minipile.py b/ReAct/data/minipile.py index 38779a5..434ed9c 100644 --- a/ReAct/data/minipile.py +++ b/ReAct/data/minipile.py @@ -4,7 +4,6 @@ import datasets import jax -import jax.numpy as jnp import numpy as np from datasets import load_dataset, load_from_disk from jaxtyping import Array diff --git a/ReAct/model/react.py b/ReAct/model/react.py index 9347edf..4858647 100644 --- a/ReAct/model/react.py +++ b/ReAct/model/react.py @@ -47,7 +47,11 @@ def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, i input_arr, idx = input_tup # i is the iteration index block = eqx.combine(_dynamic_bl, static_part) # reconstruct the block - output = block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16) # self-attention + + if idx == 0: + output = block(input_arr, input_arr, pad_mask, enable_dropout, key).astype(jnp.bfloat16) + else: + output = block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16) # self-attention return (output, idx + 1), None