Skip to content

Commit 6efd0e9

Browse files
committed
add bidir rnns
1 parent 2be7797 commit 6efd0e9

File tree

5 files changed

+426
-46
lines changed

5 files changed

+426
-46
lines changed

README.md

+27-19
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,25 @@ Overview of Framework
6969
- Basic rnn kernel
7070
- LSTM kernel
7171
- GRU kernel
72-
- BiLSTM kernel <mark>TODO</mark>
73-
- BiGRU kernel <mark>TODO</mark>
72+
- BiLSTM kernel
73+
- BiGRU kernel
74+
- Layer Norm <mark>TODO</mark>
7475
- FC
7576
- Dropout
7677
- Linear
7778
- Optimizer
78-
- Raw GD
79-
- Momentum
80-
- Nesterov(NAG)
81-
- AdaGrad
82-
- RMSProp
83-
- AdaDelta
84-
- Adam[[6](#reference)]
79+
- Algorithms
80+
- Raw GD
81+
- Momentum
82+
- Nesterov(NAG)
83+
- AdaGrad
84+
- RMSProp
85+
- AdaDelta
86+
- Adam[[6](#reference)]
87+
- Machanisms
88+
- Lr Decay. <mark>TODO</mark>
89+
- Weight Decay. <mark>TODO</mark>
90+
- Freeze. <mark>TODO</mark>
8591
- Utils
8692
- sigmoid
8793
- one hot
@@ -91,25 +97,27 @@ Overview of Framework
9197
- l1_regularization
9298
- l2_regularization
9399

94-
#### 🗠 Number of Codes
100+
#### @ Number of Codes
95101

96102
Last update: 2025.03.14.
97103

98104
```text
99-
68 text files.
100-
49 unique files.
101-
49 files ignored.
105+
236 text files.
106+
135 unique files.
107+
138 files ignored.
102108
103-
github.com/AlDanial/cloc v 1.98 T=0.04 s (1365.4 files/s, 303229.6 lines/s)
109+
github.com/AlDanial/cloc v 1.98 T=0.05 s (2810.5 files/s, 307803.1 lines/s)
104110
-------------------------------------------------------------------------------
105111
Language files blank comment code
106112
-------------------------------------------------------------------------------
107-
Jupyter Notebook 21 0 3954 2022
108-
Python 21 1014 1805 1669
109-
Text 6 1 0 295
110-
Markdown 1 19 0 103
113+
Python 33 1689 3297 3177
114+
Jupyter Notebook 21 0 3947 1913
115+
Text 6 1 0 301
116+
CSV 68 0 0 203
117+
Markdown 5 40 0 198
118+
TOML 2 3 0 16
111119
-------------------------------------------------------------------------------
112-
SUM: 49 1034 5759 4089
120+
SUM: 135 1733 7244 5808
113121
-------------------------------------------------------------------------------
114122
```
115123

gru_ucihar.ipynb

+23-26
Large diffs are not rendered by default.

plugins/lrkit

Submodule lrkit updated 35 files

plugins/minitorch/initer.py

+168
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ class Initer:
1010
Supported layer types:
1111
- Basic Recurrent Neural Network (basic_rnn)
1212
- Long Short-Term Memory (lstm)
13+
- Bidirectional Long Short-Term Memory (bilstm)
1314
- Gated Recurrent Unit (gru)
15+
- Bidirectional Gated Recurrent Unit (bigru)
1416
- Fully Connected (fc) layers
1517
- 1D Convolutional layers (conv1d)
1618
- 2D Convolutional layers (conv2d)
@@ -19,7 +21,9 @@ class Initer:
1921
Their name should be like:
2022
- "basic_rnn:"
2123
- "lstm:"
24+
- "bilstm:"
2225
- "gru:"
26+
- "bigru:"
2327
- "fc:"
2428
- "conv1d:"
2529
- "conv2d:"
@@ -31,6 +35,7 @@ class Initer:
3135
'''
3236

3337
SupportLayers = ('basic_rnn', 'lstm', 'gru',
38+
'bilstm', 'bigru',
3439
'fc',
3540
'conv1d', 'conv2d', 'conv3d')
3641

@@ -81,6 +86,169 @@ def _init_param(self, name: str):
8186

8287
return f(name)
8388

89+
def _bilstm(self, name):
90+
'''
91+
Initializes parameters for a bidirectional LSTM layer.
92+
93+
Config should be:
94+
```
95+
name: {
96+
'input_dim': int, # Input dimension
97+
'hidden_dim': int, # Hidden state dimension
98+
'strategy': str, # Initial strategy, including None, Kaiming, Xavier
99+
}
100+
```
101+
102+
Returns:
103+
A dictionary containing:
104+
- 'Ws': Weight matrix for input-to-hidden transformations (8, input_dim, hidden_dim).
105+
- First 4 matrices are for forward direction.
106+
- Last 4 matrices are for backward direction.
107+
- 'Us': Weight matrix for hidden-to-hidden transformations (8, hidden_dim, hidden_dim).
108+
- First 4 matrices are for forward direction.
109+
- Last 4 matrices are for backward direction.
110+
- 'Bs': Bias terms (8, hidden_dim).
111+
- First 4 biases are for forward direction.
112+
- Last 4 biases are for backward direction.
113+
- The forget gate bias (index 0 and 4) is initialized to 1.
114+
'''
115+
116+
match self.config[name]['strategy']:
117+
case 'None':
118+
return {
119+
'Ws': random.normal(self.key, (
120+
8,
121+
self.config[name]['input_dim'],
122+
self.config[name]['hidden_dim'],
123+
)),
124+
'Us': random.normal(self.key, (
125+
8,
126+
self.config[name]['hidden_dim'],
127+
self.config[name]['hidden_dim'],
128+
)),
129+
'Bs': jnp.zeros((
130+
8,
131+
self.config[name]['hidden_dim']
132+
)).at[0].set(1).at[4].set(1), # Initialize forget gate biases to 1
133+
}
134+
case 'Kaiming':
135+
return {
136+
'Ws': random.normal(self.key, (
137+
8,
138+
self.config[name]['input_dim'],
139+
self.config[name]['hidden_dim'],
140+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'])), # Kaiming
141+
'Us': random.normal(self.key, (
142+
8,
143+
self.config[name]['hidden_dim'],
144+
self.config[name]['hidden_dim'],
145+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'])),
146+
'Bs': jnp.zeros((
147+
8,
148+
self.config[name]['hidden_dim']
149+
)).at[0].set(1).at[4].set(1), # Initialize forget gate biases to 1
150+
}
151+
case 'Xavier':
152+
return {
153+
'Ws': random.normal(self.key, (
154+
8,
155+
self.config[name]['input_dim'],
156+
self.config[name]['hidden_dim'],
157+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'] + self.config[name]['hidden_dim'])), # Xavier
158+
'Us': random.normal(self.key, (
159+
8,
160+
self.config[name]['hidden_dim'],
161+
self.config[name]['hidden_dim'],
162+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'] + self.config[name]['hidden_dim'])),
163+
'Bs': jnp.zeros((
164+
8,
165+
self.config[name]['hidden_dim']
166+
)).at[0].set(1).at[4].set(1), # Initialize forget gate biases to 1
167+
}
168+
case _:
169+
raise ValueError(f'[x] Do not support strategy: {name["strategy"]} given by {name}.')
170+
171+
def _bigru(self, name):
172+
'''
173+
Initializes parameters for a bidirectional GRU layer.
174+
175+
Config should be:
176+
```
177+
name: {
178+
'input_dim': int, # Input dimension
179+
'hidden_dim': int, # Hidden state dimension
180+
'strategy': str, # Initial strategy, including None, Kaiming, Xavier
181+
}
182+
```
183+
184+
Returns:
185+
A dictionary containing:
186+
- 'Ws': Weight matrix for input-to-hidden transformations (6, input_dim, hidden_dim).
187+
- First 3 matrices are for forward direction.
188+
- Last 3 matrices are for backward direction.
189+
- 'Us': Weight matrix for hidden-to-hidden transformations (6, hidden_dim, hidden_dim).
190+
- First 3 matrices are for forward direction.
191+
- Last 3 matrices are for backward direction.
192+
- 'Bs': Bias terms (6, hidden_dim).
193+
- First 3 biases are for forward direction.
194+
- Last 3 biases are for backward direction.
195+
'''
196+
197+
match self.config[name]['strategy']:
198+
case 'None':
199+
return {
200+
'Ws': random.normal(self.key, (
201+
6,
202+
self.config[name]['input_dim'],
203+
self.config[name]['hidden_dim'],
204+
)),
205+
'Us': random.normal(self.key, (
206+
6,
207+
self.config[name]['hidden_dim'],
208+
self.config[name]['hidden_dim'],
209+
)),
210+
'Bs': jnp.zeros((
211+
6,
212+
self.config[name]['hidden_dim']
213+
)),
214+
}
215+
case 'Kaiming':
216+
return {
217+
'Ws': random.normal(self.key, (
218+
6,
219+
self.config[name]['input_dim'],
220+
self.config[name]['hidden_dim'],
221+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'])), # Kaiming
222+
'Us': random.normal(self.key, (
223+
6,
224+
self.config[name]['hidden_dim'],
225+
self.config[name]['hidden_dim'],
226+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'])),
227+
'Bs': jnp.zeros((
228+
6,
229+
self.config[name]['hidden_dim']
230+
)),
231+
}
232+
case 'Xavier':
233+
return {
234+
'Ws': random.normal(self.key, (
235+
6,
236+
self.config[name]['input_dim'],
237+
self.config[name]['hidden_dim'],
238+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'] + self.config[name]['hidden_dim'])), # Xavier
239+
'Us': random.normal(self.key, (
240+
6,
241+
self.config[name]['hidden_dim'],
242+
self.config[name]['hidden_dim'],
243+
)) * jnp.sqrt(2 / (self.config[name]['input_dim'] + self.config[name]['hidden_dim'])),
244+
'Bs': jnp.zeros((
245+
6,
246+
self.config[name]['hidden_dim']
247+
)),
248+
}
249+
case _:
250+
raise ValueError(f'[x] Do not support strategy: {name["strategy"]} given by {name}.')
251+
84252
def _basic_rnn(self, name):
85253
'''
86254
Initializes parameters for a basic RNN layer.

0 commit comments

Comments
 (0)