Skip to content

Commit 1ac9895

Browse files
committed
return early from save_checkpoint
1 parent c6eee50 commit 1ac9895

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

deepspeed/runtime/engine.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,6 +2396,10 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
23962396
module = self.module_state_dict()
23972397
self._curr_ckpt_path = None
23982398

2399+
# Only a subset of procs may need to save the general model params
2400+
if not self.save_non_zero_checkpoint:
2401+
return
2402+
23992403
state = dict(module=module,
24002404
buffer_names=self._get_buffer_names(),
24012405
optimizer=self.optimizer.state_dict()
@@ -2412,10 +2416,9 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
24122416
ds_version=version)
24132417
state.update(client_state)
24142418

2415-
if self.save_non_zero_checkpoint:
2416-
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
2417-
#logger.info('Saving model checkpoint: {}'.format(save_path))
2418-
torch.save(state, save_path)
2419+
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
2420+
#logger.info('Saving model checkpoint: {}'.format(save_path))
2421+
torch.save(state, save_path)
24192422

24202423
def _get_buffer_names(self):
24212424
buffer_names = []

0 commit comments

Comments
 (0)