Skip to content

Commit 99b1f74

Browse files
Sasha Shengfacebook-github-bot
authored andcommitted
[refactor] reporting + meter (#844)
Summary: * Clean the API up to prep for moving some of the report/meter updating inside the base_model for pytorch lightning early stopping/checkpointing. Pull Request resolved: #844 Reviewed By: vedanuj Differential Revision: D27486844 Pulled By: ytsheng fbshipit-source-id: 759fbaccdc1ce2ef2a8e736fa179afeec48b89dc
1 parent 2379d0c commit 99b1f74

File tree

7 files changed

+95
-70
lines changed

7 files changed

+95
-70
lines changed

mmf/common/meter.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from collections import defaultdict, deque
44

55
import torch
6+
from mmf.common.registry import registry
7+
from mmf.utils.distributed import reduce_dict
8+
from mmf.utils.general import scalarize_dict_values
69

710

811
class SmoothedValue:
@@ -55,11 +58,47 @@ def __init__(self, delimiter=", "):
5558
self.meters = defaultdict(SmoothedValue)
5659
self.delimiter = delimiter
5760

58-
def update(self, update_dict, batch_size):
59-
for k, v in update_dict.items():
60-
if isinstance(v, torch.Tensor):
61-
if v.dim() != 0:
62-
v = v.mean()
61+
def update_from_report(self, report, should_update_loss=True):
62+
"""
63+
this method updates the provided meter with report info.
64+
this method by default handles reducing metrics.
65+
66+
Args:
67+
report (Report): report object which content is used to populate
68+
the current meter
69+
70+
Usage::
71+
72+
>>> meter = Meter()
73+
>>> report = Report(prepared_batch, model_output)
74+
>>> meter.update_from_report(report)
75+
"""
76+
if hasattr(report, "metrics"):
77+
metrics_dict = report.metrics
78+
reduced_metrics_dict = reduce_dict(metrics_dict)
79+
80+
if should_update_loss:
81+
loss_dict = report.losses
82+
reduced_loss_dict = reduce_dict(loss_dict)
83+
84+
with torch.no_grad():
85+
meter_update_dict = {}
86+
if should_update_loss:
87+
meter_update_dict = scalarize_dict_values(reduced_loss_dict)
88+
total_loss_key = report.dataset_type + "/total_loss"
89+
total_loss = sum(meter_update_dict.values())
90+
registry.register(total_loss_key, total_loss)
91+
meter_update_dict.update({total_loss_key: total_loss})
92+
93+
if hasattr(report, "metrics"):
94+
metrics_dict = scalarize_dict_values(reduced_metrics_dict)
95+
meter_update_dict.update(**metrics_dict)
96+
97+
self._update(meter_update_dict, report.batch_size)
98+
99+
def _update(self, update_dict, batch_size):
100+
scalarized = scalarize_dict_values(update_dict)
101+
for k, v in scalarized.items():
63102
# Skipping .item() call
64103
# __format__() for tensor has .item
65104
# Therefore it will implicitly get called when needed

mmf/trainers/core/evaluation_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def evaluation_loop(
3939
model_output = self.model(prepared_batch)
4040
report = Report(prepared_batch, model_output)
4141

42-
self.update_meter(report, meter)
42+
meter.update_from_report(report)
4343

4444
# accumulate necessary params for metric calculation
4545
if combined_report is None:
@@ -52,8 +52,8 @@ def evaluation_loop(
5252
)
5353
combined_report.batch_size += report.batch_size
5454

55-
# Each node generates a separate copy of predict JSON from the report,
56-
# which will be used to evaluate dataset-level metrics
55+
# Each node generates a separate copy of predict JSON from the
56+
# report, which will be used to evaluate dataset-level metrics
5757
# (such as mAP in object detection or CIDEr in image captioning)
5858
# Since `reporter.add_to_report` changes report keys (e.g. scores),
5959
# do this after `combined_report.accumulate_tensor_fields_and_loss`
@@ -73,7 +73,7 @@ def evaluation_loop(
7373
combined_report.prediction_report = reporter.report
7474

7575
combined_report.metrics = self.metrics(combined_report, combined_report)
76-
self.update_meter(combined_report, meter, eval_mode=True)
76+
meter.update_from_report(combined_report, should_update_loss=False)
7777

7878
# enable train mode again
7979
self.model.train()

mmf/trainers/core/reporting.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

mmf/trainers/core/training_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Dict
77

88
import torch
9+
from mmf.common.meter import Meter
910
from mmf.common.registry import registry
1011
from mmf.common.report import Report
1112
from mmf.common.sample import to_device
@@ -21,6 +22,7 @@ class TrainerTrainingLoopMixin(ABC):
2122
current_epoch: int = 0
2223
current_iteration: int = 0
2324
num_updates: int = 0
25+
meter: Meter = Meter()
2426

2527
def training_loop(self) -> None:
2628
self.max_updates = self._calculate_max_updates()
@@ -118,7 +120,7 @@ def run_training_epoch(self) -> None:
118120
combined_report.metrics = self.metrics(
119121
combined_report, combined_report
120122
)
121-
self.update_meter(combined_report, self.meter)
123+
self.meter.update_from_report(combined_report)
122124

123125
self.on_update_end(
124126
report=combined_report, meter=self.meter, should_log=should_log

mmf/trainers/mmf_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from mmf.trainers.core.device import TrainerDeviceMixin
1717
from mmf.trainers.core.evaluation_loop import TrainerEvaluationLoopMixin
1818
from mmf.trainers.core.profiling import TrainerProfilingMixin
19-
from mmf.trainers.core.reporting import TrainerReportingMixin
2019
from mmf.trainers.core.training_loop import TrainerTrainingLoopMixin
2120
from mmf.utils.build import build_model, build_optimizer
2221
from mmf.utils.general import print_model_parameters
@@ -32,7 +31,6 @@ class MMFTrainer(
3231
TrainerTrainingLoopMixin,
3332
TrainerDeviceMixin,
3433
TrainerEvaluationLoopMixin,
35-
TrainerReportingMixin,
3634
TrainerProfilingMixin,
3735
BaseTrainer,
3836
):

mmf/utils/general.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import time
1010
import warnings
1111
from bisect import bisect
12-
from typing import Any, Callable
12+
from typing import Any, Callable, Dict
1313

1414
import torch
1515
from mmf.utils.distributed import get_rank, get_world_size, is_xla
1616
from mmf.utils.file_io import PathManager
17-
from torch import nn
17+
from torch import Tensor, nn
1818

1919

2020
logger = logging.getLogger(__name__)
@@ -446,3 +446,20 @@ def retry_n(n: int, fn: Callable, *args, log_tries=False, **kwargs) -> Any:
446446
raise
447447

448448
return output
449+
450+
451+
def scalarize_dict_values(dict_with_tensors: Dict[str, Tensor]):
452+
"""
453+
this method returns a new dict where the values of
454+
`dict_with_tensors` would be a scalar
455+
456+
Returns:
457+
Dict: a new dict with scalarized values
458+
"""
459+
dict_with_scalar_tensors = {}
460+
for key, val in dict_with_tensors.items():
461+
if torch.is_tensor(val):
462+
if val.dim() != 0:
463+
val = val.mean()
464+
dict_with_scalar_tensors[key] = val
465+
return dict_with_scalar_tensors

tests/common/test_meter.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import unittest
3+
4+
import torch
5+
from mmf.common.meter import Meter
6+
from mmf.common.report import Report
7+
from mmf.common.sample import SampleList
8+
9+
10+
class TestMeter(unittest.TestCase):
11+
def test_meter_update_from_report(self):
12+
meter = Meter()
13+
prepared_batch = SampleList(
14+
{"targets": torch.tensor([1, 2, 3, 4]), "dataset_type": "val"}
15+
)
16+
for idx in range(5):
17+
model_output = {
18+
"scores": torch.tensor([0, 1, 2, 3]),
19+
"losses": {"loss": float(idx)},
20+
}
21+
report = Report(prepared_batch, model_output)
22+
meter.update_from_report(report)
23+
24+
self.assertEqual(meter.loss.global_avg, 2.0)
25+
self.assertEqual(meter.loss.avg, 2.0)

0 commit comments

Comments
 (0)