Skip to content

Commit e97f13c

Browse files
committed
Refactor train and generare
1 parent 7eaabaf commit e97f13c

File tree

6 files changed

+227
-108
lines changed

6 files changed

+227
-108
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
checkpoints/
2+
.DS_Store
13
# Byte-compiled / optimized / DLL files
24
__pycache__/
35
*.py[cod]

README.md

+32-4
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,37 @@ Next, install the dependencies:
140140
```bash
141141
pip install -r requirements.txt
142142
```
143+
 
144+
### Training the Model
143145

144-
You can then run the main script:
146+
The `train.py` script trains the model. It accepts the following command line arguments:
145147

146-
```bash
147-
python main.py
148-
```
148+
- `--iters`: Total iterations to train. Default is 5000.
149+
- `--lr`: Learning rate. Default is 3e-4.
150+
- `--device`: Device to use for training. Default is "cuda" if CUDA is available, otherwise "mps".
151+
- `--checkpoint_dir`: Directory to save the model checkpoints. Default is "checkpoints".
152+
153+
Example usage:
154+
155+
```shell
156+
python train.py --iters 10000 --lr 1e-4 --device cuda --checkpoint_dir my_checkpoints
157+
```
158+
159+
This will train the model for 10000 iterations with a learning rate of 1e-4, using a CUDA device for training. The model checkpoints will be saved in the `my_checkpoints` directory.
160+
161+
 
162+
163+
### Generating New Text
164+
165+
The `generate.py` script generates new text from a trained model. It accepts the following command line arguments:
166+
167+
- `--checkpoint_path`: Path to the model checkpoint. This argument is required.
168+
- `--num_tokens`: Number of tokens to generate. Default is 100.
169+
170+
Example usage:
171+
172+
```shell
173+
python generate.py --checkpoint_path my_checkpoints/model_state_10000.pt --num_tokens 500
174+
```
175+
176+
This will generate 500 new tokens from the model checkpoint at `my_checkpoints/model_state_10000.pt`.

decoder_transformer.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
1-
from typing import Optional, Tuple
2-
1+
from __future__ import annotations
32
import torch
43
import torch.nn as nn
54
import torch.nn.functional as F
65

76

