Skip to content

Commit

Permalink
🐛Fixed summary usage
Browse files Browse the repository at this point in the history
Should not be wrapped inside `eval_context`, because we need to track which tensors `requires_grad` in order to mark them as 'trainable'.
  • Loading branch information
carefree0910 committed Jan 20, 2024
1 parent ddd5d63 commit 9ee79ad
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions core/learn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,12 @@ def fit(
show_summary = not self.tqdm_settings.in_distributed
## should always summary to sync the statuses in distributed training
input_sample = get_input_sample(train_loader, self.device)
with self.model.eval_context():
summary_msg = summary(
self.model.m,
input_sample,
return_only=not show_summary or not self.is_local_rank_0,
summary_forward=self.model.summary_forward,
)
summary_msg = summary(
self.model.m,
input_sample,
return_only=not show_summary or not self.is_local_rank_0,
summary_forward=self.model.summary_forward,
)
if self.is_local_rank_0:
with open(os.path.join(self.workspace, self.summary_log_file), "w") as f:
f.write(summary_msg)
Expand Down

0 comments on commit 9ee79ad

Please sign in to comment.