Skip to content

Commit 2ecc0b8

Browse files
committed
notebook doc
1 parent fdd45a2 commit 2ecc0b8

10 files changed

+1717
-82
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ On MNIST: (a) acc[96.80%] & loss vs. epochs for mlp; (b) acc[97.86%] & loss vs.
4141

4242
Some small tests for debug during the development of this project:
4343

44-
- How to Use Jax Gradient, <ins>*Ideas about how I manage parameters in this Framework*</ins>.
45-
- When to use JIT in Jax? <ins>*About Time & Space*</ins> <mark>TODO</mark>
4644
- How to Use Mini-torch? <ins>*A brief e.g. Doc*</ins> <mark>TODO</mark>
47-
- Kaiming Initialization[[2](#reference)] used in MLP & Conv, <ins>*With math derivation*</ins>
48-
- Difference between Conv2d Operation by python loop and by **Jax.Lax**.
45+
- How to Use Jax Gradient, <ins>*Ideas about how I manage parameters in this Framework*</ins>.
46+
- Some Jax Tips, <ins>*About How to Use Jax Builtins & JIT to Optimize Loops & Matrix Operations.*</ins>
47+
- Kaiming Initialization[[2](#reference)] used in MLP & Conv, <ins>*With math derivation.*</ins>
48+
- Difference between Conv2d Operation by python loop and by <ins>**Jax.lax**</ins>.
4949
- Dropout mechanism impl, <ins>*About Seed in Jax*.</ins>
5050
- Runge-Kuta solver for Neural ODE.
5151

example/data_analysis.ipynb

+375
Large diffs are not rendered by default.

example/data_process.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,20 @@ def one_hot(y: jnp.ndarray, num_class: int):
5757
TRAIN = None
5858
TEST = None
5959

60-
shuffle_kernel = random.permutation(key, (X_train.shape[0]))
61-
X_train = X_train[shuffle_kernel][:TRAIN]
62-
y_train = y_train[shuffle_kernel][:TRAIN]
63-
shuffle_kernel = random.permutation(key, (X_test.shape[0]))
64-
X_test = X_test[shuffle_kernel][:TEST]
65-
y_test = y_test[shuffle_kernel][:TEST]
66-
60+
Shuffle = False
61+
62+
if Shuffle:
63+
shuffle_kernel = random.permutation(key, (X_train.shape[0]))
64+
X_train = X_train[shuffle_kernel][:TRAIN]
65+
y_train = y_train[shuffle_kernel][:TRAIN]
66+
shuffle_kernel = random.permutation(key, (X_test.shape[0]))
67+
X_test = X_test[shuffle_kernel][:TEST]
68+
y_test = y_test[shuffle_kernel][:TEST]
69+
else:
70+
X_train = X_train[:TRAIN]
71+
y_train = y_train[:TRAIN]
72+
X_test = X_test[:TEST]
73+
y_test = y_test[:TEST]
6774

6875
# X_train = jnp.transpose(X_train, (2, 0, 1))
6976
# X_test = jnp.transpose(X_test, (2, 0, 1))

example/main.py

+49-29
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,42 @@
55
from plugins.minitorch.nn import Rnn, Dense, Model
66
from plugins.minitorch.optimizer import Adam
77
from plugins.minitorch.initer import Initer
8-
from plugins.minitorch.utils import softmax
8+
from plugins.minitorch.utils import softmax, cross_entropy_loss, l2_regularization
9+
from plugins.minitorch.loss import CrossEntropyLoss
910

1011
from data_process import X_train, X_test, y_train, y_test
1112

1213
key = random.PRNGKey(0)
1314

1415

15-
class LSTM(Model):
16+
class MyLoss(CrossEntropyLoss):
17+
def __init__(self, f):
18+
super(MyLoss, self).__init__(f)
19+
20+
def get_loss(self, train):
21+
loss_function = lambda params, x, y_true: cross_entropy_loss(y_true, self.f(x, params, train)) + l2_regularization(params, 0.01)
22+
return loss_function
23+
24+
def get_embed_loss(self, x, y_true, train):
25+
embed_loss_function = lambda params: cross_entropy_loss(y_true, self.f(x, params, train)) + l2_regularization(params, 0.01)
26+
return embed_loss_function
27+
28+
29+
class SplitLSTM(Model):
1630
def __init__(self, lr, epoches, batch_size):
1731
super().__init__(lr=lr, epoches=epoches)
1832

1933
self.config = {
2034
'lstm:0': Rnn.get_lstm(128, 9, 64),
2135
'lstm:1even': Rnn.get_lstm(64, 64, 32),
2236
'lstm:1odd': Rnn.get_lstm(64, 64, 32),
37+
'lstm:2': Rnn.get_lstm(64, 64, 64),
2338
'fc:0': Dense.get_linear(64, 6),
2439
}
2540

2641
initer = Initer(self.config, key)
2742
self.optr = Adam(initer(), lr=lr, batch_size=batch_size)
43+
self.lossr = MyLoss(self.predict_proba)
2844

2945
def predict_proba(self, x, params, train=True):
3046
res = jnp.transpose(x, (2, 0, 1))
@@ -36,18 +52,21 @@ def predict_proba(self, x, params, train=True):
3652
even, _, _ = Rnn.lstm(even, params['lstm:1even'], self.config['lstm:1even'])
3753
odd, _, _ = Rnn.lstm(odd, params['lstm:1odd'], self.config['lstm:1odd'])
3854

39-
res = jnp.concatenate((even[-1], odd[-1]), axis=1)
55+
res = jnp.concatenate((even, odd), axis=2)
56+
57+
res, _, _ = Rnn.lstm(res, params['lstm:2'], self.config['lstm:2'])
58+
res = res[-1]
4059

4160
res = Dense.linear(res, params['fc:0'])
4261

4362
return softmax(res)
4463

4564

46-
epochs = 40
65+
epochs = 200
4766
batch_size = 64
48-
learning_rate = 0.015
67+
learning_rate = 0.005
4968

50-
model = LSTM(lr=learning_rate, epoches=epochs, batch_size=batch_size)
69+
model = SplitLSTM(lr=learning_rate, epoches=epochs, batch_size=batch_size)
5170
acc, loss, tacc, tloss = model.fit(
5271
x_train=X_train,
5372
y_train_proba=y_train,
@@ -56,33 +75,34 @@ def predict_proba(self, x, params, train=True):
5675
)
5776

5877

59-
fig, ax1 = plt.subplots()
78+
def plot_curve(acc, tacc, loss, tloss, epochs):
79+
fig, ax1 = plt.subplots()
6080

61-
plt.rcParams['font.family'] = 'Noto Serif SC'
62-
plt.rcParams['font.sans-serif'] = ['Noto Serif SC']
81+
plt.rcParams['font.family'] = 'Noto Serif SC'
82+
plt.rcParams['font.sans-serif'] = ['Noto Serif SC']
6383

64-
color = 'tab:red'
65-
ax1.set_xlabel('Epochs')
66-
ax1.set_ylabel('Accuracy', color=color)
67-
ax1.plot(range(epochs), acc, color=color, label='Train Accuracy', linestyle='-')
68-
ax1.plot(range(epochs), tacc, color=color, label='Test Accuracy', linestyle='--')
69-
ax1.tick_params(axis='y', labelcolor=color)
84+
color = 'tab:red'
85+
ax1.set_xlabel('Epochs')
86+
ax1.set_ylabel('Accuracy', color=color)
87+
ax1.plot(range(epochs), acc, color=color, label='Train Accuracy', linestyle='-')
88+
ax1.plot(range(epochs), tacc, color=color, label='Test Accuracy', linestyle='--')
89+
ax1.tick_params(axis='y', labelcolor=color)
7090

71-
ax2 = ax1.twinx()
91+
ax2 = ax1.twinx()
7292

73-
color = 'tab:blue'
74-
ax2.set_ylabel('Loss', color=color)
75-
ax2.plot(range(epochs), loss, color=color, label='Train Loss', linestyle='-')
76-
ax2.plot(range(epochs), tloss, color=color, label='Test Loss', linestyle='--')
77-
ax2.tick_params(axis='y', labelcolor=color)
93+
color = 'tab:blue'
94+
ax2.set_ylabel('Loss', color=color)
95+
ax2.plot(range(epochs), loss, color=color, label='Train Loss', linestyle='-')
96+
ax2.plot(range(epochs), tloss, color=color, label='Test Loss', linestyle='--')
97+
ax2.tick_params(axis='y', labelcolor=color)
7898

79-
handles1, labels1 = ax1.get_legend_handles_labels()
80-
handles2, labels2 = ax2.get_legend_handles_labels()
81-
ax1.legend(handles1 + handles2, labels1 + labels2, loc='lower right')
99+
handles1, labels1 = ax1.get_legend_handles_labels()
100+
handles2, labels2 = ax2.get_legend_handles_labels()
101+
ax1.legend(handles1 + handles2, labels1 + labels2, loc='lower right')
82102

83-
plt.title('Training and Testing Accuracy and Loss over Epochs')
84-
fig.tight_layout()
85-
plt.show()
103+
plt.title('Training and Testing Accuracy and Loss over Epochs')
104+
fig.tight_layout()
105+
plt.show()
86106

87-
print(f'final train, test acc : {acc[-1]}, {tacc[-1]}')
88-
print(f'final train, test loss: {loss[-1]}, {tloss[-1]}')
107+
print(f'final train, test acc : {acc[-1]}, {tacc[-1]}')
108+
print(f'final train, test loss: {loss[-1]}, {tloss[-1]}')

example/plot.py

Whitespace-only changes.

example/rebuild_date.ipynb

+425
Large diffs are not rendered by default.

notebook_docs/grad.ipynb

+3-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@
158158
"Here is a very simple MLP case. As you can see, we get a gradient dict of trainable parameters we inited before. And then you can apply this result to GD algorithms like SGD, Adam... easy right? \n",
159159
"But this is also not what we want. This kind of initalization and optimization is very complex. So we can apply Pipeline Pattern to make it more easy to manage this procedure for users: \n",
160160
"\n",
161-
"![pipeline.svg](../assets/notebook_docs/minitorch.svg)\n",
161+
"<p align=\"center\">\n",
162+
" <img src=\"../assets/notebook_docs/minitorch.svg\" alt=\"Overview of framework\", width=\"50%\">\n",
163+
"</p>\n",
162164
"\n",
163165
"<p align=\"center\">\n",
164166
"Overview of Framework\n",

0 commit comments

Comments
 (0)