diff --git a/src/components/episode_buffer.py b/src/components/episode_buffer.py index 919a350f4..279ec0ac2 100644 --- a/src/components/episode_buffer.py +++ b/src/components/episode_buffer.py @@ -1,3 +1,5 @@ +import numbers + import torch as th import numpy as np from types import SimpleNamespace as SN @@ -60,7 +62,7 @@ def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess): group = field_info.get("group", None) dtype = field_info.get("dtype", th.float32) - if isinstance(vshape, int): + if isinstance(vshape, numbers.Integral): vshape = (vshape,) if group: @@ -100,7 +102,7 @@ def update(self, data, bs=slice(None), ts=slice(None), mark_filled=True): raise KeyError("{} not found in transition or episode data".format(k)) dtype = self.scheme[k].get("dtype", th.float32) - v = th.tensor(v, dtype=dtype, device=self.device) + v = th.as_tensor(v, dtype=dtype, device=self.device) self._check_safe_view(v, target[k][_slices]) target[k][_slices] = v.view_as(target[k][_slices]) diff --git a/src/learners/coma_learner.py b/src/learners/coma_learner.py index 207aaccd8..293997f15 100644 --- a/src/learners/coma_learner.py +++ b/src/learners/coma_learner.py @@ -93,7 +93,7 @@ def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("coma_loss", coma_loss.item(), t_env) - self.logger.log_stat("agent_grad_norm", grad_norm, t_env) + self.logger.log_stat("agent_grad_norm", grad_norm.item(), t_env) self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env diff --git a/src/learners/q_learner.py b/src/learners/q_learner.py index 02de44c4a..1221a7bf3 100644 --- a/src/learners/q_learner.py +++ b/src/learners/q_learner.py @@ -108,7 +108,7 @@ def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) - self.logger.log_stat("grad_norm", grad_norm, t_env) + self.logger.log_stat("grad_norm", grad_norm.item(), t_env) mask_elems = mask.sum().item() self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) diff --git a/src/learners/qtran_learner.py b/src/learners/qtran_learner.py index bf0f369f2..528c02b34 100644 --- a/src/learners/qtran_learner.py +++ b/src/learners/qtran_learner.py @@ -142,7 +142,7 @@ def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): self.logger.log_stat("td_loss", td_loss.item(), t_env) self.logger.log_stat("opt_loss", opt_loss.item(), t_env) self.logger.log_stat("nopt_loss", nopt_loss.item(), t_env) - self.logger.log_stat("grad_norm", grad_norm, t_env) + self.logger.log_stat("grad_norm", grad_norm.item(), t_env) if self.args.mixer == "qtran_base": mask_elems = mask.sum().item() self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env)