diff --git a/gpt2.py b/gpt2.py index c85e8b8..a99d8ef 100644 --- a/gpt2.py +++ b/gpt2.py @@ -72,7 +72,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [ def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab] # token + positional embeddings - x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd] + x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))] # [n_seq] -> [n_seq, n_embd] # forward pass through n_layer transformer blocks for block in blocks: diff --git a/gpt2_pico.py b/gpt2_pico.py index 762ed12..6afee28 100644 --- a/gpt2_pico.py +++ b/gpt2_pico.py @@ -35,7 +35,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): return x def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): - x = wte[inputs] + wpe[range(len(inputs))] + x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))] for block in blocks: x = transformer_block(x, **block, n_head=n_head) return layer_norm(x, **ln_f) @ wte.T