|
129 | 129 | "from plugins.minitorch.optimizer import Adam\n",
|
130 | 130 | "from plugins.minitorch.utils import softmax, cross_entropy_loss\n",
|
131 | 131 | "from plugins.minitorch.initer import Initer\n",
|
132 |
| - "from plugins.minitorch.nn import dropout\n", |
| 132 | + "from plugins.minitorch.nn import Dense\n", |
| 133 | + "from plugins.minitorch.loss import CrossEntropyLoss\n", |
133 | 134 | "\n",
|
134 | 135 | "class mlp_clf:\n",
|
135 | 136 | " def __init__(self, lr=0.01):\n",
|
136 | 137 | " self.config = {\n",
|
137 |
| - " 'fc4relu:0' : {\n", |
138 |
| - " 'input_dim': 32 * 32 * 3,\n", |
139 |
| - " 'output_dim': 128 * 3,\n", |
140 |
| - " },\n", |
141 |
| - " 'relu': {},\n", |
142 |
| - " 'fc4relu:1' : {\n", |
143 |
| - " 'input_dim': 128 * 3,\n", |
144 |
| - " 'output_dim': 64 * 3,\n", |
145 |
| - " },\n", |
146 |
| - " 'fc4relu:2' : {\n", |
147 |
| - " 'input_dim': 64 * 3,\n", |
148 |
| - " 'output_dim': 32 * 3,\n", |
149 |
| - " },\n", |
150 |
| - " 'fc4relu:3' : {\n", |
151 |
| - " 'input_dim': 32 * 3,\n", |
152 |
| - " 'output_dim': 16 * 3,\n", |
153 |
| - " },\n", |
154 |
| - " 'fc4relu:4' : {\n", |
155 |
| - " 'input_dim': 16 * 3,\n", |
156 |
| - " 'output_dim': 10,\n", |
157 |
| - " }\n", |
| 138 | + " 'fc:0': Dense.get_linear(32 * 32 * 3, 128 * 3),\n", |
| 139 | + " 'fc:1': Dense.get_linear(128 * 3, 64 * 3),\n", |
| 140 | + " 'fc:2': Dense.get_linear(64 * 3, 32 * 3),\n", |
| 141 | + " 'fc:3': Dense.get_linear(32 * 3, 16 * 3),\n", |
| 142 | + " 'fc:4': Dense.get_linear(16 * 3, 10),\n", |
158 | 143 | " }\n",
|
159 | 144 | "\n",
|
160 | 145 | " initer = Initer(self.config, key)\n",
|
161 | 146 | " self.optr = Adam(initer(), lr=lr, batch_size=512)\n",
|
| 147 | + " self.losser = CrossEntropyLoss(self.forward)\n", |
162 | 148 | "\n",
|
163 |
| - " def predict_proba(self, x: jnp.ndarray, params, train=False):\n", |
| 149 | + " def forward(self, x: jnp.ndarray, params, train=False):\n", |
164 | 150 | " res = x\n",
|
165 | 151 | " key = random.PRNGKey(42)\n",
|
166 | 152 | " for p in params.values():\n",
|
|
172 | 158 | " # 后面发现即使不使用JIT也会出现相同的问题,因为没有JIT也会有其它多线程的优化,导致这里发生问题。\n",
|
173 | 159 | " res = res @ p['w'] + p['b'] \n",
|
174 | 160 | " res = jnp.maximum(0, res) # use relu activation function\n",
|
175 |
| - " res, key = dropout(res, key, p=0.1, train=train)\n", |
| 161 | + " res, key = Dense.dropout(res, key, p=0.1, train=train)\n", |
176 | 162 | "\n",
|
177 | 163 | " return softmax(res)\n",
|
178 | 164 | "\n",
|
179 | 165 | " def fit(self, x_train, y_train_proba, x_test, y_test_proba, epoches=100): \n",
|
180 |
| - " cnt = 0\n", |
181 |
| - "\n", |
182 | 166 | " @jit\n",
|
183 | 167 | " def _acc(y_true_proba, y_pred_proba):\n",
|
184 | 168 | " y_true = jnp.argmax(y_true_proba, axis=1)\n",
|
185 | 169 | " y_pred = jnp.argmax(y_pred_proba, axis=1)\n",
|
186 | 170 | " return jnp.mean(y_true == y_pred)\n",
|
187 | 171 | "\n",
|
188 |
| - " _loss = lambda params, x, y_true: cross_entropy_loss(y_true, self.predict_proba(x, params, True)) \n", |
189 |
| - " _loss = jit(_loss) # accelerate loss function by JIT\n", |
| 172 | + " _loss = self.losser.get_loss(train=True)\n", |
| 173 | + " _loss = jit(_loss)\n", |
190 | 174 | " self.optr.open(_loss, x_train, y_train_proba)\n",
|
191 | 175 | " \n",
|
192 |
| - " _tloss = lambda params: cross_entropy_loss(y_test_proba, self.predict_proba(x_test, params, False)) \n", |
193 |
| - " _tloss = jit(_tloss) # accelerate loss function by JIT\n", |
| 176 | + " _tloss = self.losser.get_embed_loss(x_test, y_test_proba, train=False)\n", |
| 177 | + " _tloss = jit(_tloss)\n", |
194 | 178 | " \n",
|
195 | 179 | "\n",
|
196 | 180 | " acc, loss, tacc, tloss = [], [], [], [] # train acc, train loss, test acc, test loss\n",
|
197 | 181 | "\n",
|
198 |
| - " for _ in range(epoches):\n", |
| 182 | + " for cnt in range(epoches):\n", |
199 | 183 | " loss.append(_loss(self.optr.get_params(), x_train, y_train_proba))\n",
|
200 | 184 | " tloss.append(_tloss(self.optr.get_params()))\n",
|
201 | 185 | "\n",
|
202 |
| - " self.train = True # use dropout only while updating grads\n", |
203 | 186 | " self.optr.update()\n",
|
204 |
| - " self.train = False\n", |
205 | 187 | " \n",
|
206 |
| - " acc.append(_acc(y_train_proba, self.predict_proba(x_train, self.optr.get_params())))\n", |
207 |
| - " tacc.append(_acc(y_test_proba, self.predict_proba(x_test, self.optr.get_params())))\n", |
208 |
| - " cnt += 1\n", |
209 |
| - " if cnt % 10 == 0:\n", |
210 |
| - " print(f'>> epoch: {cnt}, train acc: {acc[-1]}, test acc: {tacc[-1]}')\n", |
| 188 | + " acc.append(_acc(y_train_proba, self.forward(x_train, self.optr.get_params())))\n", |
| 189 | + " tacc.append(_acc(y_test_proba, self.forward(x_test, self.optr.get_params())))\n", |
| 190 | + " if (cnt + 1) % 10 == 0:\n", |
| 191 | + " print(f'>> epoch: {cnt + 1}, train acc: {acc[-1]}, test acc: {tacc[-1]}')\n", |
211 | 192 | "\n",
|
212 | 193 | " return acc, loss, tacc, tloss"
|
213 | 194 | ]
|
|
0 commit comments