|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +import torch.distributed as dist |
| 4 | +from packaging.version import Version |
| 5 | +from torch.optim import Adam |
| 6 | +from utils import shared_tempdir |
| 7 | +from copy import deepcopy |
| 8 | + |
| 9 | +import colossalai |
| 10 | +from colossalai.booster import Booster |
| 11 | +from colossalai.booster.plugin import HybridParallelPlugin |
| 12 | +from colossalai.shardformer.layer.utils import Randomizer |
| 13 | +from colossalai.tensor.d_tensor.api import clear_layout_converter |
| 14 | +from colossalai.checkpoint_io import DistributedCheckpointIO |
| 15 | +from colossalai.testing import ( |
| 16 | + assert_close_loose, |
| 17 | + check_state_dict_equal, |
| 18 | + clear_cache_before_run, |
| 19 | + parameterize, |
| 20 | + rerun_if_address_is_in_use, |
| 21 | + spawn, |
| 22 | +) |
| 23 | +from tests.kit.model_zoo import model_zoo |
| 24 | + |
| 25 | + |
| 26 | +TEST_CONFIGS = [ |
| 27 | + ({"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, |
| 28 | + {"tp_size": 2, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1},) |
| 29 | +] |
| 30 | + |
| 31 | + |
| 32 | +@parameterize("shard", [False, True]) |
| 33 | +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) |
| 34 | +@parameterize("size_per_shard", [1]) |
| 35 | +@parameterize("test_config", TEST_CONFIGS) |
| 36 | +@parameterize("use_async", [False, True]) |
| 37 | +@parameterize("low_cpu_mem_mode", [False, True]) |
| 38 | +@clear_cache_before_run() |
| 39 | +def exam_state_dict( |
| 40 | + shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool |
| 41 | +): |
| 42 | + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( |
| 43 | + iter(model_zoo.get_sub_registry(model_name).values()) |
| 44 | + ) |
| 45 | + criterion = loss_fn |
| 46 | + test_config_0, test_config_1 = test_config |
| 47 | + plugin_0 = HybridParallelPlugin(**test_config_0) |
| 48 | + booster_0 = Booster(plugin=plugin_0) |
| 49 | + hybrid_ckp_0 = booster_0.checkpoint_io |
| 50 | + booster_0.checkpoint_io = DistributedCheckpointIO(hybrid_ckp_0.global_dp_group, hybrid_ckp_0.pp_group, hybrid_ckp_0.tp_group, hybrid_ckp_0.sp_group, hybrid_ckp_0.use_zero) |
| 51 | + |
| 52 | + def _criterion(outputs, inputs): |
| 53 | + outputs = output_transform_fn(outputs) |
| 54 | + loss = criterion(outputs) |
| 55 | + return loss |
| 56 | + |
| 57 | + def _preprocess_data(data): |
| 58 | + if booster_0.plugin.stage_manager is not None: |
| 59 | + for k, v in data.items(): |
| 60 | + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: |
| 61 | + new_shape = [1] * v.dim() |
| 62 | + new_shape[0] = 4 |
| 63 | + data[k] = v.to("cuda").repeat(*new_shape) |
| 64 | + return iter([data]) |
| 65 | + else: |
| 66 | + return {k: v.cuda() for k, v in data.items()} |
| 67 | + |
| 68 | + model_0 = model_fn().cuda() |
| 69 | + optimizer_0 = Adam(model_0.parameters(), lr=1e-3) |
| 70 | + model_0, optimizer_0, criterion, _, _ = booster_0.boost(model_0, optimizer_0, criterion) |
| 71 | + |
| 72 | + data = data_gen_fn() |
| 73 | + model_0.train() |
| 74 | + if booster_0.plugin.stage_manager is not None: |
| 75 | + booster_0.execute_pipeline(_preprocess_data(data), model_0, _criterion, optimizer_0, return_loss=True) |
| 76 | + else: |
| 77 | + output = model_0(**_preprocess_data(data)) |
| 78 | + loss = criterion(output) |
| 79 | + optimizer_0.backward(loss) |
| 80 | + |
| 81 | + optimizer_0.step() |
| 82 | + optimizer_0.zero_grad() |
| 83 | + with shared_tempdir() as tempdir: |
| 84 | + model_ckpt_path_0 = f"{tempdir}/model_0" |
| 85 | + |
| 86 | + booster_0.save_model(model_0, model_ckpt_path_0, shard=shard, size_per_shard=size_per_shard, use_async=use_async) |
| 87 | + booster_0.checkpoint_io._sync_d2h() |
| 88 | + booster_0.checkpoint_io._sync_io() |
| 89 | + dist.barrier() |
| 90 | + |
| 91 | + plugin_1 = HybridParallelPlugin(**test_config_1) |
| 92 | + booster_1 = Booster(plugin=plugin_1) |
| 93 | + hybrid_ckp_1 = booster_1.checkpoint_io |
| 94 | + booster_1.checkpoint_io = DistributedCheckpointIO(hybrid_ckp_1.global_dp_group, hybrid_ckp_1.pp_group, hybrid_ckp_1.tp_group, hybrid_ckp_1.sp_group, hybrid_ckp_1.use_zero) |
| 95 | + |
| 96 | + model_1 = model_fn().cuda() |
| 97 | + optimizer_1 = Adam(model_1.parameters(), lr=1e-3) |
| 98 | + model_1, optimizer_1, criterion, _, _ = booster_1.boost(model_1, optimizer_1, criterion) |
| 99 | + |
| 100 | + booster_1.load_model(model_1, model_ckpt_path_0, low_cpu_mem_mode=low_cpu_mem_mode) |
| 101 | + |
| 102 | + model_ckpt_path_1 = f"{tempdir}/model_1" |
| 103 | + booster_1.save_model(model_1, model_ckpt_path_1, shard=shard, size_per_shard=size_per_shard, use_async=use_async) |
| 104 | + booster_1.checkpoint_io._sync_d2h() |
| 105 | + booster_1.checkpoint_io._sync_io() |
| 106 | + dist.barrier() |
| 107 | + |
| 108 | + model_2 = model_fn().cuda() |
| 109 | + optimizer_2 = Adam(model_2.parameters(), lr=1e-3) |
| 110 | + model_2, optimizer_2, criterion, _, _ = booster_0.boost(model_2, optimizer_2, criterion) |
| 111 | + |
| 112 | + booster_0.load_model(model_2, model_ckpt_path_1, low_cpu_mem_mode=low_cpu_mem_mode) |
| 113 | + check_state_dict_equal(model_0.unwrap().state_dict(), model_2.unwrap().state_dict()) |
| 114 | + |
| 115 | + dist.barrier() |
| 116 | + Randomizer.reset_index() |
| 117 | + clear_layout_converter() |
| 118 | + |
| 119 | + |
| 120 | +def run_dist(rank, world_size, port): |
| 121 | + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
| 122 | + exam_state_dict() |
| 123 | + |
| 124 | + |
| 125 | +@pytest.mark.dist |
| 126 | +@pytest.mark.parametrize("world_size", [4]) |
| 127 | +@rerun_if_address_is_in_use() |
| 128 | +def test_hybrid_ckpIO(world_size): |
| 129 | + spawn(run_dist, world_size) |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + test_hybrid_ckpIO(4) |
0 commit comments