Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ optimizer:
amsgrad: false

train:
n_epoch: 10
n_epoch: 50
verbose: true
early_stopping_flag: true
early_stopping_patience: 2
early_stopping_delta: 0.1

save:
path_to_folder: 'models/test_main/'
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def main(path_to_config: str):
optimizer=optimizer,
device=device,
n_epoch=config["train"]["n_epoch"],
early_stopping_flag=config["train"]["early_stopping_flag"],
early_stopping_patience=config["train"]["early_stopping_patience"],
early_stopping_delta=config["train"]["early_stopping_delta"],
verbose=config["train"]["verbose"],
)

Expand Down
65 changes: 65 additions & 0 deletions pytorch_ner/pytorchtools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import torch


class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""

def __init__(
self,
patience=7,
verbose=False,
delta=0.05,
path="checkpoint.pt",
trace_func=print,
):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0.05
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
trace_func (function): trace print function.
Default: print
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
self.trace_func = trace_func

def __call__(self, val_loss, model):

score = -val_loss

if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0

def save_checkpoint(self, val_loss, model):
"""Saves model when validation loss decrease."""
if self.verbose:
self.trace_func(
f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
)
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
22 changes: 22 additions & 0 deletions pytorch_ner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from tqdm import tqdm

from pytorch_ner.metrics import calculate_metrics

# import EarlyStopping
from pytorch_ner.pytorchtools import EarlyStopping
from pytorch_ner.utils import to_numpy


Expand Down Expand Up @@ -144,13 +147,22 @@ def train(
optimizer: optim.Optimizer,
device: torch.device,
n_epoch: int,
early_stopping_flag: bool,
early_stopping_patience: int,
early_stopping_delta: float,
testloader: Optional[DataLoader] = None,
verbose: bool = True,
):
"""
Training / validation loop for n_epoch with final testing.
"""

if early_stopping_flag:
# initialize the early_stopping object
early_stopping = EarlyStopping(
patience=early_stopping_patience, delta=early_stopping_delta, verbose=True
)

for epoch in range(n_epoch):

if verbose:
Expand Down Expand Up @@ -183,6 +195,16 @@ def train(
print(f"val {metric_name}: {np.mean(metric_list)}")
print()

if early_stopping_flag:
# early_stopping needs the validation loss to check if it has decresed,
# and if it has, it will make a checkpoint of the current model
valid_loss = np.mean(val_metrics["loss"])
early_stopping(valid_loss, model)

if early_stopping.early_stop:
print("Early stopping")
break

if testloader is not None:

test_metrics = validate_loop(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
optimizer=optimizer,
device=device,
n_epoch=5,
early_stopping_flag=True,
early_stopping_patience=10,
early_stopping_delta=0.05,
verbose=False,
)

Expand Down