Skip to content

Commit a3d73ef

Browse files
committed
impl & test for all case by adding new initer module
1 parent 2b58fc1 commit a3d73ef

12 files changed

+599
-563
lines changed

cnn_cifar10.ipynb

+36-66
Large diffs are not rendered by default.

cnn_mnist.ipynb

+33-72
Large diffs are not rendered by default.

gru_ucihar.ipynb

+16-23
Large diffs are not rendered by default.

lstm_ucihar.ipynb

+26-34
Large diffs are not rendered by default.

mlp_cifar10.ipynb

+19-38
Original file line numberDiff line numberDiff line change
@@ -129,38 +129,24 @@
129129
"from plugins.minitorch.optimizer import Adam\n",
130130
"from plugins.minitorch.utils import softmax, cross_entropy_loss\n",
131131
"from plugins.minitorch.initer import Initer\n",
132-
"from plugins.minitorch.nn import dropout\n",
132+
"from plugins.minitorch.nn import Dense\n",
133+
"from plugins.minitorch.loss import CrossEntropyLoss\n",
133134
"\n",
134135
"class mlp_clf:\n",
135136
" def __init__(self, lr=0.01):\n",
136137
" self.config = {\n",
137-
" 'fc4relu:0' : {\n",
138-
" 'input_dim': 32 * 32 * 3,\n",
139-
" 'output_dim': 128 * 3,\n",
140-
" },\n",
141-
" 'relu': {},\n",
142-
" 'fc4relu:1' : {\n",
143-
" 'input_dim': 128 * 3,\n",
144-
" 'output_dim': 64 * 3,\n",
145-
" },\n",
146-
" 'fc4relu:2' : {\n",
147-
" 'input_dim': 64 * 3,\n",
148-
" 'output_dim': 32 * 3,\n",
149-
" },\n",
150-
" 'fc4relu:3' : {\n",
151-
" 'input_dim': 32 * 3,\n",
152-
" 'output_dim': 16 * 3,\n",
153-
" },\n",
154-
" 'fc4relu:4' : {\n",
155-
" 'input_dim': 16 * 3,\n",
156-
" 'output_dim': 10,\n",
157-
" }\n",
138+
" 'fc:0': Dense.get_linear(32 * 32 * 3, 128 * 3),\n",
139+
" 'fc:1': Dense.get_linear(128 * 3, 64 * 3),\n",
140+
" 'fc:2': Dense.get_linear(64 * 3, 32 * 3),\n",
141+
" 'fc:3': Dense.get_linear(32 * 3, 16 * 3),\n",
142+
" 'fc:4': Dense.get_linear(16 * 3, 10),\n",
158143
" }\n",
159144
"\n",
160145
" initer = Initer(self.config, key)\n",
161146
" self.optr = Adam(initer(), lr=lr, batch_size=512)\n",
147+
" self.losser = CrossEntropyLoss(self.forward)\n",
162148
"\n",
163-
" def predict_proba(self, x: jnp.ndarray, params, train=False):\n",
149+
" def forward(self, x: jnp.ndarray, params, train=False):\n",
164150
" res = x\n",
165151
" key = random.PRNGKey(42)\n",
166152
" for p in params.values():\n",
@@ -172,42 +158,37 @@
172158
" # 后面发现即使不使用JIT也会出现相同的问题,因为没有JIT也会有其它多线程的优化,导致这里发生问题。\n",
173159
" res = res @ p['w'] + p['b'] \n",
174160
" res = jnp.maximum(0, res) # use relu activation function\n",
175-
" res, key = dropout(res, key, p=0.1, train=train)\n",
161+
" res, key = Dense.dropout(res, key, p=0.1, train=train)\n",
176162
"\n",
177163
" return softmax(res)\n",
178164
"\n",
179165
" def fit(self, x_train, y_train_proba, x_test, y_test_proba, epoches=100): \n",
180-
" cnt = 0\n",
181-
"\n",
182166
" @jit\n",
183167
" def _acc(y_true_proba, y_pred_proba):\n",
184168
" y_true = jnp.argmax(y_true_proba, axis=1)\n",
185169
" y_pred = jnp.argmax(y_pred_proba, axis=1)\n",
186170
" return jnp.mean(y_true == y_pred)\n",
187171
"\n",
188-
" _loss = lambda params, x, y_true: cross_entropy_loss(y_true, self.predict_proba(x, params, True)) \n",
189-
" _loss = jit(_loss) # accelerate loss function by JIT\n",
172+
" _loss = self.losser.get_loss(train=True)\n",
173+
" _loss = jit(_loss)\n",
190174
" self.optr.open(_loss, x_train, y_train_proba)\n",
191175
" \n",
192-
" _tloss = lambda params: cross_entropy_loss(y_test_proba, self.predict_proba(x_test, params, False)) \n",
193-
" _tloss = jit(_tloss) # accelerate loss function by JIT\n",
176+
" _tloss = self.losser.get_embed_loss(x_test, y_test_proba, train=False)\n",
177+
" _tloss = jit(_tloss)\n",
194178
" \n",
195179
"\n",
196180
" acc, loss, tacc, tloss = [], [], [], [] # train acc, train loss, test acc, test loss\n",
197181
"\n",
198-
" for _ in range(epoches):\n",
182+
" for cnt in range(epoches):\n",
199183
" loss.append(_loss(self.optr.get_params(), x_train, y_train_proba))\n",
200184
" tloss.append(_tloss(self.optr.get_params()))\n",
201185
"\n",
202-
" self.train = True # use dropout only while updating grads\n",
203186
" self.optr.update()\n",
204-
" self.train = False\n",
205187
" \n",
206-
" acc.append(_acc(y_train_proba, self.predict_proba(x_train, self.optr.get_params())))\n",
207-
" tacc.append(_acc(y_test_proba, self.predict_proba(x_test, self.optr.get_params())))\n",
208-
" cnt += 1\n",
209-
" if cnt % 10 == 0:\n",
210-
" print(f'>> epoch: {cnt}, train acc: {acc[-1]}, test acc: {tacc[-1]}')\n",
188+
" acc.append(_acc(y_train_proba, self.forward(x_train, self.optr.get_params())))\n",
189+
" tacc.append(_acc(y_test_proba, self.forward(x_test, self.optr.get_params())))\n",
190+
" if (cnt + 1) % 10 == 0:\n",
191+
" print(f'>> epoch: {cnt + 1}, train acc: {acc[-1]}, test acc: {tacc[-1]}')\n",
211192
"\n",
212193
" return acc, loss, tacc, tloss"
213194
]

0 commit comments

Comments
 (0)