8-
class EncoderTransformer(nn.Module):
7+
class DecoderTransformer(nn.Module):
98
def __init__(
109
self,
1110
num_blocks: int,
1211
num_heads: int,
1312
embed_size: int,
14-
block_size: int,
13+
context_size: int,
1514
vocab_size: int,
1615
):
1716
super().__init__()
18-
self.block_size = block_size
17+
self.context_size = context_size
1918
self.vocab_size = vocab_size
2019
self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
21-
self.position_embedding_table = nn.Embedding(block_size, embed_size)
20+
self.position_embedding_table = nn.Embedding(context_size, embed_size)
2221
head_size = embed_size // num_heads
2322
self.blocks = nn.Sequential(
2423
*[
25-
Block(num_heads, head_size, embed_size, block_size)
24+
Block(num_heads, head_size, embed_size, context_size)
2625
for _ in range(num_blocks)
2726
]
2827
+ [nn.LayerNorm(embed_size)]
@@ -49,23 +48,23 @@ def forward(
4948

5049
return logits, loss
5150

52-
def generate(self, idx: torch.Tensor, max_tokens: int) -> torch.Tensor:
51+
def generate(self, context: torch.Tensor, num_tokens: int) -> torch.Tensor:
5352
# generate tokens
5453
with torch.no_grad():
55-
for i in range(max_tokens):
56-
cond_idx = idx[:, -self.block_size :]
57-
logits, _ = self.forward(cond_idx)
54+
for _ in range(num_tokens):
55+
cond_context = context[:, -self.context_size :]
56+
logits, _ = self.forward(cond_context)
5857
logits = logits[:, -1, :]
5958
probs = F.softmax(logits, dim=-1)
6059
next_token = torch.multinomial(probs, 1)
61-
idx = torch.cat((idx, next_token), dim=1)
62-
return idx
60+
context = torch.cat((context, next_token), dim=1)
61+
return context
6362

6463

6564
class MultiHeadAttention(nn.Module):
6665
"""
6766
A multi-head attention layer.
68-
Takees in a number of heads retruen a concatenated output of all heads.
67+
Takes in a number of heads returns a concatenated output of all heads.
6968
"""
7069

7170
def __init__(
@@ -148,7 +147,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
148147

149148
class Block(nn.Module):
150149
"""
151-
A single block of the Transformer.
150+
A single transformer block.
152151
"""
153152

154153
def __init__(

generate.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import argparse
2+
3+
import torch
4+
5+
from decoder_transformer import DecoderTransformer
6+
7+
8+
def main():
9+
# Define command line arguments
10+
parser = argparse.ArgumentParser(
11+
description="Generate text from a trained transformer model."
12+
)
13+
parser.add_argument(
14+
"--checkpoint_path",
15+
type=str,
16+
required=True,
17+
help="Path to the model checkpoint.",
18+
)
19+
parser.add_argument(
20+
"--num_tokens", type=int, default=100, help="Number of tokens to generate."
21+
)
22+
parser.add_argument(
23+
"--device",
24+
type=str,
25+
default="cuda" if torch.cuda.is_available() else "mps",
26+
help="Device to use for training.",
27+
)
28+
args = parser.parse_args()
29+
30+
with open("verne.txt", "r") as f:
31+
text = f.read()
32+
33+
device = torch.device(args.device)
34+
vocab_size = len(set(text))
35+
embed_size = 384
36+
context_size = 256
37+
num_heads = 6
38+
num_blocks = 6
39+
40+
state_dict = torch.load(args.checkpoint_path)
41+
encode = lambda x: [state_dict["encoder_dictionary"][c] for c in x]
42+
decode = lambda x: "".join([state_dict["decoder_dictionary"][i] for i in x])
43+
44+
# Load the model from the checkpoint
45+
model = DecoderTransformer(
46+
num_blocks=num_blocks,
47+
num_heads=num_heads,
48+
embed_size=embed_size,
49+
context_size=context_size,
50+
vocab_size=vocab_size,
51+
).to(device)
52+
model.load_state_dict(state_dict, strict=False)
53+
54+
encoded_context = (
55+
torch.tensor(encode("The "), dtype=torch.long).unsqueeze(0).to(device)
56+
)
57+
58+
# Generate text
59+
generated_text = model.generate(encoded_context, args.num_tokens)
60+
print(decode(generated_text.tolist()[0]))
61+
62+
63+
if __name__ == "__main__":
64+
main()

main.py

-89
This file was deleted.

train.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import argparse
2+
import os
3+
from datetime import datetime
4+
5+
import torch
6+
from tqdm import tqdm
7+
8+
from decoder_transformer import DecoderTransformer
9+
10+
with open("verne.txt", "r") as f:
11+
text = f.read()
12+
13+
vocab_size = len(set(text))
14+
batch_size = 64
15+
embed_size = 384
16+
context_size = 256
17+
num_heads = 6
18+
num_blocks = 6
19+
20+
# Define command line arguments
21+
parser = argparse.ArgumentParser(description="Train a transformer model.")
22+
parser.add_argument(
23+
"--iters", type=int, default=5000, help="Total iterations to train."
24+
)
25+
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate.")
26+
parser.add_argument(
27+
"--device",
28+
type=str,
29+
default="cuda" if torch.cuda.is_available() else "mps",
30+
help="Device to use for training.",
31+
)
32+
parser.add_argument(
33+
"--checkpoint_dir",
34+
type=str,
35+
default="checkpoints",
36+
help="Directory to save the model checkpoints.",
37+
)
38+
args = parser.parse_args()
39+
40+
device = torch.device(args.device)
41+
learning_rate = args.lr
42+
total_iters = args.iters
43+
eval_iters = total_iters // 10
44+
45+
# construct a character level tokenizer
46+
ctoi = {c: i for i, c in enumerate(set(text))}
47+
itoc = {i: c for i, c in enumerate(set(text))}
48+
encode = lambda x: [ctoi[c] for c in x]
49+
decode = lambda x: "".join([itoc[i] for i in x])
50+
51+
data = torch.tensor(encode(text), dtype=torch.long)
52+
n = int(len(data) * 0.9)
53+
train_data = data[:n]
54+
val_data = data[n:]
55+
56+
57+
def get_batch(split):
58+
data = train_data if split == "train" else val_data
59+
ix = torch.randint(0, len(data) - context_size, (batch_size,))
60+
x = torch.stack([data[i : i + context_size] for i in ix])
61+
y = torch.stack([data[i + 1 : i + context_size + 1] for i in ix])
62+
return x.to(device), y.to(device)
63+
64+
65+
@torch.no_grad()
66+
def eval_loss(model):
67+
model.eval()
68+
out = {}
69+
for split in ["train", "val"]:
70+
losses = torch.zeros(eval_iters)
71+
for i in range(eval_iters):
72+
x, y = get_batch(split)
73+
_, loss = model(x, y)
74+
losses[i] = loss.item()
75+
out[split] = losses.mean().item()
76+
model.train()
77+
return out
78+
79+
80+
start_time = datetime.now().strftime("%Y%m%d_%H%M")
81+
82+
model = DecoderTransformer(
83+
num_blocks=num_blocks,
84+
num_heads=num_heads,
85+
embed_size=embed_size,
86+
context_size=context_size,
87+
vocab_size=vocab_size,
88+
).to(device)
89+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
90+
progress_bar = tqdm(range(total_iters))
91+
checkpoint_dir = args.checkpoint_dir
92+
os.makedirs(checkpoint_dir, exist_ok=True)
93+
94+
# train the model
95+
for i in progress_bar:
96+
model.train()
97+
x, y = get_batch("train")
98+
logits, loss = model(x, y)
99+
optimizer.zero_grad()
100+
loss.backward()
101+
optimizer.step()
102+
103+
if i % eval_iters == 0 and i > 0:
104+
# Save the model state
105+
state_dict = model.state_dict()
106+
state_dict["encoder_dictionary"] = ctoi
107+
state_dict["decoder_dictionary"] = itoc
108+
torch.save(
109+
state_dict,
110+
os.path.join(checkpoint_dir, f"{start_time}_model_state_{i}.pt"),
111+
)
112+
113+
# Log the losses
114+
losses = eval_loss(model)
115+
progress_bar.set_postfix(losses)

0 commit comments

Comments
 (0)