Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mmengine/_strategy/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,9 @@ def load_optim_state_dict(self, state_dict: dict) -> None:
``optimizer.state_dict()``
"""
optim_state_dict = FSDP.optim_state_dict_to_load(
state_dict, self.model, self.optim_wrapper.optimizer)
optim_state_dict=state_dict,
model=self.model,
optim=self.optim_wrapper.optimizer)
self.optim_wrapper.load_state_dict(optim_state_dict)

def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None:
Expand Down Expand Up @@ -539,7 +541,9 @@ def build_optim_wrapper(
# Force to load the converted optim_state_dict in full mode.
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
optim_state_dict = FSDP.optim_state_dict_to_load(
optim_state_dict, model, new_optimizer)
optim_state_dict=optim_state_dict,
model=model,
optim=new_optimizer)
new_optimizer.load_state_dict(optim_state_dict)
optim_wrapper.optimizer = new_optimizer
return optim_wrapper
Expand Down
9 changes: 5 additions & 4 deletions mmengine/optim/optimizer/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import warnings
from typing import List, Union

import torch
Expand Down Expand Up @@ -115,7 +116,7 @@ def register_sophia_optimizers() -> List[str]:
Returns:
List[str]: A list of registered optimizers' name.
"""
optimizers = []
optimizers = [] # type: ignore
try:
import Sophia
except ImportError:
Expand All @@ -128,7 +129,7 @@ def register_sophia_optimizers() -> List[str]:
try:
OPTIMIZERS.register_module(module=_optim)
except Exception as e:
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
warnings.warn(f'Failed to import {Sophia} for {e}')
return optimizers


Expand Down Expand Up @@ -161,7 +162,7 @@ def register_bitsandbytes_optimizers() -> List[str]:
try:
OPTIMIZERS.register_module(module=optim_cls, name=name)
except Exception as e:
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
warnings.warn(f'Failed to import {optim_cls.__name__} for {e}')
dadaptation_optimizers.append(name)
return dadaptation_optimizers

Expand All @@ -179,7 +180,7 @@ def register_transformers_optimizers():
try:
OPTIMIZERS.register_module(name='Adafactor', module=Adafactor)
except Exception as e:
warnings.warn(f"Failed to import {optim_cls.__name__} for {e}")
warnings.warn(f'Failed to import Adafactor for {e}')
transformer_optimizers.append('Adafactor')
return transformer_optimizers

Expand Down
12 changes: 8 additions & 4 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ def load_from_local(filename, map_location):
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location, weights_only=False)
checkpoint = torch.load(
filename, map_location=map_location, weights_only=False)
return checkpoint


Expand Down Expand Up @@ -412,7 +413,8 @@ def load_from_pavi(filename, map_location=None):
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location, weights_only=False)
checkpoint = torch.load(
downloaded_file, map_location=map_location, weights_only=False)
return checkpoint


Expand All @@ -435,7 +437,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
file_backend = get_file_backend(
filename, backend_args={'backend': backend})
with io.BytesIO(file_backend.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location, weights_only=False)
checkpoint = torch.load(
buffer, map_location=map_location, weights_only=False)
return checkpoint


Expand Down Expand Up @@ -504,7 +507,8 @@ def load_from_openmmlab(filename, map_location=None):
filename = osp.join(_get_mmengine_home(), model_url)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location, weights_only=False)
checkpoint = torch.load(
filename, map_location=map_location, weights_only=False)
return checkpoint


Expand Down
2 changes: 1 addition & 1 deletion mmengine/utils/package_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def is_installed(package: str) -> bool:
# Therefore, import it in function scope to save time.
import importlib.util

import pkg_resources
import pkg_resources # type: ignore
from pkg_resources import get_distribution

# refresh the pkg_resources
Expand Down
6 changes: 3 additions & 3 deletions tests/test_fileio/test_fileclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from mmengine.utils import has_method

sys.modules['ceph'] = MagicMock()
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
sys.modules['mc'] = MagicMock()


Expand Down Expand Up @@ -289,7 +287,9 @@ def test_disk_backend(self):
osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt'
}

@patch('petrel_client.client.Client', MockPetrelClient)
@patch.dict(
sys.modules,
{'petrel_client': MagicMock(**{'client.Client': MockPetrelClient})})
@pytest.mark.parametrize('backend,prefix', [('petrel', None),
(None, 's3')])
def test_petrel_backend(self, backend, prefix):
Expand Down
32 changes: 23 additions & 9 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,17 @@ def test_with_runner(self, training_type):
cfg = copy.deepcopy(common_cfg)
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertIn('optimizer', ckpt)

cfg.default_hooks.checkpoint.save_optimizer = False
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertNotIn('optimizer', ckpt)

# Test save_param_scheduler=False
Expand All @@ -479,13 +483,17 @@ def test_with_runner(self, training_type):
]
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertIn('param_schedulers', ckpt)

cfg.default_hooks.checkpoint.save_param_scheduler = False
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertNotIn('param_schedulers', ckpt)

self.clear_work_dir()
Expand Down Expand Up @@ -533,7 +541,9 @@ def test_with_runner(self, training_type):
self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))

ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_11.pth'),
weights_only=False)
self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'],
[9, 10, 11])

Expand Down Expand Up @@ -574,9 +584,11 @@ def test_with_runner(self, training_type):
runner.train()
best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
best_ckpt = torch.load(best_ckpt_path)
best_ckpt = torch.load(best_ckpt_path, weights_only=False)

ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_5.pth'),
weights_only=False)
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])

