Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions NeuralJAXwork/loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Import jit from JAX
from jax import jit
from NeuralJAXwork import Errors
from .errors import Errors

class Loss:
"""
Expand Down Expand Up @@ -37,9 +37,9 @@ def __init__(self, loss, loss_prime):
# If it fails, switch to regular Python interpreator
try:
self.loss_prime = jit(loss_prime)
except:
except Exception:
print(Errors.jit_error)
self.loss = loss_prime
self.loss_prime = loss_prime

def loss(self, y_true, y_pred):
"""
Expand Down
18 changes: 14 additions & 4 deletions NeuralJAXwork/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@ def forward(self, x):
x = layer.forward(x)
return x

def backward(self, grad):
def backward(self, grad, learning_rate):
"""Propagate the gradient through the network.

Parameters
----------
grad: jax.numpy.ndarray
Gradient of the loss with respect to the network output.
learning_rate: float
Learning rate used to update the parameters of each layer.
"""
for layer in reversed(self.layers):
grad = layer.backward(grad)
grad = layer.backward(grad, learning_rate)
return grad

def train(self, x_train, y_train, epochs = 1000, learning_rate = 0.01, verbose = True):
Expand All @@ -27,7 +36,7 @@ def train(self, x_train, y_train, epochs = 1000, learning_rate = 0.01, verbose =

# backward
grad = self.loss.prime(y, output)
self.backward(grad)
self.backward(grad, learning_rate)

error /= len(x_train)

Expand All @@ -38,4 +47,5 @@ def __call__(self, x):
return self.forward(x)

def __repr__(self):
return f'SequentialModel({zip(enumerate(self.layers))})'
layers_repr = ", ".join(layer.__class__.__name__ for layer in self.layers)
return f"SequentialModel([{layers_repr}])"
Loading