From f98762d5fb9a275af728fa88901f3a0647c9121f Mon Sep 17 00:00:00 2001 From: Jingchang Zhang Date: Wed, 12 Mar 2025 21:58:57 -0700 Subject: [PATCH] Add support for compile optimizer step (#2810) Summary: This diff adds support for compiling the optimizer step in the train pipeline of TorchRec. The changes include: 1. Add a new argument to the TrainPipeline constructor to enable compilation of the optimizer step, and modify the optimizer step to be compiled if the argument is set to True. 2. The test_equal_to_non_pipelined_compiled test cases are also added to ensure that the compiled optimizer step produces the same results as the non-compiled version. Differential Revision: D71031049 --- .../tests/test_train_pipelines.py | 431 +++++++++++++++++- .../tests/test_train_pipelines_base.py | 2 + .../train_pipeline/train_pipelines.py | 67 ++- 3 files changed, 482 insertions(+), 18 deletions(-) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 03aa3ea96..53b4b3b35 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -46,6 +46,7 @@ from torchrec.distributed.tests.test_fp_embeddingbag_utils import ( create_module_and_freeze, ) +from torchrec.distributed.train_pipeline import TorchCompileConfig from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import ( TrainPipelineSparseDistTestBase, ) @@ -134,6 +135,7 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool class TrainPipelineBaseTest(unittest.TestCase): def setUp(self) -> None: self.device = torch.device("cuda:0") + self.optimizer_compile_config = TorchCompileConfig() torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False @@ -156,7 +158,43 @@ def test_equal_to_non_pipelined(self) -> None: for b in range(5) ] dataloader = iter(data) - pipeline = TrainPipelineBase(model_gpu, optimizer_gpu, self.device) + pipeline = TrainPipelineBase( + model_gpu, optimizer=optimizer_gpu, device=self.device + ) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + # Results will be close but not exactly equal as one model is on CPU and other on GPU + # If both were on GPU, the results will be exactly the same + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + def test_equal_to_non_pipelined_compiled(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelineBase( + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + optimizer_compile_config=self.optimizer_compile_config, + ) for batch in data[:-1]: optimizer_cpu.zero_grad() @@ -175,6 +213,7 @@ def test_equal_to_non_pipelined(self) -> None: class TrainPipelinePT2Test(unittest.TestCase): def setUp(self) -> None: self.device = torch.device("cuda:0") + self.optimizer_compile_config = TorchCompileConfig() torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False @@ -234,7 +273,41 @@ def test_equal_to_non_pipelined(self) -> None: for b in range(5) ] dataloader = iter(data) - pipeline = TrainPipelinePT2(model_gpu, optimizer_gpu, self.device) + pipeline = TrainPipelinePT2( + model_gpu, optimizer=optimizer_gpu, device=self.device + ) + + for batch in data[:-1]: + optimizer_cpu.zero_grad() + loss, pred = model_cpu(batch) + loss.backward() + optimizer_cpu.step() + + pred_gpu = pipeline.progress(dataloader) + + self.assertEqual(pred_gpu.device, self.device) + self.assertTrue(torch.isclose(pred_gpu.cpu(), pred)) + + def test_equal_to_non_pipelined_compiled(self) -> None: + model_cpu = TestModule() + model_gpu = TestModule().to(self.device) + model_gpu.load_state_dict(model_cpu.state_dict()) + optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01) + optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) + data = [ + ModelInputSimple( + float_features=torch.rand((10,)), + label=torch.randint(2, (1,), dtype=torch.float32), + ) + for b in range(5) + ] + dataloader = iter(data) + pipeline = TrainPipelinePT2( + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + optimizer_compile_config=self.optimizer_compile_config, + ) for batch in data[:-1]: optimizer_cpu.zero_grad() @@ -271,7 +344,10 @@ def pre_compile_fn(model: nn.Module) -> None: dataloader = iter(data) pipeline = TrainPipelinePT2( - model_gpu, optimizer_gpu, self.device, pre_compile_fn=pre_compile_fn + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + pre_compile_fn=pre_compile_fn, ) self.assertEqual(model_gpu._dummy_setting, "dummy") for _ in range(len(data)): @@ -315,7 +391,10 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: ] dataloader = iter(data) pipeline = TrainPipelinePT2( - model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing + model=model_gpu, + optimizer=optimizer_gpu, + device=self.device, + input_transformer=kjt_for_pt2_tracing, ) for batch in data[:-1]: @@ -545,6 +624,83 @@ def test_equal_to_non_pipelined( self.assertRaises(StopIteration, pipeline.progress, dataloader) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ] + ), + execute_all_batches=st.booleans(), + ) + def test_equal_to_non_pipelined_compiled( + self, + sharding_type: str, + kernel_type: str, + execute_all_batches: bool, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = {} + fused_params_pipelined = {} + + model = self._setup_model() + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = self.pipeline_class( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=execute_all_batches, + optimizer_compile_config=self.optimizer_compile_config, + ) + if not execute_all_batches: + data = data[:-2] + + for batch in data: + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + torch.testing.assert_close(pred, pred_pipeline) + + self.assertRaises(StopIteration, pipeline.progress, dataloader) + @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU", @@ -1669,6 +1825,149 @@ def test_equal_to_non_pipelined( pred_pipeline = pipeline.progress(dataloader) self.assertRaises(StopIteration, pipeline.progress, dataloader) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=8, deadline=None) + # pyre-ignore[56] + @given( + start_batch=st.sampled_from([0, 6]), + stash_gradients=st.booleans(), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + ] + ), + zch=st.booleans(), + ) + def test_equal_to_non_pipelined_compiled( + self, + start_batch: int, + stash_gradients: bool, + sharding_type: str, + kernel_type: str, + zch: bool, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + # ZCH only supports row-wise currently + assume(not zch or (zch and sharding_type != ShardingType.TABLE_WISE.value)) + torch.autograd.set_detect_anomaly(True) + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = { + "stochastic_rounding": False, + } + fused_params_pipelined = { + **fused_params, + } + + model = self._setup_model(zch=zch) + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = TrainPipelineSemiSync( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=True, + start_batch=start_batch, + stash_gradients=stash_gradients, + optimizer_compile_config=self.optimizer_compile_config, + ) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `sparse_forward`. + prior_sparse_out = sharded_model._dmp_wrapped_module.sparse_forward( + data[0].to(self.device) + ) + prior_batch = data[0].to(self.device) + prior_stashed_grads = None + batch_index = 0 + sparse_out = None + for batch in data[1:]: + batch_index += 1 + # Forward + backward w/o pipelining + batch = batch.to(self.device) + + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute + # `dense_forward`. + loss, pred = sharded_model._dmp_wrapped_module.dense_forward( + prior_batch, prior_sparse_out + ) + if batch_index - 1 >= start_batch: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse_forward`. + sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch) + + loss.backward() + + stashed_grads = None + if batch_index - 1 >= start_batch and stash_gradients: + stashed_grads = [] + for param in optim.param_groups[0]["params"]: + stashed_grads.append( + param.grad.clone() if param.grad is not None else None + ) + param.grad = None + + if prior_stashed_grads is not None: + for param, stashed_grad in zip( + optim.param_groups[0]["params"], prior_stashed_grads + ): + param.grad = stashed_grad + optim.step() + optim.zero_grad() + + if batch_index - 1 < start_batch: + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no + # attribute `sparse_forward`. + sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch) + + prior_stashed_grads = stashed_grads + prior_batch = batch + prior_sparse_out = sparse_out + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + + if batch_index >= start_batch: + self.assertTrue( + pipeline.is_semi_sync(), msg="pipeline is not semi_sync" + ) + else: + self.assertFalse(pipeline.is_semi_sync(), msg="pipeline is semi_sync") + self.assertTrue( + torch.equal(pred, pred_pipeline), + msg=f"batch {batch_index} doesn't match", + ) + + # one more batch + pred_pipeline = pipeline.progress(dataloader) + self.assertRaises(StopIteration, pipeline.progress, dataloader) + class PrefetchTrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): @unittest.skipIf( @@ -1783,6 +2082,119 @@ def test_equal_to_non_pipelined( else: torch.testing.assert_close(pred, pred_pipeline) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + @settings(max_examples=4, deadline=None) + # pyre-ignore[56] + @given( + execute_all_batches=st.booleans(), + weight_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + cache_precision=st.sampled_from( + [ + DataType.FP16, + DataType.FP32, + ] + ), + load_factor=st.sampled_from( + [ + 0.2, + 0.4, + 0.6, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + ] + ), + ) + def test_equal_to_non_pipelined_compiled( + self, + execute_all_batches: bool, + weight_precision: DataType, + cache_precision: DataType, + load_factor: float, + sharding_type: str, + kernel_type: str, + ) -> None: + """ + Checks that pipelined training is equivalent to non-pipelined training. + """ + mixed_precision: bool = weight_precision != cache_precision + self._set_table_weights_precision(weight_precision) + data = self._generate_data( + num_batches=12, + batch_size=32, + ) + dataloader = iter(data) + + fused_params = { + "cache_load_factor": load_factor, + "cache_precision": cache_precision, + "stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16 + } + fused_params_pipelined = { + **fused_params, + "prefetch_pipeline": True, + } + + model = self._setup_model() + sharded_model, optim = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params + ) + + ( + sharded_model_pipelined, + optim_pipelined, + ) = self._generate_sharded_model_and_optimizer( + model, sharding_type, kernel_type, fused_params_pipelined + ) + copy_state_dict( + sharded_model.state_dict(), sharded_model_pipelined.state_dict() + ) + + pipeline = PrefetchTrainPipelineSparseDist( + model=sharded_model_pipelined, + optimizer=optim_pipelined, + device=self.device, + execute_all_batches=execute_all_batches, + optimizer_compile_config=self.optimizer_compile_config, + ) + + if not execute_all_batches: + data = data[:-3] + + for batch in data: + # Forward + backward w/o pipelining + batch = batch.to(self.device) + optim.zero_grad() + loss, pred = sharded_model(batch) + loss.backward() + optim.step() + + # Forward + backward w/ pipelining + pred_pipeline = pipeline.progress(dataloader) + + if not mixed_precision: + # Rounding error is expected when using different precisions for weights and cache + self.assertTrue(torch.equal(pred, pred_pipeline)) + else: + torch.testing.assert_close(pred, pred_pipeline) + class DataLoadingThreadTest(unittest.TestCase): def test_fetch_data(self) -> None: @@ -2365,3 +2777,14 @@ def test_equal_to_non_pipelined( execute_all_batches: bool, ) -> None: super().test_equal_to_non_pipelined() + + @unittest.skip( + "TrainPipelineSparseDistTest.test_equal_to_non_pipelined_compiled was called from multiple different executors, which fails hypothesis HealthChek, so we skip it here" + ) + def test_equal_to_non_pipelined_compiled( + self, + sharding_type: str, + kernel_type: str, + execute_all_batches: bool, + ) -> None: + super().test_equal_to_non_pipelined_compiled() diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 56e6ac636..a68f78295 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -24,6 +24,7 @@ TestEBCSharderMCH, TestSparseNN, ) +from torchrec.distributed.train_pipeline import TorchCompileConfig from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist from torchrec.distributed.types import ModuleSharder, ShardingEnv from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig @@ -62,6 +63,7 @@ def setUp(self) -> None: self.device = torch.device("cuda:0") self.pipeline_class = TrainPipelineSparseDist + self.optimizer_compile_config = TorchCompileConfig() def tearDown(self) -> None: super().tearDown() diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index fcd7efc24..1311832a2 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -121,6 +121,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -140,6 +141,16 @@ def __init__( self._cur_batch: Optional[In] = None self._connected = False + if optimizer_compile_config is not None: + self._optimizer_step: Callable[[], None] = torch.compile( + lambda: self._optimizer.step(), + fullgraph=optimizer_compile_config.fullgraph, + dynamic=optimizer_compile_config.dynamic, + backend=optimizer_compile_config.backend, + ) + else: + self._optimizer_step: Callable[[], None] = self._optimizer.step + def _connect(self, dataloader_iter: Iterator[In]) -> None: cur_batch = next(dataloader_iter) self._cur_batch = cur_batch @@ -193,7 +204,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # Update if self._model.training: with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() return output @@ -221,6 +232,7 @@ def __init__( pre_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, post_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None, input_transformer: Optional[Callable[[In], In]] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -237,6 +249,16 @@ def __init__( self._iter = 0 self._cur_batch: Optional[In] = None + if optimizer_compile_config is not None: + self._optimizer_step: Callable[[], None] = torch.compile( + lambda: self._optimizer.step(), + fullgraph=optimizer_compile_config.fullgraph, + dynamic=optimizer_compile_config.dynamic, + backend=optimizer_compile_config.backend, + ) + else: + self._optimizer_step: Callable[[], None] = self._optimizer.step + def progress(self, dataloader_iter: Iterator[In]) -> Out: if self._iter == 0: # Turn on sync collectives for PT2 pipeline. @@ -292,7 +314,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: torch.sum(losses).backward() with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() return output @@ -337,6 +359,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -399,6 +422,16 @@ def __init__( self._batch_ip2: Optional[In] = None self._context: TrainPipelineContext = context_type(version=0) + if optimizer_compile_config is not None: + self._optimizer_step: Callable[[], None] = torch.compile( + lambda: self._optimizer.step(), + fullgraph=optimizer_compile_config.fullgraph, + dynamic=optimizer_compile_config.dynamic, + backend=optimizer_compile_config.backend, + ) + else: + self._optimizer_step: Callable[[], None] = self._optimizer.step + def detach(self) -> torch.nn.Module: """ Detaches the model from sparse data dist (SDD) pipeline. @@ -530,7 +563,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # update with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() self.dequeue_batch() return output @@ -757,6 +790,7 @@ def __init__( Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, strict: bool = False, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: super().__init__( model=model, @@ -767,6 +801,7 @@ def __init__( context_type=EmbeddingTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + optimizer_compile_config=optimizer_compile_config, ) self._start_batch = start_batch self._stash_gradients = stash_gradients @@ -836,7 +871,7 @@ def _mlp_optimizer_step(self, current_batch: int) -> None: # special case: not all optimizers support optim.step() on null gradidents if current_batch == self._start_batch and self._stash_gradients: return - self._optimizer.step() + self._optimizer_step() def progress(self, dataloader_iter: Iterator[In]) -> Out: self.fill_pipeline(dataloader_iter) @@ -1056,6 +1091,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: super().__init__( model=model, @@ -1066,6 +1102,7 @@ def __init__( context_type=PrefetchTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + optimizer_compile_config=optimizer_compile_config, ) self._context = PrefetchTrainPipelineContext(version=0) self._prefetch_stream: Optional[torch.Stream] = ( @@ -1133,7 +1170,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # update with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() self._start_sparse_data_dist(self._batch_ip2) @@ -1568,16 +1605,18 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + optimizer_compile_config: Optional[TorchCompileConfig] = None, ) -> None: super().__init__( - model, - optimizer, - device, - execute_all_batches, - apply_jit, - context_type, - pipeline_postproc, - custom_model_fwd, + model=model, + optimizer=optimizer, + device=device, + execute_all_batches=execute_all_batches, + apply_jit=apply_jit, + context_type=context_type, + pipeline_postproc=pipeline_postproc, + custom_model_fwd=custom_model_fwd, + optimizer_compile_config=optimizer_compile_config, ) torch._logging.set_logs(compiled_autograd_verbose=True) @@ -1665,7 +1704,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # update with record_function("## optimizer ##"): - self._optimizer.step() + self._optimizer_step() self.dequeue_batch() return output