Expand All @@ -603,11 +615,13 @@ def test_with_runner(self, training_type):
runner.train()
best_ckpt_path = osp.join(cfg.work_dir,
f'best_test_acc_{training_type}_5.pth')
best_ckpt = torch.load(best_ckpt_path)
best_ckpt = torch.load(best_ckpt_path, weights_only=False)

# if the current ckpt is the best, the interval will be ignored the
# the ckpt will also be saved
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth'))
ckpt = torch.load(
osp.join(cfg.work_dir, f'{training_type}_5.pth'),
weights_only=False)
self.assertEqual(best_ckpt_path,
ckpt['message_hub']['runtime_info']['best_ckpt'])

Expand Down
18 changes: 13 additions & 5 deletions tests/test_hooks/test_ema_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def test_with_runner(self):
self.assertTrue(
isinstance(ema_hook.ema_model, ExponentialMovingAverage))

checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
checkpoint = torch.load(
osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False)
self.assertTrue('ema_state_dict' in checkpoint)
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)

Expand All @@ -245,7 +246,8 @@ def test_with_runner(self):
runner.test()

# Test load checkpoint without ema_state_dict
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
checkpoint = torch.load(
osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False)
checkpoint.pop('ema_state_dict')
torch.save(checkpoint,
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
Expand Down Expand Up @@ -274,7 +276,9 @@ def test_with_runner(self):
runner = self.build_runner(cfg)
runner.train()
state_dict = torch.load(
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
osp.join(self.temp_dir.name, 'epoch_4.pth'),
map_location='cpu',
weights_only=False)
self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
Expand All @@ -287,12 +291,16 @@ def test_with_runner(self):
runner = self.build_runner(cfg)
runner.train()
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
osp.join(self.temp_dir.name, 'iter_4.pth'),
map_location='cpu',
weights_only=False)
self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
osp.join(self.temp_dir.name, 'iter_5.pth'),
map_location='cpu',
weights_only=False)
self.assertIn('ema_state_dict', state_dict)

def _test_swap_parameters(self, func_name, *args, **kwargs):
Expand Down
29 changes: 17 additions & 12 deletions tests/test_hooks/test_empty_cache_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest.mock import patch

import pytest
Expand All @@ -13,7 +14,7 @@ class TestEmptyCacheHook(RunnerTestCase):
not is_cuda_available(), reason='cuda should be available')
def test_with_runner(self):
with patch('torch.cuda.empty_cache') as mock_empty_cache:
cfg = self.epoch_based_cfg
cfg = deepcopy(self.epoch_based_cfg)
cfg.custom_hooks = [dict(type='EmptyCacheHook')]
cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501
runner = self.build_runner(cfg)
Expand All @@ -24,12 +25,14 @@ def test_with_runner(self):

# Call `torch.cuda.empty_cache` after each epoch:
# runner.train: `max_epochs` times.
# runner.val: last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501
# runner.val: `1` time.
# runner.test: `1` time.
target_called_times = runner.max_epochs + 2
target_called_times = runner.max_epochs + 3
self.assertEqual(mock_empty_cache.call_count, target_called_times)

#
with patch('torch.cuda.empty_cache') as mock_empty_cache:
cfg = deepcopy(self.epoch_based_cfg)
cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)]
runner = self.build_runner(cfg)

Expand All @@ -39,13 +42,15 @@ def test_with_runner(self):

# Call `torch.cuda.empty_cache` after/before each epoch:
# runner.train: `max_epochs*2` times.
# runner.val: `1*2` times.
# runner.val: (max_epochs + 1)*2 times, last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501
# runner.test: `1*2` times.

target_called_times = runner.max_epochs * 2 + 4
target_called_times = runner.max_epochs * 2 + (runner.max_epochs +
1) * 2 + 1 * 2
self.assertEqual(mock_empty_cache.call_count, target_called_times)

with patch('torch.cuda.empty_cache') as mock_empty_cache:
cfg = deepcopy(self.epoch_based_cfg)
cfg.custom_hooks = [
dict(
type='EmptyCacheHook', after_iter=True, before_epoch=True)
Expand All @@ -58,13 +63,13 @@ def test_with_runner(self):

# Call `torch.cuda.empty_cache` after/before each epoch,
# after each iteration:
# runner.train: `max_epochs*2 + len(dataloader)*max_epochs` times. # noqa: E501
# runner.val: `1*2 + len(val_dataloader)` times.
# runner.test: `1*2 + len(val_dataloader)` times.
# runner.train: max_epochs * (2 + len(train_dataloader)) times.
# runner.val: (max_epochs + 1(interval) + 1(last)) * (2 + len(val_dataloader)) times # noqa: E501
# runner.test: 1 * (2 + len(test_dataloader)) times

target_called_times = \
runner.max_epochs * 2 + 4 + \
len(runner.train_dataloader) * runner.max_epochs + \
len(runner.val_dataloader) + \
len(runner.test_dataloader)
runner.max_epochs * (2 + len(runner.train_dataloader)) + \
(runner.max_epochs + 1) * (2 + len(runner.val_dataloader)) + \
1 * (2 + len(runner.test_dataloader))

self.assertEqual(mock_empty_cache.call_count, target_called_times)
Loading
Loading