Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head): # [n_seq, n_embd] -> [

def gpt2(inputs, wte, wpe, blocks, ln_f, n_head): # [n_seq] -> [n_seq, n_vocab]
# token + positional embeddings
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))] # [n_seq] -> [n_seq, n_embd]

# forward pass through n_layer transformer blocks
for block in blocks:
Expand Down
2 changes: 1 addition & 1 deletion gpt2_pico.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
return x

def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
x = wte[inputs] + wpe[range(len(inputs))]
x = wte[np.array(inputs)] + wpe[np.array(range(len(inputs)))]
for block in blocks:
x = transformer_block(x, **block, n_head=n_head)
return layer_norm(x, **ln_f) @ wte.T
Expand Down