5
5
from plugins .minitorch .nn import Rnn , Dense , Model
6
6
from plugins .minitorch .optimizer import Adam
7
7
from plugins .minitorch .initer import Initer
8
- from plugins .minitorch .utils import softmax
8
+ from plugins .minitorch .utils import softmax , cross_entropy_loss , l2_regularization
9
+ from plugins .minitorch .loss import CrossEntropyLoss
9
10
10
11
from data_process import X_train , X_test , y_train , y_test
11
12
12
13
key = random .PRNGKey (0 )
13
14
14
15
15
- class LSTM (Model ):
16
+ class MyLoss (CrossEntropyLoss ):
17
+ def __init__ (self , f ):
18
+ super (MyLoss , self ).__init__ (f )
19
+
20
+ def get_loss (self , train ):
21
+ loss_function = lambda params , x , y_true : cross_entropy_loss (y_true , self .f (x , params , train )) + l2_regularization (params , 0.01 )
22
+ return loss_function
23
+
24
+ def get_embed_loss (self , x , y_true , train ):
25
+ embed_loss_function = lambda params : cross_entropy_loss (y_true , self .f (x , params , train )) + l2_regularization (params , 0.01 )
26
+ return embed_loss_function
27
+
28
+
29
+ class SplitLSTM (Model ):
16
30
def __init__ (self , lr , epoches , batch_size ):
17
31
super ().__init__ (lr = lr , epoches = epoches )
18
32
19
33
self .config = {
20
34
'lstm:0' : Rnn .get_lstm (128 , 9 , 64 ),
21
35
'lstm:1even' : Rnn .get_lstm (64 , 64 , 32 ),
22
36
'lstm:1odd' : Rnn .get_lstm (64 , 64 , 32 ),
37
+ 'lstm:2' : Rnn .get_lstm (64 , 64 , 64 ),
23
38
'fc:0' : Dense .get_linear (64 , 6 ),
24
39
}
25
40
26
41
initer = Initer (self .config , key )
27
42
self .optr = Adam (initer (), lr = lr , batch_size = batch_size )
43
+ self .lossr = MyLoss (self .predict_proba )
28
44
29
45
def predict_proba (self , x , params , train = True ):
30
46
res = jnp .transpose (x , (2 , 0 , 1 ))
@@ -36,18 +52,21 @@ def predict_proba(self, x, params, train=True):
36
52
even , _ , _ = Rnn .lstm (even , params ['lstm:1even' ], self .config ['lstm:1even' ])
37
53
odd , _ , _ = Rnn .lstm (odd , params ['lstm:1odd' ], self .config ['lstm:1odd' ])
38
54
39
- res = jnp .concatenate ((even [- 1 ], odd [- 1 ]), axis = 1 )
55
+ res = jnp .concatenate ((even , odd ), axis = 2 )
56
+
57
+ res , _ , _ = Rnn .lstm (res , params ['lstm:2' ], self .config ['lstm:2' ])
58
+ res = res [- 1 ]
40
59
41
60
res = Dense .linear (res , params ['fc:0' ])
42
61
43
62
return softmax (res )
44
63
45
64
46
- epochs = 40
65
+ epochs = 200
47
66
batch_size = 64
48
- learning_rate = 0.015
67
+ learning_rate = 0.005
49
68
50
- model = LSTM (lr = learning_rate , epoches = epochs , batch_size = batch_size )
69
+ model = SplitLSTM (lr = learning_rate , epoches = epochs , batch_size = batch_size )
51
70
acc , loss , tacc , tloss = model .fit (
52
71
x_train = X_train ,
53
72
y_train_proba = y_train ,
@@ -56,33 +75,34 @@ def predict_proba(self, x, params, train=True):
56
75
)
57
76
58
77
59
- fig , ax1 = plt .subplots ()
78
+ def plot_curve (acc , tacc , loss , tloss , epochs ):
79
+ fig , ax1 = plt .subplots ()
60
80
61
- plt .rcParams ['font.family' ] = 'Noto Serif SC'
62
- plt .rcParams ['font.sans-serif' ] = ['Noto Serif SC' ]
81
+ plt .rcParams ['font.family' ] = 'Noto Serif SC'
82
+ plt .rcParams ['font.sans-serif' ] = ['Noto Serif SC' ]
63
83
64
- color = 'tab:red'
65
- ax1 .set_xlabel ('Epochs' )
66
- ax1 .set_ylabel ('Accuracy' , color = color )
67
- ax1 .plot (range (epochs ), acc , color = color , label = 'Train Accuracy' , linestyle = '-' )
68
- ax1 .plot (range (epochs ), tacc , color = color , label = 'Test Accuracy' , linestyle = '--' )
69
- ax1 .tick_params (axis = 'y' , labelcolor = color )
84
+ color = 'tab:red'
85
+ ax1 .set_xlabel ('Epochs' )
86
+ ax1 .set_ylabel ('Accuracy' , color = color )
87
+ ax1 .plot (range (epochs ), acc , color = color , label = 'Train Accuracy' , linestyle = '-' )
88
+ ax1 .plot (range (epochs ), tacc , color = color , label = 'Test Accuracy' , linestyle = '--' )
89
+ ax1 .tick_params (axis = 'y' , labelcolor = color )
70
90
71
- ax2 = ax1 .twinx ()
91
+ ax2 = ax1 .twinx ()
72
92
73
- color = 'tab:blue'
74
- ax2 .set_ylabel ('Loss' , color = color )
75
- ax2 .plot (range (epochs ), loss , color = color , label = 'Train Loss' , linestyle = '-' )
76
- ax2 .plot (range (epochs ), tloss , color = color , label = 'Test Loss' , linestyle = '--' )
77
- ax2 .tick_params (axis = 'y' , labelcolor = color )
93
+ color = 'tab:blue'
94
+ ax2 .set_ylabel ('Loss' , color = color )
95
+ ax2 .plot (range (epochs ), loss , color = color , label = 'Train Loss' , linestyle = '-' )
96
+ ax2 .plot (range (epochs ), tloss , color = color , label = 'Test Loss' , linestyle = '--' )
97
+ ax2 .tick_params (axis = 'y' , labelcolor = color )
78
98
79
- handles1 , labels1 = ax1 .get_legend_handles_labels ()
80
- handles2 , labels2 = ax2 .get_legend_handles_labels ()
81
- ax1 .legend (handles1 + handles2 , labels1 + labels2 , loc = 'lower right' )
99
+ handles1 , labels1 = ax1 .get_legend_handles_labels ()
100
+ handles2 , labels2 = ax2 .get_legend_handles_labels ()
101
+ ax1 .legend (handles1 + handles2 , labels1 + labels2 , loc = 'lower right' )
82
102
83
- plt .title ('Training and Testing Accuracy and Loss over Epochs' )
84
- fig .tight_layout ()
85
- plt .show ()
103
+ plt .title ('Training and Testing Accuracy and Loss over Epochs' )
104
+ fig .tight_layout ()
105
+ plt .show ()
86
106
87
- print (f'final train, test acc : { acc [- 1 ]} , { tacc [- 1 ]} ' )
88
- print (f'final train, test loss: { loss [- 1 ]} , { tloss [- 1 ]} ' )
107
+ print (f'final train, test acc : { acc [- 1 ]} , { tacc [- 1 ]} ' )
108
+ print (f'final train, test loss: { loss [- 1 ]} , { tloss [- 1 ]} ' )
0 commit comments