Skip to content

Commit 28857ae

Browse files
committedMar 17, 2025·
add conv1dx3 for example
1 parent 5fc2a2a commit 28857ae

File tree

12 files changed

+75
-29
lines changed

12 files changed

+75
-29
lines changed
 

‎example/clfs.py

+65-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax.numpy as jnp
22
from jax import random
33

4-
from plugins.minitorch.nn import Rnn, Dense
4+
from plugins.minitorch.nn import Conv, Rnn, Dense
55
from plugins.minitorch import Initer
66
from plugins.minitorch.optimizer import Adam
77
from plugins.minitorch.loss import CrossEntropyLoss
@@ -101,5 +101,67 @@ def fit(self, x, y):
101101
self.optr.close()
102102

103103

104-
class conv3x3(Clfs):
105-
pass
104+
class conv1dx3(Clfs):
105+
def __init__(self, lr, epoches, batch_size, depth=1):
106+
super(conv1dx3, self).__init__()
107+
108+
self.config = {
109+
'conv1d:00': Conv.get_conv1d(9, 16, (3,)), # 128 -> 126
110+
'conv1d:01': Conv.get_conv1d(16, 16, (3,)), # 126 -> 124
111+
'maxpooling1d:0': Conv.get_max_pool1d(2), # 124 -> 62
112+
'conv1d:10': Conv.get_conv1d(16, 32, (3,)), # 62 -> 60
113+
'conv1d:11': Conv.get_conv1d(32, 32, (3,)), # 60 -> 58
114+
'maxpooling1d:1': Conv.get_max_pool1d(2), # 58 -> 29
115+
'conv1d:20': Conv.get_conv1d(32, 64, (3,)), # 29 -> 27
116+
'conv1d:21': Conv.get_conv1d(64, 64, (3,)), # 27 -> 25
117+
'maxpooling1d:2': Conv.get_max_pool1d(2), # 25 -> 12
118+
'fc:0': Dense.get_linear(12 * 64, 256), # 64 x 12 = 768
119+
'fc:1': Dense.get_linear(256, 64),
120+
'fc:2': Dense.get_linear(64, 6)
121+
}
122+
123+
self.epoches = epoches
124+
self.lr = lr
125+
self.batch_size = batch_size
126+
self.losr = CrossEntropyLoss(self.forward)
127+
128+
def conv_block(self, x, params, id):
129+
res = Conv.conv1d(x, params[f'conv1d:{id}0'], self.config[f'conv1d:{id}0'])
130+
res = Conv.conv1d(res, params[f'conv1d:{id}1'], self.config[f'conv1d:{id}1'])
131+
res = Conv.max_pooling1d(res, self.config[f'maxpooling1d:{id}'])
132+
133+
return res
134+
135+
def forward(self, x, params, train=False):
136+
res = self.conv_block(x, params, 0)
137+
res = self.conv_block(res, params, 1)
138+
res = self.conv_block(res, params, 2)
139+
140+
res = res.reshape(res.shape[0], -1)
141+
142+
res = Dense.linear(res, params['fc:0'])
143+
res = Dense.linear(res, params['fc:1'])
144+
res = Dense.linear(res, params['fc:2'])
145+
146+
return softmax(res)
147+
148+
@timing
149+
def predict_proba(self, x):
150+
return self.forward(params=self.optr.get_params(), x=x, train=False)
151+
152+
@timing
153+
def fit(self, x, y):
154+
self.optr = Adam(Initer(self.config, random.PRNGKey(42))(), lr=self.lr, batch_size=self.batch_size)
155+
_loss = self.losr.get_loss(True)
156+
self.optr.open(_loss, x, y)
157+
158+
_tloss = self.losr.get_loss(False)
159+
160+
log_wise = self.epoches // 10 if self.epoches >= 10 else self.epoches
161+
for cnt in range(self.epoches):
162+
if (cnt + 1) % log_wise == 0:
163+
print(f'====> Epoch {cnt + 1}/{self.epoches}, loss: {_tloss(self.optr.get_params(), x, y)}')
164+
165+
self.optr.update()
166+
167+
self.optr.close()
File renamed without changes.
File renamed without changes.

