diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 0788fafdab..e25e9e2a77 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,9 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) + optim_state_dict=state_dict, + model=self.model, + optim=self.optim_wrapper.optimizer) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: @@ -539,7 +541,9 @@ def build_optim_wrapper( # Force to load the converted optim_state_dict in full mode. with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): optim_state_dict = FSDP.optim_state_dict_to_load( - optim_state_dict, model, new_optimizer) + optim_state_dict=optim_state_dict, + model=model, + optim=new_optimizer) new_optimizer.load_state_dict(optim_state_dict) optim_wrapper.optimizer = new_optimizer return optim_wrapper diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index fef95f729a..6c169924b8 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import inspect +import warnings from typing import List, Union import torch @@ -115,7 +116,7 @@ def register_sophia_optimizers() -> List[str]: Returns: List[str]: A list of registered optimizers' name. """ - optimizers = [] + optimizers = [] # type: ignore try: import Sophia except ImportError: @@ -128,7 +129,7 @@ def register_sophia_optimizers() -> List[str]: try: OPTIMIZERS.register_module(module=_optim) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f'Failed to import {Sophia} for {e}') return optimizers @@ -161,7 +162,7 @@ def register_bitsandbytes_optimizers() -> List[str]: try: OPTIMIZERS.register_module(module=optim_cls, name=name) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f'Failed to import {optim_cls.__name__} for {e}') dadaptation_optimizers.append(name) return dadaptation_optimizers @@ -179,7 +180,7 @@ def register_transformers_optimizers(): try: OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f'Failed to import Adafactor for {e}') transformer_optimizers.append('Adafactor') return transformer_optimizers diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index f5556ed236..d55e6d6c3a 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -344,7 +344,8 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location, weights_only=False) + checkpoint = torch.load( + filename, map_location=map_location, weights_only=False) return checkpoint @@ -412,7 +413,8 @@ def load_from_pavi(filename, map_location=None): with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location, weights_only=False) + checkpoint = torch.load( + downloaded_file, map_location=map_location, weights_only=False) return checkpoint @@ -435,7 +437,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): file_backend = get_file_backend( filename, backend_args={'backend': backend}) with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location, weights_only=False) + checkpoint = torch.load( + buffer, map_location=map_location, weights_only=False) return checkpoint @@ -504,7 +507,8 @@ def load_from_openmmlab(filename, map_location=None): filename = osp.join(_get_mmengine_home(), model_url) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location, weights_only=False) + checkpoint = torch.load( + filename, map_location=map_location, weights_only=False) return checkpoint diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 1816f47f07..452bbaddaa 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -14,7 +14,7 @@ def is_installed(package: str) -> bool: # Therefore, import it in function scope to save time. import importlib.util - import pkg_resources + import pkg_resources # type: ignore from pkg_resources import get_distribution # refresh the pkg_resources diff --git a/tests/test_fileio/test_fileclient.py b/tests/test_fileio/test_fileclient.py index 345832a026..bd3b601eaa 100644 --- a/tests/test_fileio/test_fileclient.py +++ b/tests/test_fileio/test_fileclient.py @@ -16,8 +16,6 @@ from mmengine.utils import has_method sys.modules['ceph'] = MagicMock() -sys.modules['petrel_client'] = MagicMock() -sys.modules['petrel_client.client'] = MagicMock() sys.modules['mc'] = MagicMock() @@ -289,7 +287,9 @@ def test_disk_backend(self): osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' } - @patch('petrel_client.client.Client', MockPetrelClient) + @patch.dict( + sys.modules, + {'petrel_client': MagicMock(**{'client.Client': MockPetrelClient})}) @pytest.mark.parametrize('backend,prefix', [('petrel', None), (None, 's3')]) def test_petrel_backend(self, backend, prefix): diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d731a42b76..13914341f7 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -458,13 +458,17 @@ def test_with_runner(self, training_type): cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('optimizer', ckpt) cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('optimizer', ckpt) # Test save_param_scheduler=False @@ -479,13 +483,17 @@ def test_with_runner(self, training_type): ] runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('param_schedulers', ckpt) cfg.default_hooks.checkpoint.save_param_scheduler = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('param_schedulers', ckpt) self.clear_work_dir() @@ -533,7 +541,9 @@ def test_with_runner(self, training_type): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'], [9, 10, 11]) @@ -574,9 +584,11 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) @@ -603,11 +615,13 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) # if the current ckpt is the best, the interval will be ignored the # the ckpt will also be saved - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 6dad7ba4f0..9467da45dc 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -230,7 +230,8 @@ def test_with_runner(self): self.assertTrue( isinstance(ema_hook.ema_model, ExponentialMovingAverage)) - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load( + osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False) self.assertTrue('ema_state_dict' in checkpoint) self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) @@ -245,7 +246,8 @@ def test_with_runner(self): runner.test() # Test load checkpoint without ema_state_dict - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load( + osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False) checkpoint.pop('ema_state_dict') torch.save(checkpoint, osp.join(self.temp_dir.name, 'without_ema_state_dict.pth')) @@ -274,7 +276,9 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load( - osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'epoch_4.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) @@ -287,12 +291,16 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'iter_4.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') + osp.join(self.temp_dir.name, 'iter_5.pth'), + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) def _test_swap_parameters(self, func_name, *args, **kwargs): diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index d30972d360..69a2fd27af 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy from unittest.mock import patch import pytest @@ -13,7 +14,7 @@ class TestEmptyCacheHook(RunnerTestCase): not is_cuda_available(), reason='cuda should be available') def test_with_runner(self): with patch('torch.cuda.empty_cache') as mock_empty_cache: - cfg = self.epoch_based_cfg + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [dict(type='EmptyCacheHook')] cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501 runner = self.build_runner(cfg) @@ -24,12 +25,14 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after each epoch: # runner.train: `max_epochs` times. + # runner.val: last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501 # runner.val: `1` time. # runner.test: `1` time. - target_called_times = runner.max_epochs + 2 + target_called_times = runner.max_epochs + 3 self.assertEqual(mock_empty_cache.call_count, target_called_times) - + # with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)] runner = self.build_runner(cfg) @@ -39,13 +42,15 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after/before each epoch: # runner.train: `max_epochs*2` times. - # runner.val: `1*2` times. + # runner.val: (max_epochs + 1)*2 times, last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501 # runner.test: `1*2` times. - target_called_times = runner.max_epochs * 2 + 4 + target_called_times = runner.max_epochs * 2 + (runner.max_epochs + + 1) * 2 + 1 * 2 self.assertEqual(mock_empty_cache.call_count, target_called_times) with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [ dict( type='EmptyCacheHook', after_iter=True, before_epoch=True) @@ -58,13 +63,13 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after/before each epoch, # after each iteration: - # runner.train: `max_epochs*2 + len(dataloader)*max_epochs` times. # noqa: E501 - # runner.val: `1*2 + len(val_dataloader)` times. - # runner.test: `1*2 + len(val_dataloader)` times. + # runner.train: max_epochs * (2 + len(train_dataloader)) times. + # runner.val: (max_epochs + 1(interval) + 1(last)) * (2 + len(val_dataloader)) times # noqa: E501 + # runner.test: 1 * (2 + len(test_dataloader)) times target_called_times = \ - runner.max_epochs * 2 + 4 + \ - len(runner.train_dataloader) * runner.max_epochs + \ - len(runner.val_dataloader) + \ - len(runner.test_dataloader) + runner.max_epochs * (2 + len(runner.train_dataloader)) + \ + (runner.max_epochs + 1) * (2 + len(runner.val_dataloader)) + \ + 1 * (2 + len(runner.test_dataloader)) + self.assertEqual(mock_empty_cache.call_count, target_called_times) diff --git a/tests/test_hooks/test_sync_buffers_hook.py b/tests/test_hooks/test_sync_buffers_hook.py index 6d4019dc58..71db44e38a 100644 --- a/tests/test_hooks/test_sync_buffers_hook.py +++ b/tests/test_hooks/test_sync_buffers_hook.py @@ -1,15 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os from unittest.mock import MagicMock import torch import torch.distributed as torch_dist import torch.nn as nn +from torch.testing._internal.common_distributed import DistributedTestBase from mmengine.dist import all_gather from mmengine.hooks import SyncBuffersHook from mmengine.registry import MODELS -from mmengine.testing._internal import MultiProcessTestCase from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel @@ -23,22 +22,14 @@ def __init__(self, data_preprocessor=None): def init_weights(self): for buffer in self.buffers(): buffer.fill_( - torch.tensor(int(os.environ['RANK']), dtype=torch.float32)) + torch.tensor(torch_dist.get_rank(), dtype=torch.float32)) return super().init_weights() -class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase): - - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - def prepare_subprocess(self): - MODELS.register_module(module=ToyModuleWithNorm, force=True) - super(MultiProcessTestCase, self).setUp() +class TestSyncBuffersHook(DistributedTestBase, RunnerTestCase): def test_sync_buffers_hook(self): - self.setup_dist_env() + self.create_pg('cuda') runner = MagicMock() runner.model = ToyModuleWithNorm() runner.model.init_weights() @@ -53,9 +44,12 @@ def test_sync_buffers_hook(self): for buffer in runner.model.buffers(): buffer1, buffer2 = all_gather(buffer) self.assertTrue(torch.allclose(buffer1, buffer2)) + torch_dist.destroy_process_group() def test_with_runner(self): - self.setup_dist_env() + MODELS.register_module(module=ToyModuleWithNorm, force=True) + self.create_pg('cuda') + RunnerTestCase.setUp(self) cfg = self.epoch_based_cfg cfg.model = dict(type='ToyModuleWithNorm') cfg.launch = 'pytorch' @@ -67,8 +61,6 @@ def test_with_runner(self): buffer1, buffer2 = all_gather(buffer) self.assertTrue(torch.allclose(buffer1, buffer2)) - def setup_dist_env(self): - super().setup_dist_env() - os.environ['RANK'] = str(self.rank) - torch_dist.init_process_group( - backend='gloo', rank=self.rank, world_size=self.world_size) + @property + def world_size(self) -> int: + return 2 diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index 4655a4c5da..5954896158 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -354,6 +354,7 @@ def load_from_abc(filename, map_location): assert loader.__name__ == 'load_from_abc' +@patch.dict(sys.modules, {'petrel_client': MagicMock()}) def test_save_checkpoint(tmp_path): model = Model() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e7668054bb..7e105f0895 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -2272,7 +2272,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 3) self.assertEqual(ckpt['meta']['iter'], 12) self.assertEqual(ckpt['meta']['experiment_name'], @@ -2444,7 +2444,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 0) self.assertEqual(ckpt['meta']['iter'], 12) assert isinstance(ckpt['optimizer'], dict) @@ -2455,7 +2455,7 @@ def test_checkpoint(self): self.assertEqual(message_hub.get_info('iter'), 11) # 2.1.2 check class attribute _statistic_methods can be saved HistoryBuffer._statistics_methods.clear() - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertIn('min', HistoryBuffer._statistics_methods) # 2.2 test `load_checkpoint` diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py index bed91b6c18..11ce294c29 100644 --- a/tests/test_utils/test_package_utils.py +++ b/tests/test_utils/test_package_utils.py @@ -2,7 +2,7 @@ import os.path as osp import sys -import pkg_resources +import pkg_resources # type: ignore import pytest from mmengine.utils import get_installed_path, is_installed