Skip to content

Commit da2566c

Browse files
committed
Add first bigram model
0 parents  commit da2566c

File tree

3 files changed

+24110
-0
lines changed

3 files changed

+24110
-0
lines changed

.gitignore

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Jupyter Notebook
10+
.ipynb_checkpoints
11+
12+
# Environments
13+
.env
14+
.venv
15+
env/
16+
venv/
17+
ENV/
18+
env.bak/
19+
venv.bak/
20+
21+
# Cython debug symbols
22+
cython_debug/

prototype.ipynb

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch\n",
10+
"import torch.nn as nn\n",
11+
"from torch.nn import functional as F"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 2,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"with open('verne.txt', 'r') as f:\n",
21+
" text = f.read()\n",
22+
"\n",
23+
"vocab_size = len(set(text))"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": 3,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"# construct a character level tokenizer\n",
33+
"ctoi = {c:i for i,c in enumerate(set(text))}\n",
34+
"itoc = {i:c for i,c in enumerate(set(text))}\n",
35+
"encode = lambda x: [ctoi[c] for c in x]\n",
36+
"decode = lambda x: ''.join([itoc[i] for i in x])"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 24,
42+
"metadata": {},
43+
"outputs": [],
44+
"source": [
45+
"data = torch.tensor(encode(text), dtype=torch.long)\n",
46+
"\n",
47+
"n = int(len(data) *.9)\n",
48+
"train_data = data[:n]\n",
49+
"val_data = data[n:]\n",
50+
"\n",
51+
"batch_size = 32\n",
52+
"block_size = 8\n",
53+
"device = torch.device('mps')\n",
54+
"\n",
55+
"def get_batch(split):\n",
56+
" data = train_data if split == 'train' else val_data\n",
57+
" ix = torch.randint(0, len(data) - block_size, (batch_size,))\n",
58+
" x = torch.stack([data[i:i+block_size] for i in ix])\n",
59+
" y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
60+
" return x.to(device), y.to(device)"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"metadata": {},
67+
"outputs": [],
68+
"source": []
69+
},
70+
{
71+
"cell_type": "code",
72+
"execution_count": 25,
73+
"metadata": {},
74+
"outputs": [],
75+
"source": [
76+
"class BigramLanguageModel(nn.Module):\n",
77+
" def __init__(self, vocab_size: int):\n",
78+
" super().__init__()\n",
79+
" #construct a lookup table where each row corresponds to each token\n",
80+
" #and contains the logits for the next tokcn\n",
81+
" self.embedding_table= nn.Embedding(vocab_size, vocab_size)\n",
82+
"\n",
83+
" def forward(self, idx:torch.Tensor, target:torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor | None]:\n",
84+
" #look up the logits for the next token\n",
85+
" logits = self.embedding_table(idx)\n",
86+
"\n",
87+
" if target is None:\n",
88+
" loss = None\n",
89+
" else:\n",
90+
" #compute the loss\n",
91+
" B, T, C = logits.shape\n",
92+
" logits = logits.view(B*T, C)\n",
93+
" loss = F.cross_entropy(logits, target.view(-1))\n",
94+
" return logits, loss\n",
95+
"\n",
96+
" def generate(self, idx: torch.Tensor, max_tokens:int) -> torch.Tensor:\n",
97+
" #generate tokens\n",
98+
" with torch.no_grad():\n",
99+
" for _ in range(max_tokens):\n",
100+
" logits, loss = self.forward(idx)\n",
101+
" logits = logits[:, -1, :]\n",
102+
" probs = F.softmax(logits, dim=-1)\n",
103+
" next_token = torch.multinomial(probs, 1)\n",
104+
" idx = torch.cat((idx, next_token), dim=1)\n",
105+
" return idx"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": 26,
111+
"metadata": {},
112+
"outputs": [
113+
{
114+
"name": "stdout",
115+
"output_type": "stream",
116+
"text": [
117+
"torch.Size([32, 8])\n",
118+
"tensor(5.1389, device='mps:0', grad_fn=<NllLossBackward0>)\n",
119+
"0£h Œi((“WI_+z:YyNXn=-1”_Tr5i£:oN“3$\n",
120+
"°m/zfŒ\"EfYM5>:3&OgPŒ,‘J-6i1/_V_″vfS7I@FnCé=—A\n",
121+
"N2:i57ï/)X1!nEb,>\n"
122+
]
123+
}
124+
],
125+
"source": [
126+
"torch.manual_seed(1337)\n",
127+
"bigram = BigramLanguageModel(vocab_size).to(device) \n",
128+
"x, y = get_batch('train')\n",
129+
"\n",
130+
"print(x.shape)\n",
131+
"\n",
132+
"logits, loss = bigram(x,y)\n",
133+
"print(loss)\n",
134+
"\n",
135+
"print(decode(bigram.generate(torch.zeros(1,1, dtype=torch.long, device=device), 100)[0].tolist()))"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": 27,
141+
"metadata": {},
142+
"outputs": [],
143+
"source": [
144+
"optimizer = torch.optim.AdamW(bigram.parameters(), lr=1e-3)"
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": 29,
150+
"metadata": {},
151+
"outputs": [
152+
{
153+
"name": "stdout",
154+
"output_type": "stream",
155+
"text": [
156+
"2.5752177238464355\n"
157+
]
158+
}
159+
],
160+
"source": [
161+
"for i in range(10000):\n",
162+
" x,y = get_batch('train')\n",
163+
" logits, loss = bigram(x,y)\n",
164+
" optimizer.zero_grad()\n",
165+
" loss.backward()\n",
166+
" optimizer.step()\n",
167+
"print(loss.item())"
168+
]
169+
},
170+
{
171+
"cell_type": "code",
172+
"execution_count": 37,
173+
"metadata": {},
174+
"outputs": [
175+
{
176+
"name": "stdout",
177+
"output_type": "stream",
178+
"text": [
179+
"0\n",
180+
"The borven iove s theannokintwaim, we---trs to ar-o ad anted I ves, “Hibouerthedeloke ier Theatis l\n"
181+
]
182+
}
183+
],
184+
"source": [
185+
"print(decode(bigram.generate(torch.zeros(1,1, dtype=torch.long, device=device), 100)[0].tolist()))"
186+
]
187+
}
188+
],
189+
"metadata": {
190+
"kernelspec": {
191+
"display_name": "verne-encoder-transformer",
192+
"language": "python",
193+
"name": "python3"
194+
},
195+
"language_info": {
196+
"codemirror_mode": {
197+
"name": "ipython",
198+
"version": 3
199+
},
200+
"file_extension": ".py",
201+
"mimetype": "text/x-python",
202+
"name": "python",
203+
"nbconvert_exporter": "python",
204+
"pygments_lexer": "ipython3",
205+
"version": "3.10.9"
206+
},
207+
"orig_nbformat": 4
208+
},
209+
"nbformat": 4,
210+
"nbformat_minor": 2
211+
}

0 commit comments

Comments
 (0)