Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lambeq/training/nelder_mead_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def objective(self, x: Iterable[Any], y: ArrayLike, w: ArrayLike) -> float:
raise ValueError(
'Objective function must return a scalar'
) from e
return result # type: ignore[return-value]
return result

def backward(self, batch: tuple[Iterable[Any], np.ndarray]) -> float:
"""Calculate the gradients of the loss function.
Expand Down
7 changes: 4 additions & 3 deletions lambeq/training/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def validation_step(
with torch.no_grad():
y_hat = self.model(x)
loss = self.loss_function(y_hat, y.to(self.device))
return y_hat, loss.item()
return y_hat.detach(), loss.detach().item()

def training_step(
self,
Expand All @@ -196,8 +196,9 @@ def training_step(
x, y = batch
y_hat = self.model(x)
loss = self.loss_function(y_hat, y.to(self.device))
self.train_costs.append(loss.item())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return y_hat, loss.item()
loss_item = loss.detach().item()
self.train_costs.append(loss_item)
return y_hat.detach(), loss_item
Loading