This repository was archived by the owner on Apr 17, 2023. It is now read-only.
forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathrunners.py
77 lines (64 loc) · 3.41 KB
/
runners.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
import time
from mmcv.runner import EpochBasedRunner
from mmcv.runner import RUNNERS
from mmcv.runner.utils import get_host_info
from .utils import check_nncf_is_enabled
@RUNNERS.register_module()
class AccuracyAwareRunner(EpochBasedRunner):
"""
An mmdet training runner to be used with NNCF-based accuracy-aware training.
Inherited from the standard EpochBasedRunner with the overridden "run" method.
This runner does not use the "workflow" and "max_epochs" parameters that are
used by the EpochBasedRunner since the training is controlled by NNCF's
AdaptiveCompressionTrainingLoop that does the scheduling of the compression-aware
training loop using the parameters specified in the "accuracy_aware_training".
"""
def __init__(self, *args, target_metric_name, **kwargs):
super().__init__(*args, **kwargs)
self.target_metric_name = target_metric_name
def run(self, data_loaders, *args, compression_ctrl=None,
nncf_config=None, configure_optimizers_fn=None):
check_nncf_is_enabled()
from nncf.common.accuracy_aware_training import create_accuracy_aware_training_loop
assert isinstance(data_loaders, list)
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.warning('Note that the workflow and max_epochs parameters '
'are not used in NNCF-based accuracy-aware training')
acc_aware_training_loop = create_accuracy_aware_training_loop(nncf_config,
compression_ctrl,
verbose=False)
# taking only the first data loader for NNCF training
self.train_data_loader = data_loaders[0]
# Maximum possible number of iterations, needs for progress tracking
self._max_epochs = acc_aware_training_loop.runner.maximal_total_epochs
self._max_iters = self._max_epochs * len(self.train_data_loader)
self.call_hook('before_run')
model = acc_aware_training_loop.run(self.model,
train_epoch_fn=self.train_fn,
validate_fn=self.validation_fn,
configure_optimizers_fn=configure_optimizers_fn)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
return model
def train_fn(self, *args, **kwargs):
"""
Train the model for a single epoch.
This method is used in NNCF-based accuracy-aware training.
"""
self.train(self.train_data_loader)
def validation_fn(self, *args, **kwargs):
"""
Return the target metric value on the validation dataset.
Evaluation is assumed to be already done at this point since EvalHook was called.
This method is used in NNCF-based accuracy-aware training.
"""
# Get metric from runner's attributes that set in EvalHook.evaluate() function
metric = getattr(self, self.target_metric_name, None)
if metric is None:
raise RuntimeError(f'Could not find the {self.target_metric_name} key')
return metric