Skip to content

Commit

Permalink
EXP: Upgrade to a jax.lax.cond
Browse files Browse the repository at this point in the history
  • Loading branch information
neel04 committed Apr 14, 2024
1 parent 1cb791d commit a572c47
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions ReAct/model/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def f(input_tup: Tuple[Array, int], _dynamic_bl: PyTree) -> Tuple[Tuple[Array, i

block = eqx.combine(_dynamic_bl, static_part) # reconstruct the block

if idx == 0:
output = block(input_arr, input_arr, pad_mask, enable_dropout, key).astype(jnp.bfloat16)
else:
output = block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16) # self-attention
output = jax.lax.cond(idx == 0,
lambda: block(input_arr, input_arr, pad_mask, enable_dropout, key).astype(jnp.bfloat16),
lambda: block(x, x, pad_mask, enable_dropout, key).astype(jnp.bfloat16))

return (output, idx + 1), None

Expand Down

0 comments on commit a572c47

Please sign in to comment.