diff --git a/NeuralJAXwork/loss.py b/NeuralJAXwork/loss.py index b0fab4b..9fc582d 100644 --- a/NeuralJAXwork/loss.py +++ b/NeuralJAXwork/loss.py @@ -1,6 +1,6 @@ # Import jit from JAX from jax import jit -from NeuralJAXwork import Errors +from .errors import Errors class Loss: """ @@ -37,9 +37,9 @@ def __init__(self, loss, loss_prime): # If it fails, switch to regular Python interpreator try: self.loss_prime = jit(loss_prime) - except: + except Exception: print(Errors.jit_error) - self.loss = loss_prime + self.loss_prime = loss_prime def loss(self, y_true, y_pred): """ diff --git a/NeuralJAXwork/model.py b/NeuralJAXwork/model.py index 09e1593..c426741 100644 --- a/NeuralJAXwork/model.py +++ b/NeuralJAXwork/model.py @@ -10,9 +10,18 @@ def forward(self, x): x = layer.forward(x) return x - def backward(self, grad): + def backward(self, grad, learning_rate): + """Propagate the gradient through the network. + + Parameters + ---------- + grad: jax.numpy.ndarray + Gradient of the loss with respect to the network output. + learning_rate: float + Learning rate used to update the parameters of each layer. + """ for layer in reversed(self.layers): - grad = layer.backward(grad) + grad = layer.backward(grad, learning_rate) return grad def train(self, x_train, y_train, epochs = 1000, learning_rate = 0.01, verbose = True): @@ -27,7 +36,7 @@ def train(self, x_train, y_train, epochs = 1000, learning_rate = 0.01, verbose = # backward grad = self.loss.prime(y, output) - self.backward(grad) + self.backward(grad, learning_rate) error /= len(x_train) @@ -38,4 +47,5 @@ def __call__(self, x): return self.forward(x) def __repr__(self): - return f'SequentialModel({zip(enumerate(self.layers))})' \ No newline at end of file + layers_repr = ", ".join(layer.__class__.__name__ for layer in self.layers) + return f"SequentialModel([{layers_repr}])"