diff --git a/ding/entry/utils.py b/ding/entry/utils.py index bbfbaa83bd..a3b66bfe70 100644 --- a/ding/entry/utils.py +++ b/ding/entry/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable, List, Any +from typing import Optional, Callable, List, Any, Dict from ding.policy import PolicyFactory from ding.worker import IMetric, MetricSerialEvaluator @@ -46,7 +46,8 @@ def random_collect( collector_env: 'BaseEnvManager', # noqa commander: 'BaseSerialCommander', # noqa replay_buffer: 'IBuffer', # noqa - postprocess_data_fn: Optional[Callable] = None + postprocess_data_fn: Optional[Callable] = None, + collect_kwargs: Optional[Dict] = None, ) -> None: # noqa assert policy_cfg.random_collect_size > 0 if policy_cfg.get('transition_with_policy_data', False): @@ -55,7 +56,8 @@ def random_collect( action_space = collector_env.action_space random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) collector.reset_policy(random_policy) - collect_kwargs = commander.step() + if collect_kwargs is None: + collect_kwargs = commander.step() if policy_cfg.collect.collector.type == 'episode': new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs) else: diff --git a/ding/envs/env_manager/envpool_env_manager.py b/ding/envs/env_manager/envpool_env_manager.py index a8d1a4ae03..bcfa5ae3ce 100644 --- a/ding/envs/env_manager/envpool_env_manager.py +++ b/ding/envs/env_manager/envpool_env_manager.py @@ -2,7 +2,11 @@ from easydict import EasyDict from copy import deepcopy import numpy as np +import torch +import treetensor.torch as ttorch +import treetensor.numpy as tnp from collections import namedtuple +import enum from typing import Any, Union, List, Tuple, Dict, Callable, Optional from ditk import logging try: @@ -17,17 +21,28 @@ from ding.torch_utils import to_ndarray -@ENV_MANAGER_REGISTRY.register('env_pool') +class EnvState(enum.IntEnum): + VOID = 0 + INIT = 1 + RUN = 2 + RESET = 3 + DONE = 4 + ERROR = 5 + NEED_RESET = 6 + + +@ENV_MANAGER_REGISTRY.register('envpool') class PoolEnvManager: - ''' + """ Overview: + PoolEnvManager supports old pipeline of DI-engine. Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. Here we list some commonly used env_ids as follows. For more examples, you can refer to . - Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" - Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" - ''' + """ @classmethod def default_config(cls) -> EasyDict: @@ -39,10 +54,17 @@ def default_config(cls) -> EasyDict: # Async mode: batch_size < env_num env_num=8, batch_size=8, + image_observation=True, + episodic_life=False, + reward_clip=False, + gray_scale=True, + stack_num=4, + frame_skip=4, ) def __init__(self, cfg: EasyDict) -> None: - self._cfg = cfg + self._cfg = self.default_config() + self._cfg.update(cfg) self._env_num = cfg.env_num self._batch_size = cfg.batch_size self._ready_obs = {} @@ -55,6 +77,7 @@ def launch(self) -> None: seed = 0 else: seed = self._seed + self._envs = envpool.make( task_id=self._cfg.env_id, env_type="gym", @@ -65,8 +88,10 @@ def launch(self) -> None: reward_clip=self._cfg.reward_clip, stack_num=self._cfg.stack_num, gray_scale=self._cfg.gray_scale, - frame_skip=self._cfg.frame_skip + frame_skip=self._cfg.frame_skip, ) + self._action_space = self._envs.action_space + self._observation_space = self._envs.observation_space self._closed = False self.reset() @@ -77,6 +102,8 @@ def reset(self) -> None: obs, _, _, info = self._envs.recv() env_id = info['env_id'] obs = obs.astype(np.float32) + if self._cfg.image_observation: + obs /= 255.0 self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs) if len(self._ready_obs) == self._env_num: break @@ -91,6 +118,8 @@ def step(self, action: dict) -> Dict[int, namedtuple]: obs, rew, done, info = self._envs.recv() obs = obs.astype(np.float32) + if self._cfg.image_observation: + obs /= 255.0 rew = rew.astype(np.float32) env_id = info['env_id'] timesteps = {} @@ -124,3 +153,153 @@ def env_num(self) -> int: @property def ready_obs(self) -> Dict[int, Any]: return self._ready_obs + + @property + def observation_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._observation_space + except AttributeError: + self.launch() + self.close() + return self._observation_space + + @property + def action_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._action_space + except AttributeError: + self.launch() + self.close() + return self._action_space + + +@ENV_MANAGER_REGISTRY.register('envpool_v2') +class PoolEnvManagerV2: + """ + Overview: + PoolEnvManagerV2 supports new pipeline of DI-engine. + Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. + Here we list some commonly used env_ids as follows. + For more examples, you can refer to . + + - Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" + - Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" + """ + + @classmethod + def default_config(cls) -> EasyDict: + return EasyDict(deepcopy(cls.config)) + + config = dict( + type='envpool_v2', + env_num=8, + batch_size=8, + image_observation=True, + episodic_life=False, + reward_clip=False, + gray_scale=True, + stack_num=4, + frame_skip=4, + ) + + def __init__(self, cfg: EasyDict) -> None: + super().__init__() + self._cfg = self.default_config() + self._cfg.update(cfg) + self._env_num = cfg.env_num + self._batch_size = cfg.batch_size + + self._closed = True + self._seed = None + + def launch(self) -> None: + assert self._closed, "Please first close the env manager" + if self._seed is None: + seed = 0 + else: + seed = self._seed + + self._envs = envpool.make( + task_id=self._cfg.env_id, + env_type="gym", + num_envs=self._env_num, + batch_size=self._batch_size, + seed=seed, + episodic_life=self._cfg.episodic_life, + reward_clip=self._cfg.reward_clip, + stack_num=self._cfg.stack_num, + gray_scale=self._cfg.gray_scale, + frame_skip=self._cfg.frame_skip, + ) + self._action_space = self._envs.action_space + self._observation_space = self._envs.observation_space + self._closed = False + return self.reset() + + def reset(self) -> None: + self._envs.async_reset() + ready_obs = {} + while True: + obs, _, _, info = self._envs.recv() + env_id = info['env_id'] + obs = obs.astype(np.float32) + if self._cfg.image_observation: + obs /= 255.0 + for i in range(len(list(env_id))): + ready_obs[env_id[i]] = obs[i] + if len(ready_obs) == self._env_num: + break + self._eval_episode_return = [0. for _ in range(self._env_num)] + + return ready_obs + + def send_action(self, action, env_id) -> Dict[int, namedtuple]: + self._envs.send(action, env_id) + + def receive_data(self): + next_obs, rew, done, info = self._envs.recv() + next_obs = next_obs.astype(np.float32) + if self._cfg.image_observation: + next_obs /= 255.0 + rew = rew.astype(np.float32) + + return next_obs, rew, done, info + + def close(self) -> None: + if self._closed: + return + # Envpool has no `close` API + self._closed = True + + @property + def closed(self) -> None: + return self._closed + + def seed(self, seed: int, dynamic_seed=False) -> None: + # The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here + self._seed = seed + logging.warning("envpool doesn't support dynamic_seed in different episode") + + @property + def env_num(self) -> int: + return self._env_num + + @property + def observation_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._observation_space + except AttributeError: + self.launch() + self.close() + self._ready_obs = {} + return self._observation_space + + @property + def action_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._action_space + except AttributeError: + self.launch() + self.close() + self._ready_obs = {} + return self._action_space diff --git a/ding/envs/env_manager/tests/test_envpool_env_manager.py b/ding/envs/env_manager/tests/test_envpool_env_manager.py index 9ac7730773..3d9e0dd5de 100644 --- a/ding/envs/env_manager/tests/test_envpool_env_manager.py +++ b/ding/envs/env_manager/tests/test_envpool_env_manager.py @@ -3,7 +3,7 @@ import numpy as np from easydict import EasyDict -from ..envpool_env_manager import PoolEnvManager +from ding.envs.env_manager.envpool_env_manager import PoolEnvManager, PoolEnvManagerV2 env_num_args = [[16, 8], [8, 8]] @@ -30,17 +30,51 @@ def test_naive(self, env_num, batch_size): env_manager = PoolEnvManager(env_manager_cfg) assert env_manager._closed env_manager.launch() - # Test step - start_time = time.time() - for count in range(20): + for count in range(5): env_id = env_manager.ready_obs.keys() action = {i: np.random.randint(4) for i in env_id} timestep = env_manager.step(action) assert len(timestep) == env_manager_cfg.batch_size - print('Count {}'.format(count)) - print([v.info for v in timestep.values()]) - end_time = time.time() - print('total step time: {}'.format(end_time - start_time)) - # Test close env_manager.close() assert env_manager._closed + + +@pytest.mark.envpooltest +@pytest.mark.parametrize('env_num, batch_size', env_num_args) +class TestPoolEnvManagerV2: + + def test_naive(self, env_num, batch_size): + env_manager_cfg = EasyDict( + { + 'env_id': 'Pong-v5', + 'env_num': env_num, + 'batch_size': batch_size, + 'seed': 3, + # env wrappers + 'episodic_life': False, + 'reward_clip': False, + 'gray_scale': True, + 'stack_num': 4, + 'frame_skip': 4, + } + ) + env_manager = PoolEnvManagerV2(env_manager_cfg) + assert env_manager._closed + ready_obs = env_manager.launch() + env_id = list(ready_obs.keys()) + for count in range(5): + action = {i: np.random.randint(4) for i in env_id} + action_send = np.array([action[i] for i in action.keys()]) + env_id_send = np.array(list(action.keys())) + env_manager.send_action(action_send, env_id_send) + next_obs, rew, done, info = env_manager.receive_data() + assert next_obs.shape == (env_manager_cfg.batch_size, 4, 84, 84) + assert rew.shape == (env_manager_cfg.batch_size, ) + assert done.shape == (env_manager_cfg.batch_size, ) + assert info['env_id'].shape == (env_manager_cfg.batch_size, ) + env_manager.close() + assert env_manager._closed + + +if __name__ == "__main__": + TestPoolEnvManagerV2().test_naive(16, 8) diff --git a/ding/example/dqn_nstep_envpool.py b/ding/example/dqn_nstep_envpool.py new file mode 100644 index 0000000000..7ab7a74677 --- /dev/null +++ b/ding/example/dqn_nstep_envpool.py @@ -0,0 +1,117 @@ +import datetime +from easydict import EasyDict +from ditk import logging +from ding.model import DQN +from ding.policy import DQNFastPolicy +from ding.envs.env_manager.envpool_env_manager import PoolEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import envpool_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger, \ + termination_checker, wandb_online_logger, epoch_timer, EnvpoolStepCollector, EnvpoolOffPolicyLearner +from ding.utils import set_pkg_seed +from dizoo.atari.config.serial import pong_dqn_envpool_config + + +def main(cfg): + logging.getLogger().setLevel(logging.INFO) + cfg.exp_name = 'Pong-v5-DQN-envpool-' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + + collector_env_cfg = EasyDict( + { + 'env_id': cfg.env.env_id, + 'env_num': cfg.env.collector_env_num, + 'batch_size': cfg.env.collector_batch_size, + # env wrappers + 'episodic_life': True, # collector: True + 'reward_clip': False, # collector: True + 'gray_scale': cfg.env.get('gray_scale', True), + 'stack_num': cfg.env.get('stack_num', 4), + } + ) + cfg.env["collector_env_cfg"] = collector_env_cfg + evaluator_env_cfg = EasyDict( + { + 'env_id': cfg.env.env_id, + 'env_num': cfg.env.evaluator_env_num, + 'batch_size': cfg.env.evaluator_batch_size, + # env wrappers + 'episodic_life': False, # evaluator: False + 'reward_clip': False, # evaluator: False + 'gray_scale': cfg.env.get('gray_scale', True), + 'stack_num': cfg.env.get('stack_num', 4), + } + ) + cfg.env["evaluator_env_cfg"] = evaluator_env_cfg + cfg = compile_config(cfg, PoolEnvManagerV2, DQNFastPolicy, save_cfg=task.router.node_id == 0) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_env = PoolEnvManagerV2(cfg.env.collector_env_cfg) + evaluator_env = PoolEnvManagerV2(cfg.env.evaluator_env_cfg) + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNFastPolicy(cfg.policy, model=model) + + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.EVALUATOR) + else: + task.add_role(task.role.COLLECTOR) + + # Sync their context and model between each worker. + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(model)) + task.use(epoch_timer()) + task.use(envpool_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use( + EnvpoolStepCollector( + cfg, + policy.collect_mode, + collector_env, + random_collect_size=cfg.policy.random_collect_size if hasattr(cfg.policy, 'random_collect_size') else 0, + ) + ) + task.use(data_pusher(cfg, buffer_)) + task.use(EnvpoolOffPolicyLearner(cfg, policy, buffer_)) + task.use(online_logger(train_show_freq=10)) + task.use( + wandb_online_logger( + metric_list=policy._monitor_vars_learn(), + model=policy._model, + exp_config=cfg, + anonymous=True, + project_name=cfg.exp_name, + wandb_sweep=False, + ) + ) + #task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) + task.use(termination_checker(max_env_step=10000000)) + task.run() + + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument("--collector_env_num", type=int, default=8, help="collector env number") + parser.add_argument("--collector_batch_size", type=int, default=8, help="collector batch size") + arg = parser.parse_args() + + pong_dqn_envpool_config.env.collector_env_num = arg.collector_env_num + pong_dqn_envpool_config.env.collector_batch_size = arg.collector_batch_size + pong_dqn_envpool_config.seed = arg.seed + + main(pong_dqn_envpool_config) diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index b9e3c5005d..aff23d79fc 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -1,6 +1,6 @@ from .functional import * -from .collector import StepCollector, EpisodeCollector, PPOFStepCollector -from .learner import OffPolicyLearner, HERLearner +from .collector import StepCollector, EpisodeCollector, PPOFStepCollector, EnvpoolStepCollector +from .learner import OffPolicyLearner, HERLearner, EnvpoolOffPolicyLearner from .ckpt_handler import CkptSaver from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger from .barrier import Barrier, BarrierRuntime diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index beb4894ad9..ed58c80993 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -10,6 +10,9 @@ if TYPE_CHECKING: from ding.framework import OnlineRLContext +import numpy as np +import torch + class StepCollector: """ @@ -68,6 +71,188 @@ def __call__(self, ctx: "OnlineRLContext") -> None: break +class EnvpoolStepCollector: + + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() + return super(EnvpoolStepCollector, cls).__new__(cls) + + def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: + """ + Arguments: + - cfg (:obj:`EasyDict`): Config. + - policy (:obj:`Policy`): The policy to be collected. + - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ + its derivatives are supported. + - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ + typically used in initial runs. + """ + self.cfg = cfg + self.env = env + + self._ready_obs_receive = {} + self._ready_obs_send = {} + self._ready_action_send = {} + self._trajectory = {i: [] for i in range(env.env_num)} + self._nsteps = self.cfg.policy.nstep if hasattr(self.cfg.policy, 'nstep') else 1 + self._discount_ratio_list = [self.cfg.policy.discount_factor ** (i + 1) for i in range(self._nsteps)] + self._nsteps_range = list(range(1, self._nsteps)) + self.policy = policy + self.random_collect_size = random_collect_size + + def __call__(self, ctx: "OnlineRLContext") -> None: + """ + Overview: + An encapsulation of inference and rollout middleware. Stop when completing \ + the target number of steps. + Input of ctx: + - env_step (:obj:`int`): The env steps which will increase during collection. + """ + old = ctx.env_step + + if self.random_collect_size > 0 and old < self.random_collect_size: + target_size = self.random_collect_size - old + random = True + else: + target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len + random = False + + if self.env.closed: + self._ready_obs_receive = self.env.launch() + + counter = 0 + + while True: + if len(self._ready_obs_receive.keys()) > 0: + if random: + action_to_send = { + i: { + "action": np.array([self.env.action_space.sample()]) + } + for i in self._ready_obs_receive.keys() + } + else: + action_by_policy = self.policy.forward(self._ready_obs_receive, **ctx.collect_kwargs) + + if isinstance(list(action_by_policy.values())[0]['action'], torch.Tensor): + # transfer to numpy + action_to_send = { + i: { + "action": action_by_policy[i]['action'].cpu().numpy() + } + for i in action_by_policy.keys() + } + else: + action_to_send = action_by_policy + self._ready_obs_send.update(self._ready_obs_receive) + self._ready_obs_receive = {} + self._ready_action_send.update(action_to_send) + + action_send = np.array([action_to_send[i]['action'] for i in action_to_send.keys()]) + if action_send.ndim == 2 and action_send.shape[1] == 1: + action_send = action_send.squeeze(1) + env_id_send = np.array(list(action_to_send.keys())) + self.env.send_action(action_send, env_id_send) + + next_obs, rew, done, info = self.env.receive_data() + env_id_receive = info['env_id'] + counter += len(env_id_receive) + self._ready_obs_receive.update({i: next_obs[i] for i in range(len(next_obs))}) + + #todo + for i in range(len(env_id_receive)): + current_reward = rew[i] + if self._nsteps > 1: + self._trajectory[env_id_receive[i]].append( + { + 'obs': self._ready_obs_send[env_id_receive[i]], + 'action': self._ready_action_send[env_id_receive[i]]['action'], + 'next_obs': next_obs[i], + # n-step reward + 'reward': [current_reward], + 'done': done[i], + } + ) + else: + self._trajectory[env_id_receive[i]].append( + { + 'obs': self._ready_obs_send[env_id_receive[i]], + 'action': self._ready_action_send[env_id_receive[i]]['action'], + 'next_obs': next_obs[i], + # n-step reward + 'reward': current_reward, + 'done': done[i], + } + ) + + if self._nsteps > 1: + if done[i] is False and counter < target_size: + reverse_record_position = min(self._nsteps, len(self._trajectory[env_id_receive[i]])) + real_reverse_record_position = reverse_record_position + + for j in range(1, reverse_record_position + 1): + if j == 1: + pass + else: + if self._trajectory[env_id_receive[i]][-j]['done'] is True: + real_reverse_record_position = j - 1 + break + else: + self._trajectory[env_id_receive[i]][-j]['reward'].append(current_reward) + + if real_reverse_record_position == self._nsteps: + self._trajectory[env_id_receive[i] + ][-real_reverse_record_position]['next_n_obs'] = next_obs[i] + self._trajectory[env_id_receive[i]][-real_reverse_record_position][ + 'value_gamma'] = self._discount_ratio_list[real_reverse_record_position - 1] + + else: # done[i] is True or counter >= target_size + + reverse_record_position = min(self._nsteps, len(self._trajectory[env_id_receive[i]])) + real_reverse_record_position = reverse_record_position + + for j in range(1, reverse_record_position + 1): + if j == 1: + self._trajectory[env_id_receive[i]][-j]['reward'].extend( + [ + 0.0 for _ in + range(self._nsteps - len(self._trajectory[env_id_receive[i]][-j]['reward'])) + ] + ) + self._trajectory[env_id_receive[i]][-j]['next_n_obs'] = next_obs[i] + self._trajectory[env_id_receive[i]][-j]['value_gamma'] = self._discount_ratio_list[j - + 1] + else: + if self._trajectory[env_id_receive[i]][-j]['done'] is True: + real_reverse_record_position = j + break + else: + self._trajectory[env_id_receive[i]][-j]['reward'].append(current_reward) + self._trajectory[env_id_receive[i]][-j]['reward'].extend( + [ + 0.0 for _ in range( + self._nsteps - len(self._trajectory[env_id_receive[i]][-j]['reward']) + ) + ] + ) + self._trajectory[env_id_receive[i]][-j]['next_n_obs'] = next_obs[i] + self._trajectory[env_id_receive[i]][-j]['value_gamma'] = self._discount_ratio_list[ + j - 1] + + else: + self._trajectory[env_id_receive[i]][-1]['value_gamma'] = self._discount_ratio_list[0] + + if counter >= target_size: + break + + ctx.trajectories = [] + for i in range(self.env.env_num): + ctx.trajectories.extend(self._trajectory[i]) + self._trajectory[i] = [] + ctx.env_step += len(ctx.trajectories) + + class PPOFStepCollector: """ Overview: diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 8474f2626e..cb7dbcf9d6 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -1,8 +1,8 @@ from .trainer import trainer, multistep_trainer from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \ - offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver + offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver, offpolicy_data_fetcher_v2 from .collector import inferencer, rolloutor, TransitionList -from .evaluator import interaction_evaluator, interaction_evaluator_ttorch +from .evaluator import interaction_evaluator, interaction_evaluator_ttorch, envpool_evaluator from .termination_checker import termination_checker, ddp_termination_checker from .logger import online_logger, offline_logger, wandb_online_logger, wandb_offline_logger from .ctx_helper import final_ctx_saver diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index ab1f1a5544..e254e4ad3b 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -1,4 +1,5 @@ import os +import torch.multiprocessing as mp from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional from easydict import EasyDict from ditk import logging @@ -8,9 +9,13 @@ from ding.framework import task from ding.utils import get_rank +from ding.policy.common_utils import default_preprocess_learn, fast_preprocess_learn + if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext +import time + def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None): """ @@ -31,7 +36,6 @@ def _push(ctx: "OnlineRLContext"): - trajectories (:obj:`List[Dict]`): Trajectories. - episodes (:obj:`List[Dict]`): Episodes. """ - if ctx.trajectories is not None: # each data in buffer is a transition if group_by_env: for i, t in enumerate(ctx.trajectories): @@ -170,10 +174,105 @@ def _fetch(ctx: "OnlineRLContext"): index = [d.index for d in buffered_data] meta = [d.meta for d in buffered_data] # such as priority - if isinstance(ctx.train_output, List): - priority = ctx.train_output.pop()['priority'] + if isinstance(ctx.train_output_for_post_process, List): + priority = ctx.train_output_for_post_process.pop()['priority'] + else: + priority = ctx.train_output_for_post_process['priority'] + for idx, m, p in zip(index, meta, priority): + m['priority'] = p + buffer_.update(index=idx, data=None, meta=m) + + return _fetch + + +def offpolicy_data_fetcher_v2( + cfg: EasyDict, + buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], + data_shortage_warning: bool = False, +) -> Callable: + """ + Overview: + The return function is a generator which meanly fetch a batch of data from a buffer, \ + a list of buffers, or a dict of buffers. + Arguments: + - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. + - buffer (:obj:`Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]`): \ + The buffer where the data is fetched from. \ + ``Buffer`` type means a buffer.\ + ``List[Tuple[Buffer, float]]`` type means a list of tuple. In each tuple there is a buffer and a float. \ + The float defines, how many batch_size is the size of the data \ + which is sampled from the corresponding buffer.\ + ``Dict[str, Buffer]`` type means a dict in which the value of each element is a buffer. \ + For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer \ + and assigned to the same key of `ctx.train_data`. + - data_shortage_warning (:obj:`bool`): Whether to output warning when data shortage occurs in fetching. + """ + + def _fetch(ctx: "OnlineRLContext"): + """ + Input of ctx: + - train_output (:obj:`Union[Dict, Deque[Dict]]`): This attribute should exist \ + if `buffer_` is of type Buffer and if `buffer_` use the middleware `PriorityExperienceReplay`. \ + The meta data `priority` of the sampled data in the `buffer_` will be updated \ + to the `priority` attribute of `ctx.train_output` if `ctx.train_output` is a dict, \ + or the `priority` attribute of `ctx.train_output`'s popped element \ + if `ctx.train_output` is a deque of dicts. + Output of ctx: + - train_data (:obj:`Union[List[Dict], Dict[str, List[Dict]]]`): The fetched data. \ + ``List[Dict]`` type means a list of data. + `train_data` is of this type if the type of `buffer_` is Buffer or List. + ``Dict[str, List[Dict]]]`` type means a dict, in which the value of each key-value pair + is a list of data. `train_data` is of this type if the type of `buffer_` is Dict. + """ + try: + unroll_len = cfg.policy.collect.unroll_len + if isinstance(buffer_, Buffer): + if unroll_len > 1: + buffered_data = buffer_.sample( + cfg.policy.learn.batch_size, groupby="env", unroll_len=unroll_len, replace=True + ) + ctx.train_data_sample = [[t.data for t in d] for d in buffered_data] # B, unroll_len + else: + buffered_data = buffer_.sample(cfg.policy.learn.batch_size) + ctx.train_data_sample = [d.data for d in buffered_data] + elif isinstance(buffer_, List): # like sqil, r2d3 + assert unroll_len == 1, "not support" + buffered_data = [] + for buffer_elem, p in buffer_: + data_elem = buffer_elem.sample(int(cfg.policy.learn.batch_size * p)) + assert data_elem is not None + buffered_data.append(data_elem) + buffered_data = sum(buffered_data, []) + ctx.train_data_sample = [d.data for d in buffered_data] + elif isinstance(buffer_, Dict): # like ppg_offpolicy + assert unroll_len == 1, "not support" + buffered_data = {k: v.sample(cfg.policy.learn.batch_size) for k, v in buffer_.items()} + ctx.train_data_sample = {k: [d.data for d in v] for k, v in buffered_data.items()} + else: + raise TypeError("not support buffer argument type: {}".format(type(buffer_))) + + assert buffered_data is not None + except (ValueError, AssertionError): + if data_shortage_warning: + # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode. + # Fetcher will skip this this attempt. + logging.warning( + "Replay buffer's data is not enough to support training, so skip this training to wait more data." + ) + ctx.train_data_sample = None + return + + yield + + if isinstance(buffer_, Buffer): + if any([isinstance(m, PriorityExperienceReplay) for m in buffer_._middleware]): + index = [d.index for d in buffered_data] + meta = [d.meta for d in buffered_data] + # such as priority + if isinstance(ctx.train_output_for_post_process, List): + priority = ctx.train_output_for_post_process.pop()['priority'] else: - priority = ctx.train_output['priority'] + priority = ctx.train_output_for_post_process['priority'] for idx, m, p in zip(index, meta, priority): m['priority'] = p buffer_.update(index=idx, data=None, meta=m) @@ -185,7 +284,6 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: from threading import Thread from queue import Queue - import time stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device): diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 611bbcdea6..31b70153a7 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -15,6 +15,8 @@ from ding.torch_utils import to_ndarray, get_shape0 from ding.utils import lists_to_dicts +import time + class IMetric(ABC): @@ -305,6 +307,127 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): return _evaluate +def envpool_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable: + """ + Overview: + The middleware that executes the evaluation. + Arguments: + - cfg (:obj:`EasyDict`): Config. + - policy (:obj:`Policy`): The policy to be evaluated. + - env (:obj:`BaseEnvManager`): The env for the evaluation. + - render (:obj:`bool`): Whether to render env images and policy logits. + """ + if task.router.is_active and not task.has_role(task.role.EVALUATOR): + return task.void() + + env.seed(cfg.seed, dynamic_seed=False) + + def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): + """ + Overview: + - The evaluation will be executed if the task begins and enough train_iter passed \ + since last evaluation. + Input of ctx: + - last_eval_iter (:obj:`int`): Last evaluation iteration. + - train_iter (:obj:`int`): Current train iteration. + Output of ctx: + - eval_value (:obj:`float`): The average reward in the current evaluation. + """ + + # evaluation will be executed if the task begins or enough train_iter after last evaluation + if ctx.last_eval_iter != -1 and \ + (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): + return + + ready_obs_receive = {} + ready_obs_send = {} + ready_action_send = {} + trajectory = {i: [] for i in range(env.env_num)} + + if env.closed: + ready_obs_receive = env.launch() + else: + ready_obs_receive = env.reset() + policy.reset() + eval_monitor = VectorEvalMonitor(env.env_num, cfg.env.n_evaluator_episode) + + while not eval_monitor.is_finished(): + + if len(ready_obs_receive.keys()) > 0: + action_to_send = policy.forward(ready_obs_receive) + output = [v for v in action_to_send.values()] + + ready_obs_send.update(ready_obs_receive) + ready_obs_receive = {} + ready_action_send.update(action_to_send) + + action_send = np.array([action_to_send[i]['action'] for i in action_to_send.keys()]) + if action_send.ndim == 2 and action_send.shape[1] == 1: + action_send = action_send.squeeze(1) + env_id_send = np.array(list(action_to_send.keys())) + env.send_action(action_send, env_id_send) + + next_obs, rew, done, info = env.receive_data() + env_id_receive = info['env_id'] + ready_obs_receive.update({i: next_obs[i] for i in range(len(next_obs))}) + + #todo + for i in range(len(env_id_receive)): + current_reward = ttorch.tensor(np.array([rew[i]])) + trajectory[env_id_receive[i]].append( + { + 'obs': ttorch.tensor(ready_obs_send[env_id_receive[i]]), + 'action': ttorch.tensor(ready_action_send[env_id_receive[i]]['action']), + 'next_obs': ttorch.tensor(next_obs[i]), + # n-step reward + 'reward': [current_reward], + 'done': ttorch.tensor(done[i]) + } + ) + + if done[i] is True: + episode_return_i = 0.0 + for item in trajectory[env_id_receive[i]]: + episode_return_i += item['reward'][0] + eval_monitor.update_reward(env_id_receive[i], episode_return_i) + policy.reset([env_id_receive[i]]) + trajectory[env_id_receive[i]] = [] + + episode_return = eval_monitor.get_episode_return() + episode_return_min = np.min(episode_return) + episode_return_max = np.max(episode_return) + episode_return_std = np.std(episode_return) + episode_return = np.mean(episode_return) + stop_flag = episode_return >= cfg.env.stop_value and ctx.train_iter > 0 + if isinstance(ctx, OnlineRLContext): + logging.info( + 'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})'.format( + ctx.train_iter, ctx.env_step, episode_return + ) + ) + elif isinstance(ctx, OfflineRLContext): + logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, episode_return)) + else: + raise TypeError("not supported ctx type: {}".format(type(ctx))) + ctx.last_eval_iter = ctx.train_iter + ctx.eval_value = episode_return + ctx.eval_value_min = episode_return_min + ctx.eval_value_max = episode_return_max + ctx.eval_value_std = episode_return_std + ctx.last_eval_value = ctx.eval_value + ctx.eval_output = {'episode_return': episode_return} + episode_info = eval_monitor.get_episode_info() + if episode_info is not None: + ctx.eval_output['episode_info'] = episode_info + + ctx.eval_output['output'] = output # for compatibility + + if stop_flag: + task.finish = True + + return _evaluate + + def interaction_evaluator_ttorch( seed: int, policy: Policy, diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 9f62e2f429..419b3f53a4 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -302,6 +302,19 @@ def _plot(ctx: "OnlineRLContext"): "If you want to use wandb to visualize the result, please set plot_logger = True in the config." ) + if hasattr(ctx, "evaluator_time"): + info_for_logging.update({"evaluator_time": ctx.evaluator_time}) + if hasattr(ctx, "collector_time"): + info_for_logging.update({"collector_time": ctx.collector_time}) + if hasattr(ctx, "learner_time"): + info_for_logging.update({"learner_time": ctx.learner_time}) + if hasattr(ctx, "data_pusher_time"): + info_for_logging.update({"data_pusher_time": ctx.data_pusher_time}) + if hasattr(ctx, "nstep_time"): + info_for_logging.update({"nstep_time": ctx.nstep_time}) + if hasattr(ctx, "total_time"): + info_for_logging.update({"total_time": ctx.total_time}) + if ctx.eval_value != -np.inf: if hasattr(ctx, "eval_value_min"): info_for_logging.update({ diff --git a/ding/framework/middleware/learner.py b/ding/framework/middleware/learner.py index 9abf88e9b3..8779ced315 100644 --- a/ding/framework/middleware/learner.py +++ b/ding/framework/middleware/learner.py @@ -4,13 +4,46 @@ from ding.framework import task from ding.data import Buffer -from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer +from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer, offpolicy_data_fetcher_v2 if TYPE_CHECKING: from ding.framework import Context, OnlineRLContext from ding.policy import Policy from ding.reward_model import BaseRewardModel +from queue import Queue +import time +from threading import Thread +from ding.policy.common_utils import fast_preprocess_learn + + +def data_process_func( + data_queue_input: Queue, + data_queue_output: Queue, + use_priority: bool = False, + use_priority_IS_weight: bool = False, + use_nstep: bool = False, + cuda: bool = True, + device: str = "cuda:0", +): + while True: + if data_queue_input.empty(): + time.sleep(0.001) + else: + data = data_queue_input.get() + if data is None: + break + else: + output_data = fast_preprocess_learn( + data, + use_priority=use_priority, + use_priority_IS_weight=use_priority_IS_weight, + use_nstep=use_nstep, + cuda=cuda, + device=device, + ) + data_queue_output.put(output_data) + class OffPolicyLearner: """ @@ -63,9 +96,104 @@ def __call__(self, ctx: "OnlineRLContext") -> None: self._reward_estimator(ctx) self._trainer(ctx) train_output_queue.append(ctx.train_output) + ctx.train_output_for_post_process = ctx.train_output ctx.train_output = train_output_queue +class EnvpoolOffPolicyLearner: + """ + Overview: + The class of the off-policy learner, including data fetching and model training. Use \ + the `__call__` method to execute the whole learning process. + """ + + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.LEARNER): + return task.void() + return super(EnvpoolOffPolicyLearner, cls).__new__(cls) + + def __init__( + self, + cfg: EasyDict, + policy: 'Policy', + buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], + reward_model: Optional['BaseRewardModel'] = None, + log_freq: int = 100, + ) -> None: + """ + Arguments: + - cfg (:obj:`EasyDict`): Config. + - policy (:obj:`Policy`): The policy to be trained. + - buffer (:obj:`Buffer`): The replay buffer to store the data for training. + - reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \ + default to None. + - log_freq (:obj:`int`): The frequency (iteration) of showing log. + """ + self.cfg = cfg + + self._fetcher = task.wrap(offpolicy_data_fetcher_v2(cfg, buffer_)) + + self._data_queue_input = Queue() + self._data_queue_output = Queue() + + self.thread_worker = Thread( + target=data_process_func, + args=( + self._data_queue_input, + self._data_queue_output, + cfg.policy.priority, + cfg.policy.priority_IS_weight, + cfg.policy.nstep > 1, + cfg.policy.cuda, + policy._device, + ) + ) + self.thread_worker.start() + + self._trainer = task.wrap(trainer(cfg, policy.learn_mode, log_freq=log_freq)) + if reward_model is not None: + self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model)) + else: + self._reward_estimator = None + + def __call__(self, ctx: "OnlineRLContext") -> None: + """ + Output of ctx: + - train_output (:obj:`Deque`): The training output in deque. + """ + train_output_queue = [] + data_counter = 0 + for _ in range(self.cfg.policy.learn.update_per_collect): + self._fetcher(ctx) + if ctx.train_data_sample is None: + break + self._data_queue_input.put(ctx.train_data_sample) + data_counter += 1 + + for _ in range(data_counter): + while True: + if self._data_queue_output.empty(): + time.sleep(0.001) + continue + else: + ctx.train_data = self._data_queue_output.get() + break + if self._reward_estimator: + self._reward_estimator(ctx) + self._trainer(ctx) + + train_output_queue.append(ctx.train_output) + ctx.train_output_for_post_process = ctx.train_output + + ctx.train_output = train_output_queue + + yield + + if task.finish: + self._data_queue_input.put(None) + self.thread_worker.join() + + class HERLearner: """ Overview: diff --git a/ding/framework/middleware/tests/test_distributer.py b/ding/framework/middleware/tests/test_distributer.py index 7651e66ec7..942bbc7621 100644 --- a/ding/framework/middleware/tests/test_distributer.py +++ b/ding/framework/middleware/tests/test_distributer.py @@ -246,18 +246,15 @@ def train(ctx): task.use(train) else: y_pred1 = policy.predict(X) - print("y_pred1: ", y_pred1) stale = 1 def pred(ctx): nonlocal stale y_pred2 = policy.predict(X) - print("y_pred2: ", y_pred2) stale += 1 assert stale <= 3 or all(y_pred1 == y_pred2) if any(y_pred1 != y_pred2): stale = 1 - sleep(0.3) task.use(pred) diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py index f74a179962..fe30a1efe7 100644 --- a/ding/model/common/utils.py +++ b/ding/model/common/utils.py @@ -1,5 +1,6 @@ import copy import torch +import torch.nn as nn from easydict import EasyDict from ding.utils import import_module, MODEL_REGISTRY diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 25e8b67c4d..1d428cf99b 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -1,6 +1,6 @@ from .base_policy import Policy, CommandModePolicy, create_policy, get_policy_cls from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch, default_preprocess_learn -from .dqn import DQNSTDIMPolicy, DQNPolicy +from .dqn import DQNSTDIMPolicy, DQNPolicy, DQNFastPolicy from .mdqn import MDQNPolicy from .iqn import IQNPolicy from .fqf import FQFPolicy diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index fd2c7d3d61..a8e095974c 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -1,9 +1,9 @@ from typing import List, Any, Dict, Callable -import torch import numpy as np +import torch import treetensor.torch as ttorch from ding.utils.data import default_collate -from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze +from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze, to_device def default_preprocess_learn( @@ -69,6 +69,94 @@ def default_preprocess_learn( return data +def fast_preprocess_learn( + data: List[np.ndarray], + use_priority_IS_weight: bool = False, + use_priority: bool = False, + use_nstep: bool = False, + cuda: bool = False, + device: str = 'cpu', +) -> dict: + """ + Overview: + Fast data pre-processing before policy's ``_forward_learn`` method, including stacking batch data, transform \ + data to PyTorch Tensor and move data to GPU, etc. This function is faster than ``default_preprocess_learn`` \ + but less flexible. + This function abandons calling ``default_collate`` to stack data because ``default_collate`` \ + is recursive and cumbersome. In this function, we alternatively stack the data and send it to GPU, so that it \ + is faster. In addition, this function is usually used in a special data process thread in learner. + Arguments: + - data (:obj:`List[np.ndarray]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor. + - use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \ + will set the weight of each sample to the priority IS weight. + - use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight. + - cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda. + - device (:obj:`str`): The device name to move the input data to. + Returns: + - data (:obj:`dict`): The preprocessed dict data whose values can be directly used for \ + the following model forward and loss computation. + """ + processes_data = {} + + action = torch.tensor(np.array([data[i]['action'] for i in range(len(data))])) + if cuda: + action = to_device(action, device=device) + if action.ndim == 2 and action.shape[1] == 1 and action.dtype in [torch.int64, torch.int32]: + action = action.squeeze(1) + processes_data['action'] = action + + obs = torch.tensor(np.array([data[i]['obs'] for i in range(len(data))])) + if cuda: + obs = to_device(obs, device=device) + processes_data['obs'] = obs + + next_obs = torch.tensor(np.array([data[i]['next_obs'] for i in range(len(data))])) + if cuda: + next_obs = to_device(next_obs, device=device) + processes_data['next_obs'] = next_obs + + if 'next_n_obs' in data[0]: + next_n_obs = torch.tensor(np.array([data[i]['next_n_obs'] for i in range(len(data))])) + if cuda: + next_n_obs = to_device(next_n_obs, device=device) + processes_data['next_n_obs'] = next_n_obs + + reward = torch.tensor(np.array([data[i]['reward'] for i in range(len(data))]), dtype=torch.float32) + if cuda: + reward = to_device(reward, device=device) + if use_nstep: + reward = reward.permute(1, 0).contiguous() + processes_data['reward'] = reward + + if 'value_gamma' in data[0]: + value_gamma = torch.tensor(np.array([data[i]['value_gamma'] for i in range(len(data))]), dtype=torch.float32) + if cuda: + value_gamma = to_device(value_gamma, device=device) + processes_data['value_gamma'] = value_gamma + + done = torch.tensor(np.array([data[i]['done'] for i in range(len(data))]), dtype=torch.float32) + if cuda: + done = to_device(done, device=device) + processes_data['done'] = done + + if use_priority and use_priority_IS_weight: + if 'priority_IS' in data: + weight = data['priority_IS'] + else: # for compability + weight = data['IS'] + else: + if 'weight' in data[0]: + weight = torch.tensor(np.array([data[i]['weight'] for i in range(len(data))])) + else: + weight = None + + if weight is not None and cuda: + weight = to_device(weight, device=device) + processes_data['weight'] = weight + + return processes_data + + def single_env_forward_wrapper(forward_fn: Callable) -> Callable: """ Overview: diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index d1f6fdbb49..82e54462bf 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -199,14 +199,16 @@ def _init_learn(self) -> None: # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) - if 'target_update_freq' in self._cfg.learn: + if 'target_update_freq' in self._cfg.learn and self._cfg.learn.target_update_freq is not None \ + and self._cfg.learn.target_update_freq > 0: self._target_model = model_wrap( self._target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': self._cfg.learn.target_update_freq} ) - elif 'target_theta' in self._cfg.learn: + elif 'target_theta' in self._cfg.learn and self._cfg.learn.target_theta is not None \ + and self._cfg.learn.target_theta > 0.0: self._target_model = model_wrap( self._target_model, wrapper_name='target', @@ -248,6 +250,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: .. note:: For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ + # Data preprocessing operations, such as stack data, cpu to cuda device data = default_preprocess_learn( data, @@ -256,6 +259,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: ignore_done=self._cfg.learn.ignore_done, use_nstep=True ) + if self._cuda: data = to_device(data, self._device) # Q-learning forward @@ -284,6 +288,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: # Postprocessing operations, such as updating target model, return logged values and priority. self._target_model.update(self._learn_model.state_dict()) + return { 'cur_lr': self._optimizer.defaults['lr'], 'total_loss': loss.item(), @@ -484,6 +489,461 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} + def monitor_vars(self) -> List[str]: + return ['cur_lr', 'total_loss', 'q_value'] + + def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = False) -> Dict[str, Any]: + """ + Overview: + Calculate priority for replay buffer. + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training. + Returns: + - priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars. + ArgumentsKeys: + - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done`` + - optional: ``value_gamma`` + ReturnsKeys: + - necessary: ``priority`` + """ + + if update_target_model: + self._target_model.load_state_dict(self._learn_model.state_dict()) + + data = default_preprocess_learn( + data, + use_priority=False, + use_priority_IS_weight=False, + ignore_done=self._cfg.learn.ignore_done, + use_nstep=True + ) + if self._cuda: + data = to_device(data, self._device) + # ==================== + # Q-learning forward + # ==================== + self._learn_model.eval() + self._target_model.eval() + with torch.no_grad(): + # Current q value (main model) + q_value = self._learn_model.forward(data['obs'])['logit'] + # Target q value + target_q_value = self._target_model.forward(data['next_obs'])['logit'] + # Max q value action (main model), i.e. Double DQN + target_q_action = self._learn_model.forward(data['next_obs'])['action'] + data_n = q_nstep_td_data( + q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight'] + ) + value_gamma = data.get('value_gamma') + loss, td_error_per_sample = q_nstep_td_error( + data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma + ) + return {'priority': td_error_per_sample.abs().tolist()} + + +@POLICY_REGISTRY.register('dqn_fast') +class DQNFastPolicy(Policy): + """ + Overview: + Policy class of DQN algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD. + + Config: + == ===================== ======== ============== ======================================= ======================= + ID Symbol Type Default Value Description Other(Shape) + == ===================== ======== ============== ======================================= ======================= + 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional, + | registry ``POLICY_REGISTRY`` | a placeholder + 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff- + | erent from modes + 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy + | or off-policy + 4 ``priority`` bool False | Whether use priority(PER) | Priority sample, + | update priority + 5 | ``priority_IS`` bool False | Whether use Importance Sampling + | ``_weight`` | Weight to correct biased update. If + | True, priority must be True. + 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse + | ``factor`` [0.95, 0.999] | gamma | reward env + 7 ``nstep`` int 1, | N-step reward discount sum for target + [3, 5] | q_value estimation + 8 | ``model.dueling`` bool True | dueling head architecture + 9 | ``model.encoder`` list [32, 64, | Sequence of ``hidden_size`` of | default kernel_size + | ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3] + | ``_size_list`` | final dense layer. | default stride is + | [4, 2 ,1] + 10 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary + | ``per_collect`` | after collector's one collection. | from envs. Bigger val + | Only valid in serial training | means more off-policy + 11 | ``learn.batch_`` int 64 | The number of samples of an iteration + | ``size`` + 12 | ``learn.learning`` float 0.001 | Gradient step length of an iteration. + | ``_rate`` + 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update + | ``update_freq`` + 14 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update + | ``theta`` | Only one of [target_update_freq, + | | target_theta] should be set + 15 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some + | ``done`` | calculation. | fake termination env + 16 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from + | call of collector. | different envs + 17 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample + | call of collector | ,n_episode] should + | | be set + 18 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1 + | ``_len`` + 19 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp', + | 'linear']. + 20 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1] + | ``start`` + 21 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1] + | ``end`` + 22 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set + | ``decay`` | decay=10000 means + | the exploration rate + | decay from start + | value to end value + | during decay length. + == ===================== ======== ============== ======================================= ======================= + """ + + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type='dqn_fast', + # (bool) Whether use cuda in policy. + cuda=False, + # (bool) Whether learning policy is the same as collecting data policy(on-policy). + on_policy=False, + # (bool) Whether enable priority experience sample. + priority=False, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + # (float) Discount factor(gamma) for returns. + discount_factor=0.97, + model=dict( + #(list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer. + encoder_hidden_size_list=[128, 128, 64], + ), + learn=dict( + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + update_per_collect=3, + # (int) How many samples in a training batch. + batch_size=64, + # (float) The step size of gradient descent. + learning_rate=0.001, + # (int) Frequence of target network update. + # Only one of [target_update_freq, target_theta] should be set. + target_update_freq=100, + # (float) : Used for soft update of the target network. + # aka. Interpolation factor in EMA update for target network. + # Only one of [target_update_freq, target_theta] should be set. + target_theta=0.005, + # (bool) Whether ignore done(usually for max step termination env). + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. + ignore_done=False, + ), + # collect_mode config + collect=dict( + # (int) How many training samples collected in one collection procedure. + # Only one of [n_sample, n_episode] shoule be set. + n_sample=8, + # (int) Split episodes or trajectories into pieces with length `unroll_len`. + unroll_len=1, + ), + eval=dict(), + # other config + other=dict( + # Epsilon greedy with decay. + eps=dict( + # (str) Decay type. Support ['exp', 'linear']. + type='exp', + # (float) Epsilon start value. + start=0.95, + # (float) Epsilon end value. + end=0.1, + # (int) Decay length(env step). + decay=10000, + ), + replay_buffer=dict( + # (int) Maximum size of replay buffer. Usually, larger buffer size is good. + replay_buffer_size=10000, + ), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names + + .. note:: + The user can define and use customized network model but must obey the same inferface definition indicated \ + by import_names path. For DQN, ``ding.model.template.q_learning.DQN`` + """ + return 'dqn', ['ding.model.template.q_learning'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \ + and target model. + """ + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + # Optimizer + self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) + + self._gamma = self._cfg.discount_factor + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + if 'target_update_freq' in self._cfg.learn and self._cfg.learn.target_update_freq is not None \ + and self._cfg.learn.target_update_freq > 0: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.learn.target_update_freq} + ) + elif 'target_theta' in self._cfg.learn and self._cfg.learn.target_theta is not None \ + and self._cfg.learn.target_theta > 0.0: + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.learn.target_theta} + ) + else: + raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta") + self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample') + self._learn_model.reset() + self._target_model.reset() + self.time_counter = dict( + set_model_train_time=0, + forward_q_value_time=0, + forward_target_next_time=0, + q_nstep_td_data_time=0, + get_value_gamma_time=0, + loss_time=0, + backward_time=0, + gradient_step_time=0, + target_update_time=0, + time_learn_total=0, + counter_learn=0, + ) + + def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]: + + # ==================== + # Q-learning forward + # ==================== + self._learn_model.train() + self._target_model.train() + + q_value = self._learn_model.forward(data['obs'])['logit'] + + # Target q value + with torch.no_grad(): + target_next_n_q_value = self._target_model.forward(data['next_n_obs'])['logit'] + # Max q value action (main model), i.e. Double DQN + target_next_n_action = self._learn_model.forward(data['next_n_obs'])['action'] + + data_n = q_nstep_td_data( + q_value, target_next_n_q_value, data['action'], target_next_n_action, data['reward'], data['done'], + data['weight'] + ) + + if self._cfg.nstep == 1: + value_gamma = None + else: + value_gamma = data.get( + 'value_gamma' + ) if 'value_gamma' in data else self._cfg.discount_factor * torch.ones_like(data['done']) + + loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma) + + # ==================== + # Q-learning update + # ==================== + self._optimizer.zero_grad() + loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + self._optimizer.step() + + # ============= + # after update + # ============= + self._target_model.update(self._learn_model.state_dict()) + + return { + 'cur_lr': self._optimizer.defaults['lr'], + 'total_loss': loss.item(), + 'q_value': q_value.mean().item(), + 'target_q_value': target_next_n_q_value.mean().item(), + 'priority': td_error_per_sample.abs().tolist(), + # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard. + # '[histogram]action_distribution': data['action'], + } + + def _monitor_vars_learn(self) -> List[str]: + return ['cur_lr', 'total_loss', 'q_value', 'target_q_value'] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model, \ + enable the eps_greedy_sample for exploration. + """ + self._unroll_len = self._cfg.collect.unroll_len + self._gamma = self._cfg.discount_factor # necessary for parallel + self._nstep = self._cfg.nstep # necessary for parallel + self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample') + self._collect_model.reset() + + def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: + """ + Overview: + Forward computation graph of collect mode(collect training data), with eps_greedy for exploration. + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ + values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step. + Returns: + - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \ + env and the constructing of transition. + ArgumentsKeys: + - necessary: ``obs`` + ReturnsKeys + - necessary: ``logit``, ``action`` + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._collect_model.eval() + with torch.no_grad(): + output = self._collect_model.forward(data, eps=eps) + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Overview: + For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \ + can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \ + or some continuous transitions(DRQN). + Arguments: + - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \ + format as the return value of ``self._process_transition`` method. + Returns: + - samples (:obj:`dict`): The list of training samples. + + .. note:: + We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \ + And the user can customize the this data processing procecure by overriding this two methods and collector \ + itself. + """ + data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) + return get_train_sample(data, self._unroll_len) + + def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]: + """ + Overview: + Generate a transition(e.g.: ) for this algorithm training. + Arguments: + - obs (:obj:`Any`): Env observation. + - policy_output (:obj:`Dict[str, Any]`): The output of policy collect mode(``self._forward_collect``),\ + including at least ``action``. + - timestep (:obj:`namedtuple`): The output after env step(execute policy output action), including at \ + least ``obs``, ``reward``, ``done``, (here obs indicates obs after env step). + Returns: + - transition (:obj:`dict`): Dict type transition data. + """ + transition = { + 'obs': obs, + 'next_obs': timestep.obs, + 'action': policy_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + return transition + + def _init_eval(self) -> None: + r""" + Overview: + Evaluate mode init method. Called by ``self.__init__``, initialize eval_model. + """ + self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample') + self._eval_model.reset() + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + """ + Overview: + Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \ + ``self._forward_collect``. + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ + values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + Returns: + - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. + ArgumentsKeys: + - necessary: ``obs`` + ReturnsKeys + - necessary: ``action`` + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + with torch.no_grad(): + output = self._eval_model.forward(data) + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = False) -> Dict[str, Any]: """ Overview: diff --git a/ding/policy/tests/test_common_utils.py b/ding/policy/tests/test_common_utils.py index 96fbde0963..a7e279b6a0 100644 --- a/ding/policy/tests/test_common_utils.py +++ b/ding/policy/tests/test_common_utils.py @@ -5,6 +5,7 @@ import treetensor.torch as ttorch from ding.policy.common_utils import default_preprocess_learn +from ding.policy.common_utils import fast_preprocess_learn shape_test = [ [2], @@ -173,3 +174,64 @@ def test_default_preprocess_learn_nstep(): assert data['reward'][0][0] == torch.tensor(1.0) assert data['reward'][1][0] == torch.tensor(2.0) assert data['reward'][2][0] == torch.tensor(0.0) + + +@pytest.mark.unittest +def test_fast_preprocess_learn_action(): + + for shape in shape_test: + for dtype in dtype_test: + data = [ + { + 'obs': np.random.randn(4, 84, 84), + 'action': np.random.randn(*shape).astype(dtype), + 'reward': 1.0, + 'next_obs': np.random.randn(4, 84, 84), + 'done': False, + 'weight': 1.0, + } for _ in range(10) + ] + use_priority_IS_weight = False + use_priority = False + use_nstep = False + data = fast_preprocess_learn( + data, use_priority_IS_weight, use_priority, use_nstep, cuda=False, device="cpu" + ) + + assert data['obs'].shape == torch.Size([10, 4, 84, 84]) + if dtype in ["int64"] and shape[0] == 1: + assert data['action'].shape == torch.Size([10]) + else: + assert data['action'].shape == torch.Size([10, *shape]) + assert data['reward'].shape == torch.Size([10]) + assert data['next_obs'].shape == torch.Size([10, 4, 84, 84]) + assert data['done'].shape == torch.Size([10]) + assert data['weight'].shape == torch.Size([10]) + + +@pytest.mark.unittest +def test_fast_preprocess_learn_nstep(): + + data = [ + { + 'obs': np.random.randn(4, 84, 84), + 'action': np.random.randn(2), + 'reward': np.array([1.0, 2.0, 0.0]), + 'next_obs': np.random.randn(4, 84, 84), + 'done': False, + 'weight': 1.0, + } for _ in range(10) + ] + use_priority_IS_weight = False + use_priority = False + use_nstep = True + data = fast_preprocess_learn(data, use_priority_IS_weight, use_priority, use_nstep, cuda=False, device="cpu") + + assert data['reward'].shape == torch.Size([3, 10]) + assert data['reward'][0][0] == torch.tensor(1.0) + assert data['reward'][1][0] == torch.tensor(2.0) + assert data['reward'][2][0] == torch.tensor(0.0) + + +if __name__ == "__main__": + test_fast_preprocess_learn_nstep() diff --git a/ding/utils/default_helper.py b/ding/utils/default_helper.py index d76b6936f3..396a89e191 100644 --- a/ding/utils/default_helper.py +++ b/ding/utils/default_helper.py @@ -8,7 +8,7 @@ import treetensor.torch as ttorch -def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int: +def get_shape0(data: Union[List, Dict, np.ndarray, torch.Tensor, ttorch.Tensor]) -> int: """ Overview: Get shape[0] of data's torch tensor or treetensor @@ -34,6 +34,8 @@ def fn(t): return fn(item) return fn(data.shape) + elif isinstance(data, np.ndarray): + return data.shape[0] else: raise TypeError("Error in getting shape0, not support type: {}".format(data)) diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index 26db458edb..07bac75ae5 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -25,7 +25,7 @@ class SampleSerialCollector(ISerialCollector): envstep """ - config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100) + config = dict(type='sample', deepcopy_obs=False, transform_obs=False, collect_print_freq=100) def __init__( self, @@ -34,7 +34,8 @@ def __init__( policy: namedtuple = None, tb_logger: 'SummaryWriter' = None, # noqa exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector' + instance_name: Optional[str] = 'collector', + timer_cuda: bool = False, ) -> None: """ Overview: @@ -44,6 +45,10 @@ def __init__( - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy - tb_logger (:obj:`SummaryWriter`): tensorboard handle + - exp_name (:obj:`Optional[str]`): name of the project folder of this experiment + - instance_name (:obj:`Optional[str]`): instance name, used to specify the saving path of log and model + - timer_cuda (:obj:`bool`): whether to use cuda timer, if True, the timer will measure the time of \ + the forward process on cuda, otherwise, the timer will measure the time of the forward process on cpu. """ self._exp_name = exp_name self._instance_name = instance_name @@ -51,7 +56,7 @@ def __init__( self._deepcopy_obs = cfg.deepcopy_obs # whether to deepcopy each data self._transform_obs = cfg.transform_obs self._cfg = cfg - self._timer = EasyTimer() + self._timer = EasyTimer(cuda=timer_cuda) self._end_flag = False self._rank = get_rank() self._world_size = get_world_size() diff --git a/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py b/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py index 0b80e41548..8a9f9c1721 100644 --- a/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py +++ b/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py @@ -8,14 +8,16 @@ evaluator_env_num=8, evaluator_batch_size=8, n_evaluator_episode=8, - stop_value=20, - env_id='PongNoFrameskip-v4', + stop_value=21, + env_id='Pong-v5', #'ALE/Pong-v5' is available. But special setting is needed after gym make. frame_stack=4, ), + nstep = 3, policy=dict( cuda=True, priority=False, + random_collect_size=50000, model=dict( obs_shape=[4, 84, 84], action_shape=6, @@ -24,10 +26,15 @@ nstep=3, discount_factor=0.99, learn=dict( - update_per_collect=10, + update_per_collect=2, batch_size=32, learning_rate=0.0001, - target_update_freq=500, + # If updating target network by replacement, \ + # target_update_freq should be larger than 0. \ + # If updating target network by changing several percentage of the origin weights, \ + # target_update_freq should be 0 and target_theta should be set. + target_update_freq=None, + target_theta=0.04, ), collect=dict(n_sample=96, ), eval=dict(evaluator=dict(eval_freq=4000, )), @@ -49,7 +56,7 @@ type='atari', import_names=['dizoo.atari.envs.atari_env'], ), - env_manager=dict(type='env_pool'), + env_manager=dict(type='envpool'), policy=dict(type='dqn'), replay_buffer=dict(type='deque'), ) diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_envpool_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_envpool_config.py new file mode 100644 index 0000000000..da56810f0c --- /dev/null +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_envpool_config.py @@ -0,0 +1,63 @@ +from easydict import EasyDict + +spaceinvaders_dqn_envpool_config = dict( + exp_name='spaceinvaders_dqn_envpool_seed0', + env=dict( + collector_env_num=8, + collector_batch_size=8, + evaluator_env_num=8, + evaluator_batch_size=8, + n_evaluator_episode=8, + stop_value=10000000000, + env_id='SpaceInvaders-v5', + #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make. + frame_stack=4, + ), + policy=dict( + cuda=True, + priority=False, + random_collect_size=5000, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=3, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + ), + collect=dict(n_sample=100, ), + eval=dict(evaluator=dict(eval_freq=4000, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=1000000, + ), + replay_buffer=dict(replay_buffer_size=400000, ), + ), + ), +) +spaceinvaders_dqn_envpool_config = EasyDict(spaceinvaders_dqn_envpool_config) +main_config = spaceinvaders_dqn_envpool_config +spaceinvaders_dqn_envpool_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='env_pool'), + policy=dict(type='dqn'), + replay_buffer=dict(type='deque'), +) +spaceinvaders_dqn_envpool_create_config = EasyDict(spaceinvaders_dqn_envpool_create_config) +create_config = spaceinvaders_dqn_envpool_create_config + +if __name__ == '__main__': + # or you can enter `ding -m serial -c spaceinvaders_dqn_envpool_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0)