‎example/data_process.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def one_hot(y: jnp.ndarray, num_class: int):
7575
# X_train = jnp.transpose(X_train, (2, 0, 1))
7676
# X_test = jnp.transpose(X_test, (2, 0, 1))
7777

78-
print('X_train 形状:', X_train.shape) # 应为 (128, 7352, 9)
78+
print('X_train 形状:', X_train.shape) # 应为 (7352, 9, 128)
7979
print('y_train 形状:', y_train.shape) # 应为 (7352, 6)
8080
print('X_test 形状:', X_test.shape)
8181
print('y_test 形状:', y_test.shape)

‎example/log/2025_03_17_17-24-24/hyper.toml

-11
This file was deleted.

‎example/log/2025_03_17_17-24-24/test.csv

-3
This file was deleted.

‎example/log/2025_03_17_17-24-24/valid.csv

-3
This file was deleted.

‎example/main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from plugins.lrkit.executer import KFlodCrossExecuter
22

3-
from clfs import lstm, gru
3+
from clfs import lstm, gru, conv1dx3
44
from data_process import X_train, X_test, y_train, y_test
55

66
excr = KFlodCrossExecuter(
77
X_train, y_train, X_test, y_test,
88
clf_dict={
9-
'gru': gru(lr=0.01, epoches=20, batch_size=64),
9+
'gru': gru(lr=0.01, epoches=30, batch_size=64),
1010
'lstm': lstm(lr=0.01, epoches=50, batch_size=64),
11+
'conv1dx3': conv1dx3(lr=0.001, epoches=50, batch_size=128),
1112
},
1213
k=5,
1314
metric_list=['accuracy', 'macro_f1', 'micro_f1', 'avg_recall'],
1415
log=True,
1516
log_dir='./log/',
1617
)
1718

18-
excr.run_all()
19+
excr.run_all(time=True)

‎mlp_cifar10.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@
158158
" # 后面发现即使不使用JIT也会出现相同的问题,因为没有JIT也会有其它多线程的优化,导致这里发生问题。\n",
159159
" res = res @ p['w'] + p['b'] \n",
160160
" res = jnp.maximum(0, res) # use relu activation function\n",
161-
" res, key = Dense.dropout(res, key, p=0.1, train=train)\n",
161+
" key, res = Dense.dropout(key, res, p=0.1, train=train)\n",
162162
"\n",
163163
" return softmax(res)\n",
164164
"\n",

‎mlp_mnist.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@
157157
" # 后面发现即使不使用JIT也会出现相同的问题,因为没有JIT也会有其它多线程的优化,导致这里发生问题。\n",
158158
" res = res @ p['w'] + p['b'] \n",
159159
" res = jnp.maximum(0, res) # use relu activation function\n",
160-
" res, key = Dense.dropout(res, key, p=0.1, train=train) # add dropout\n",
160+
" key, res = Dense.dropout(key, res, p=0.1, train=train) # add dropout\n",
161161
"\n",
162162
" return softmax(res)\n",
163163
"\n",

‎plugins/lrkit

‎plugins/minitorch/nn/JaxOptimized/fc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from jax import random
3636

3737

38-
def dropout(x: jnp.ndarray, key, p=0.5, train=True):
38+
def dropout(key, x: jnp.ndarray, p=0.5, train=True):
3939
'''
4040
Applies dropout to the input array during training.
4141
@@ -76,7 +76,7 @@ def dropout(x: jnp.ndarray, key, p=0.5, train=True):
7676
new_key, use_key = random.split(key) # update key, to make mask different in different **batch**.
7777
mask = random.bernoulli(use_key, p_keep, x.shape)
7878

79-
return jnp.where(mask, x / p_keep, 0), new_key # scale here to make E(X) the same while evaluating.
79+
return new_key, jnp.where(mask, x / p_keep, 0) # scale here to make E(X) the same while evaluating.
8080

8181

8282
def _linear(x: jnp.ndarray, w: jnp.ndarray, b: jnp.ndarray):

0 commit comments

Comments
 (0)
Please sign in to comment.