diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index 11cccf0e13..a44b85e571 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -25,4 +25,4 @@ import serial_pipeline_preference_based_irl_onpolicy from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer from .serial_entry_bco import serial_pipeline_bco -from .serial_entry_pc import serial_pipeline_pc +from .serial_entry_pc import serial_pipeline_pc \ No newline at end of file diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 8e902f1504..95dbd46025 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -29,3 +29,4 @@ from .qgpo import QGPO from .ebm import EBM, AutoregressiveEBM from .havac import HAVAC +from .qtransformer import QTransformer diff --git a/ding/model/template/qtransformer.py b/ding/model/template/qtransformer.py new file mode 100644 index 0000000000..b6655a6351 --- /dev/null +++ b/ding/model/template/qtransformer.py @@ -0,0 +1,392 @@ +import copy +import math +import os +import time +import warnings +from functools import wraps +from os.path import exists +from typing import Callable, List, Optional, Tuple, Union + +import pandas as pd +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from packaging import version +from sympy import numer + +from torch import Tensor, einsum, nn +from torch.cuda.amp import autocast +from torch.nn import Module, ModuleList +from torch.nn.functional import log_softmax, pad +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data.distributed import DistributedSampler + + +class FiLM(nn.Module): + def __init__(self, in_features, out_features): + super(FiLM, self).__init__() + self.gamma = nn.Linear(in_features, out_features) + self.beta = nn.Linear(in_features, out_features) + + def forward(self, x, cond): + gamma = self.gamma(cond) + beta = self.beta(cond) + return gamma * x + beta + + +class EncoderDecoder(nn.Module): + """ + A standard Encoder-Decoder architecture. Base for this and many + other models. + """ + + def forward(self, src, tgt, src_mask, tgt_mask): + "Take in and process masked src and target sequences." + return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask) + + def encode(self, src, src_mask): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, memory, src_mask, tgt, tgt_mask): + return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) + + +# class Generator(nn.Module): +# "Define standard linear + softmax generation step." + +# def __init__(self, d_model, vocab): +# super(Generator, self).__init__() +# self.proj = nn.Linear(d_model, vocab) +# self.proj1 = nn.Linear(vocab, vocab) + +# def forward(self, x): +# x = self.proj(x) +# return x + + +def clones(module, N): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +class LayerNorm(nn.Module): + "Construct a layernorm module (See citation for details)." + + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + + +class Decoder(nn.Module): + "Generic N layer decoder with masking." + + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, tgt_mask): + for layer in self.layers: + x = layer(x, tgt_mask) + return self.norm(x) + + +class DecoderLayer(nn.Module): + "Decoder is made of self-attn, src-attn, and feed forward (defined below)" + + def __init__(self, size, self_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + + def forward(self, x, tgt_mask): + "Follow Figure 1 (right) for connections." + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + return self.sublayer[1](x, self.feed_forward) + + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8) + return subsequent_mask == 0 + + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + p_attn = scores.softmax(dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = [ + lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for lin, x in zip(self.linears, (query, key, value)) + ] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) + + # 3) "Concat" using a view and apply a final linear. + x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) + del query + del key + del value + return self.linears[-1](x) + + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(self.w_1(x).relu())) + + +class Embeddings(nn.Module): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + return self.lut(x) * math.sqrt(self.d_model) + + +class PositionalEncoding(nn.Module): + "Implement the PE function." + + def __init__(self, d_model, dropout, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[:, : x.size(1)].requires_grad_(False) + return self.dropout(x) + + +class stateEncode(nn.Module): + def __init__(self, num_timesteps, state_dim): + super().__init__() + self.fc1 = nn.Linear(num_timesteps * state_dim, 256) + self.fc2 = nn.Linear(256, 256) # Corrected the input size + self.fc3 = nn.Linear(256, 512) + + def forward(self, x): + batch_size = x.size(0) + # Reshape from (Batch, 8, 256) to (Batch, 2048) + x = x.reshape(batch_size, -1) + # Pass through the layers with activation functions + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x.unsqueeze(1) + + +class actionEncode(nn.Module): + def __init__(self, action_dim, action_bin): + super().__init__() + self.actionbin = action_bin + self.linear_layers = nn.ModuleList( + [nn.Linear(self.actionbin, 512) for _ in range(action_dim)] + ) + + def forward(self, x): + x = x.to(dtype=torch.float) + b, n, _ = x.shape + slices = torch.unbind(x, dim=1) + layer_outputs = torch.empty(b, n, 512, device=x.device) + for i, layer in enumerate(self.linear_layers[:n]): + slice_output = layer(slices[i]) + layer_outputs[:, i, :] = slice_output + return layer_outputs + + +class actionDecode(nn.Module): + def __init__(self, d_model, action_dim, action_bin): + super().__init__() + self.actionbin = action_bin + self.linear_layers = nn.ModuleList( + [nn.Linear(d_model, action_bin) for _ in range(action_dim)] + ) + + def forward(self, x): + x = x.to(dtype=torch.float) + b, n, _ = x.shape + slices = torch.unbind(x, dim=1) + layer_outputs = torch.empty(b, n, self.actionbin, device=x.device) + for i, layer in enumerate(self.linear_layers[:n]): + slice_output = layer(slices[i]) + layer_outputs[:, i, :] = slice_output + return layer_outputs + + +class actionDecode_with_relu(nn.Module): + def __init__(self, d_model, action_dim, action_bin, hidden_dim): + super().__init__() + self.actionbin = action_bin + self.hidden_dim = hidden_dim + self.linear_layers = nn.ModuleList( + [nn.Linear(d_model, hidden_dim) for _ in range(action_dim)] + ) + self.hidden_layers = nn.ModuleList( + [nn.Linear(hidden_dim, action_bin) for _ in range(action_dim)] + ) + self.activation = nn.ReLU() + + def forward(self, x): + x = x.to(dtype=torch.float) + b, n, _ = x.shape + slices = torch.unbind(x, dim=1) + layer_outputs = torch.empty(b, n, self.actionbin, device=x.device) + for i, (linear_layer, hidden_layer) in enumerate( + zip(self.linear_layers[:n], self.hidden_layers[:n]) + ): + slice_output = self.activation(linear_layer(slices[i])) + slice_output = hidden_layer(slice_output) + layer_outputs[:, i, :] = slice_output + return layer_outputs + + +class DecoderOnly(nn.Module): + def __init__(self, action_bin, N=8, d_model=512, d_ff=2048, h=8, dropout=0.1): + super(DecoderOnly, self).__init__() + c = copy.deepcopy + self_attn = MultiHeadedAttention(h, d_model, dropout) + feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) + self.position = PositionalEncoding(d_model, dropout) + self.model = Decoder( + DecoderLayer(d_model, c(self_attn), c(feed_forward), dropout), N + ) + # self.Generator = Generator(d_model, vocab=action_bin) + + def forward(self, x): + x = self.position(x) + x = self.model(x, subsequent_mask(x.size(1)).to(x.device)) + # x = self.Generator(x) + return x + + +class QTransformer(nn.Module): + def __init__(self, num_timesteps, state_dim, action_dim, action_bin): + super().__init__() + self.stateEncode = stateEncode(num_timesteps, state_dim) + self.actionEncode = actionEncode(action_dim, action_bin) + self.Transormer = DecoderOnly(action_bin) + self._action_bin = action_bin + self.actionDecode = actionDecode(512, action_dim, action_bin) + + def forward( + self, + state: Tensor, + action: Optional[Tensor] = None, + ): + stateEncode = self.stateEncode(state) + if action is not None: + action = torch.nn.functional.one_hot(action, num_classes=self._action_bin) + actionEncode = self.actionEncode(action) + res = self.Transormer(torch.cat((stateEncode, actionEncode), dim=1)) + return self.actionDecode(res) + res = self.Transormer(stateEncode) + return self.actionDecode(res) + + +# class QTransformerWithFiLM(nn.Module): +# def __init__(self, num_timesteps, state_dim, action_dim, action_bin): +# super().__init__() +# self.stateEncode = stateEncode(num_timesteps, state_dim) +# self.actionEncode = actionEncode(action_dim, action_bin) +# self.Transormer = DecoderOnly(action_bin) +# self._action_bin = action_bin +# self.actionDecode = actionDecode(512, action_dim, action_bin) + +# # Define FiLM layers +# self.film_state = FiLM(num_timesteps, 512) +# self.film_action = FiLM(num_timesteps, 512) + +# def forward(self, state: Tensor, action: Optional[Tensor] = None): +# stateEncode = self.stateEncode(state) +# seq_len = state.size(1) +# stateEncode = self.film_state( +# stateEncode, torch.tensor([seq_len], device=state.device) +# ) +# if action is not None: +# action = torch.nn.functional.one_hot(action, num_classes=self._action_bin) +# actionEncode = self.actionEncode(action) +# actionEncode = self.film_action( +# actionEncode, torch.tensor([seq_len], device=action.device) +# ) +# res = self.Transormer(torch.cat((stateEncode, actionEncode), dim=1)) +# return self.actionDecode(res) + +# res = self.Transormer(stateEncode) +# return self.actionDecode(res) diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 1f202da3bb..bd2b416902 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -19,6 +19,7 @@ from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy from .cql import CQLPolicy, DiscreteCQLPolicy +from .qtransformer import QTransformerPolicy from .edac import EDACPolicy from .impala import IMPALAPolicy from .ngu import NGUPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 2e817ead4b..6268384751 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -43,6 +43,7 @@ from .d4pg import D4PGPolicy from .cql import CQLPolicy, DiscreteCQLPolicy +from .qtransformer import QTransformerPolicy from .dt import DTPolicy from .pdqn import PDQNPolicy from .madqn import MADQNPolicy @@ -167,6 +168,7 @@ class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePo pass + @POLICY_REGISTRY.register('r2d3_command') class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy): pass @@ -325,6 +327,9 @@ class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy): class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy): pass +@POLICY_REGISTRY.register('qtransformer_command') +class QtransformerCommandModePolicy(QTransformerPolicy): + pass @POLICY_REGISTRY.register('dt_command') class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy): diff --git a/ding/policy/qtransformer.py b/ding/policy/qtransformer.py new file mode 100644 index 0000000000..ac637dd7ba --- /dev/null +++ b/ding/policy/qtransformer.py @@ -0,0 +1,659 @@ +import copy +from copy import deepcopy +from typing import Any, Dict, List, Union + +import numpy as np +import torch +import torch.nn.functional as F +from easydict import EasyDict + +import wandb + +# from einops import pack, rearrange +from ding.model import model_wrap +from ding.torch_utils import Adam, to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_collate, default_decollate + +from .common_utils import default_preprocess_learn +from .sac import SACPolicy + + +@POLICY_REGISTRY.register("qtransformer") +class QTransformerPolicy(SACPolicy): + """ + Overview: + Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779. + + Config: + == ==================== ======== ============= ================================= ======================= + ID Symbol Type Default Value Description Other(Shape) + == ==================== ======== ============= ================================= ======================= + 1 ``type`` str | RL policy register name, refer | this arg is optional, + | to registry ``POLICY_REGISTRY`` | a placeholder + 2 ``cuda`` bool True | Whether to use cuda for network | + 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for + | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/ + | | buffer when training starts. | TD3. + 4 | ``model.policy_`` int 256 | Linear layer size for policy | + | ``embedding_size`` | network. | + 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q | + | ``embedding_size`` | network. | + 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when + | ``embedding_size`` | network. | model.value_network + | | | is False. + 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when + | ``_rate_q`` | network. | model.value_network + | | | is True. + 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when + | ``_rate_policy`` | network. | model.value_network + | | | is True. + 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when + | ``_rate_value`` | network. | model.value_network + | | | is False. + 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali- + | | coefficient. | zation for auto + | | | `alpha`, when + | | | auto_alpha is True + 11 | ``learn.repara_`` bool True | Determine whether to use | + | ``meterization`` | reparameterization trick. | + 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter + | ``auto_alpha`` | auto temperature parameter | determines the + | | `alpha`. | relative importance + | | | of the entropy term + | | | against the reward. + 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only + | ``ignore_done`` | done flag. | in halfcheetah env. + 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation + | ``target_theta`` | target network. | factor in polyak aver + | | | aging for target + | | | networks. + == ==================== ======== ============= ================================= ======================= + """ + + config = dict( + # (str) RL policy register name (refer to function "POLICY_REGISTRY"). + type="qtransformer", + # (bool) Whether to use cuda for policy. + cuda=True, + # (bool) on_policy: Determine whether on-policy or off-policy. + # on-policy setting influences the behaviour of buffer. + on_policy=False, + # (bool) priority: Determine whether to use priority in buffer sample. + priority=False, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + # (int) Number of training samples(randomly collected) in replay buffer when training starts. + random_collect_size=10000, + model=dict( + # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation. + # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one . + # Default to True. + twin_critic=True, + # (str type) action_space: Use reparameterization trick for continous action + action_space="reparameterization", + # (int) Hidden size for actor network head. + actor_head_hidden_size=256, + # (int) Hidden size for critic network head. + critic_head_hidden_size=256, + ), + # learn_mode config + learn=dict( + # (int) How many updates (iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + update_per_collect=1, + # (int) Minibatch size for gradient descent. + batch_size=256, + # (float) learning_rate_q: Learning rate for soft q network. + learning_rate_q=3e-4, + # (float) learning_rate_policy: Learning rate for policy network. + learning_rate_policy=3e-4, + # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``. + learning_rate_alpha=3e-4, + # (float) target_theta: Used for soft update of the target network, + # aka. Interpolation factor in polyak averaging for target networks. + target_theta=0.005, + # (float) discount factor for the discounted sum of rewards, aka. gamma. + discount_factor=0.99, + # (float) alpha: Entropy regularization coefficient. + # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. + # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`. + # Default to 0.2. + alpha=0.2, + # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` . + # Temperature parameter determines the relative importance of the entropy term against the reward. + # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details. + # Default to False. + # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`. + auto_alpha=True, + # (bool) log_space: Determine whether to use auto `\alpha` in log space. + log_space=True, + # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. + ignore_done=False, + # (float) Weight uniform initialization range in the last output layer. + init_w=3e-3, + # (int) The numbers of action sample each at every state s from a uniform-at-random. + num_actions=10, + # (bool) Whether use lagrange multiplier in q value loss. + with_lagrange=False, + # (float) The threshold for difference in Q-values. + lagrange_thresh=-1, + # (float) Loss weight for conservative item. + min_q_weight=1.0, + # (bool) Whether to use entropy in target q. + with_q_entropy=False, + ), + eval=dict(), # for compatibility + ) + + def _init_learn(self) -> None: + """ + Overview: + Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \ + contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \ + with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \ + target is also initialized here. + This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``. + + .. note:: + For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \ + and ``_load_state_dict_learn`` methods. + + .. note:: + For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method. + + .. note:: + If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \ + with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``. + """ + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + self._num_actions = self._cfg.learn.num_actions + + self._min_q_version = 3 + self._min_q_weight = self._cfg.learn.min_q_weight + self._with_lagrange = self._cfg.learn.with_lagrange and ( + self._lagrange_thresh > 0 + ) + self._lagrange_thresh = self._cfg.learn.lagrange_thresh + if self._with_lagrange: + self.target_action_gap = self._lagrange_thresh + self.log_alpha_prime = torch.tensor(0.0).to(self._device).requires_grad_() + self.alpha_prime_optimizer = Adam( + [self.log_alpha_prime], + lr=self._cfg.learn.learning_rate_q, + ) + + self._with_q_entropy = self._cfg.learn.with_q_entropy + # Optimizers + self._optimizer_q = Adam( + self._model.parameters(), + lr=self._cfg.learn.learning_rate_q, + ) + + # Algorithm config + self._gamma = self._cfg.learn.discount_factor + self._action_dim = self._cfg.model.action_dim + + # Init auto alpha + if self._cfg.learn.auto_alpha: + if self._cfg.learn.target_entropy is None: + assert ( + "action_shape" in self._cfg.model + ), "CQL need network model with action_shape variable" + self._target_entropy = -np.prod(self._cfg.model.action_shape) + else: + self._target_entropy = self._cfg.learn.target_entropy + if self._cfg.learn.log_space: + self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha])) + self._log_alpha = self._log_alpha.to(self._device).requires_grad_() + self._alpha_optim = torch.optim.Adam( + [self._log_alpha], lr=self._cfg.learn.learning_rate_alpha + ) + assert ( + self._log_alpha.shape == torch.Size([1]) + and self._log_alpha.requires_grad + ) + self._alpha = self._log_alpha.detach().exp() + self._auto_alpha = True + self._log_space = True + else: + self._alpha = ( + torch.FloatTensor([self._cfg.learn.alpha]) + .to(self._device) + .requires_grad_() + ) + self._alpha_optim = torch.optim.Adam( + [self._alpha], lr=self._cfg.learn.learning_rate_alpha + ) + self._auto_alpha = True + self._log_space = False + else: + self._alpha = torch.tensor( + [self._cfg.learn.alpha], + requires_grad=False, + device=self._device, + dtype=torch.float32, + ) + self._auto_alpha = False + for p in self._model.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name="target", + update_type="momentum", + update_kwargs={"theta": self._cfg.learn.target_theta}, + ) + + self._action_bin = self._cfg.model.action_bin + + self._action_values = np.array( + [ + np.linspace(min_val, max_val, self._action_bin) + for min_val, max_val in zip( + np.full(self._cfg.model.action_dim, -1), + np.full(self._cfg.model.action_dim, 1), + ) + ] + ) + # Main and target models + self._learn_model = model_wrap(self._model, wrapper_name="base") + self._learn_model.reset() + self._target_model.reset() + + self._forward_learn_cnt = 0 + wandb.init(**self._cfg.wandb) + + def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, action, priority. + Arguments: + - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \ + training samples. For each element in list, the key of the dict is the name of data items and the \ + value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \ + combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \ + dimension by some utility functions such as ``default_preprocess_learn``. \ + For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \ + ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + """ + + def merge_dict1_into_dict2( + dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict] + ) -> Union[Dict, EasyDict]: + """ + Overview: + Merge two dictionaries recursively. \ + Update values in dict2 with values in dict1, and add new keys from dict1 to dict2. + Arguments: + - dict1 (:obj:`dict`): The first dictionary. + - dict2 (:obj:`dict`): The second dictionary. + """ + for key, value in dict1.items(): + if ( + key in dict2 + and isinstance(value, dict) + and isinstance(dict2[key], dict) + ): + # Both values are dictionaries, so merge them recursively + merge_dict1_into_dict2(value, dict2[key]) + else: + # Either the key doesn't exist in dict2 or the values are not dictionaries + dict2[key] = value + + return dict2 + + def merge_two_dicts_into_newone( + dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict] + ) -> Union[Dict, EasyDict]: + """ + Overview: + Merge two dictionaries recursively into a new dictionary. \ + Update values in dict2 with values in dict1, and add new keys from dict1 to dict2. + Arguments: + - dict1 (:obj:`dict`): The first dictionary. + - dict2 (:obj:`dict`): The second dictionary. + """ + dict2 = deepcopy(dict2) + return merge_dict1_into_dict2(dict1, dict2) + + config = merge_two_dicts_into_newone(EasyDict(wandb.config), self._cfg) + wandb.config.update(config) + data = default_preprocess_learn( + data, + use_priority=self._priority, + use_priority_IS_weight=self._cfg.priority_IS_weight, + ignore_done=self._cfg.learn.ignore_done, + use_nstep=False, + ) + + def discretization(x): + self._action_values = torch.tensor(self._action_values) + indices = torch.zeros_like(x, dtype=torch.long, device=x.device) + for i in range(x.shape[1]): + diff = (x[:, i].unsqueeze(-1) - self._action_values[i, :]) ** 2 + indices[:, i] = diff.argmin(dim=-1) + return indices + + data["action"] = discretization(data["action"]) + + if self._cuda: + data = to_device(data, self._device) + + self._learn_model.train() + self._target_model.train() + + state = data["obs"] + next_state = data["next_obs"] # torch.Size([2048, 17]) + reward = data["reward"] # torch.Size([2048]) + done = data["done"] # torch.Size([2048]) + action = data["action"] # torch.Size([2048, 6]) + + q_pred_all_actions = self._learn_model.forward(state, action=action)[:, :-1, :] + # torch.Size([2048, 6, 256]) + + def batch_select_indices(t, indices): + indices = indices.unsqueeze(-1) + selected = t.gather(-1, indices) + selected = selected.squeeze(-1) + return selected + + # torch.Size([2048, 6]) + q_pred = batch_select_indices(q_pred_all_actions, action) + # Create the dataset action mask and set selected values to 1 + # dataset_action_mask = torch.zeros_like(q_pred_all_actions).scatter_( + # -1, action.unsqueeze(-1), 1 + # ) + # q_actions_not_taken = q_pred_all_actions[~dataset_action_mask.bool()] + # num_non_dataset_actions = q_actions_not_taken.size(0) // q_pred.size(0) + # conservative_loss = ( + # (q_actions_not_taken - (0)) ** 2 + # ).sum() / num_non_dataset_actions + # Iterate over each row in the action tensor + + q_pred_rest_actions, q_pred_last_action = q_pred[:, :-1], q_pred[:, -1:] + with torch.no_grad(): + # q_next_target = self._target_model.forward(next_state) + q_target = self._target_model.forward(state, action=action)[:, :-1, :] + + q_target_rest_actions = q_target[:, 1:, :] + max_q_target_rest_actions = q_target_rest_actions.max(dim=-1).values + + # q_next_target_first_action = q_next_target[:, 0:1, :] + # max_q_next_target_first_action = q_next_target_first_action.max(dim=-1).values + + losses_all_actions_but_last = F.mse_loss( + q_pred_rest_actions, max_q_target_rest_actions + ) + q_target_last_action = next_state.unsqueeze(1) + losses_last_action = F.mse_loss(q_pred_last_action, q_target_last_action) + td_loss = losses_all_actions_but_last + losses_last_action + td_loss.mean() + loss = td_loss + self._optimizer_q.zero_grad() + loss.backward() + self._optimizer_q.step() + self._forward_learn_cnt += 1 + self._target_model.update(self._learn_model.state_dict()) + + split_tensors = q_pred_all_actions.chunk(6, dim=1) + q_means = [tensor.mean() for tensor in split_tensors] + split_tensors_r = q_pred.chunk(6, dim=1) + q_r_means = [tensor.mean() for tensor in split_tensors_r] + wandb.log( + { + "td_loss": td_loss.item(), + "losses_all_actions_but_last": losses_all_actions_but_last.item(), + "losses_last_action": losses_last_action.item(), + "q_mean": q_pred_all_actions.mean().item(), + "q_a1": q_means[0].item(), + "q_a2": q_means[1].item(), + "q_a3": q_means[2].item(), + "q_a4": q_means[3].item(), + "q_a5": q_means[4].item(), + "q_a6": q_means[5].item(), + "q_r_a1": q_r_means[0].item(), + "q_r_a2": q_r_means[1].item(), + "q_r_a3": q_r_means[2].item(), + "q_r_a4": q_r_means[3].item(), + "q_r_a5": q_r_means[4].item(), + "q_r_a6": q_r_means[5].item(), + "q_all": q_pred_all_actions.mean().item(), + "q_real": q_pred.mean().item(), + "mc": next_state.mean().item(), + }, + ) + return { + "td_error": loss.item(), + "policy_loss": q_pred_all_actions.mean().item(), + } + + def _get_actions(self, obs, eval=False, epsilon=0.1): + import random + + action_bins = None + if eval or random.random() > epsilon: + action_bins = torch.full( + (obs.size(0), self._action_dim), -1, dtype=torch.long, device=obs.device + ) + for action_idx in range(self._action_dim): + if action_idx == 0: + q_values = self._eval_model.forward(obs) + else: + q_values_all = self._eval_model.forward( + obs, action=action_bins[:, :action_idx] + ) + q_values = q_values_all[:, action_idx : action_idx + 1, :] + selected_action_bins = q_values.argmax(dim=-1) + action_bins[:, action_idx] = selected_action_bins.squeeze() + else: + action_bins = torch.randint( + 0, self._action_bin, (obs.size(0), self._action_dim), device=obs.device + ) + action = 2.0 * action_bins.float() / (1.0 * self._action_bin) - 1.0 + wandb.log( + { + "action/action_mean": action.mean().item(), + "action/action_max": action.max().item(), + "action/action_min": action.min().item(), + } + ) + return action + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \ + as text logger, tensorboard logger, will use these keys to save the corresponding data. + Returns: + - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged. + """ + return [ + "value_loss" "alpha_loss", + "policy_loss", + "critic_loss", + "cur_lr_q", + "cur_lr_p", + "target_q_value", + "alpha", + "td_error", + "transformed_log_prob", + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizers. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + ret = { + "model": self._learn_model.state_dict(), + "target_model": self._target_model.state_dict(), + "optimizer_q": self._optimizer_q.state_dict(), + } + if self._auto_alpha: + ret.update({"optimizer_alpha": self._alpha_optim.state_dict()}) + return ret + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + + .. tip:: + If you want to only load some parts of model, you can simply set the ``strict`` argument in \ + load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ + complicated operation. + """ + self._learn_model.load_state_dict(state_dict["model"]) + self._target_model.load_state_dict(state_dict["ema_model"]) + self._optimizer_q.load_state_dict(state_dict["optimizer_q"]) + if self._auto_alpha: + self._alpha_optim.load_state_dict(state_dict["optimizer_alpha"]) + + def _init_eval(self) -> None: + self._eval_model = model_wrap(self._model, wrapper_name="base") + self._eval_model.reset() + + def _forward_eval_offline(self, data: dict, the_time, **policy_kwargs) -> dict: + r""" + Overview: + Forward function of eval mode, similar to ``self._forward_collect``. + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ + values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + Returns: + - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. + ReturnsKeys + - necessary: ``action`` + """ + data_id = list(data.keys()) + expected_ids = list(range(self._cfg.model.num_timesteps)) + missing_ids = [i for i in expected_ids if i not in data_id] + for missing_id in missing_ids: + data[missing_id] = torch.zeros_like(input=next(iter(data.values()))) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + if the_time == 0: + self._state_list = data.unsqueeze(1).expand( + -1, self._cfg.model.num_timesteps, -1 + ) + else: + self._state_list = self._state_list[:, 1:, :] + # Insert the new data at the last position + self._state_list = torch.cat((self._state_list, data.unsqueeze(1)), dim=1) + with torch.no_grad(): + output = self._get_actions(self._state_list) + if self._cuda: + output = to_device(output, "cpu") + output = default_decollate(output) + output = [{"action": o} for o in output] + return {i: d for i, d in zip(data_id, output)} + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + r""" + Overview: + Forward function of eval mode, similar to ``self._forward_collect``. + Arguments: + - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \ + values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer. + Returns: + - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env. + ReturnsKeys + - necessary: ``action`` + """ + data_id = list(data.keys()) + expected_ids = list(range(self._cfg.model.num_timesteps)) + missing_ids = [i for i in expected_ids if i not in data_id] + for missing_id in missing_ids: + data[missing_id] = torch.zeros_like(input=next(iter(data.values()))) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._eval_model.eval() + with torch.no_grad(): + output = self._get_actions(data, eval=True) + if self._cuda: + output = to_device(output, "cpu") + output = default_decollate(output) + output = [{"action": o} for o in output] + return {i: d for i, d in zip(data_id, output)} + + def _init_collect(self) -> None: + """ + Overview: + Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \ + collect_model other algorithm-specific arguments such as unroll_len. \ + This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``. + + .. note:: + If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \ + with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``. + """ + self._unroll_len = self._cfg.collect.unroll_len + self._collect_model = model_wrap(self._model, wrapper_name="base") + self._collect_model.reset() + + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + """ + Overview: + Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \ + that the policy gets some necessary data (mainly observation) from the envs and then returns the output \ + data, such as the action to interact with the envs. + Arguments: + - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \ + key of the dict is environment id and the value is the corresponding data of the env. + Returns: + - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \ + other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \ + dict is the same as the input data, i.e. environment id. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + .. note:: + ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency. + + .. note:: + For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``. + """ + data_id = list(data.keys()) + data = default_collate(list(data.values())) + if self._cuda: + data = to_device(data, self._device) + self._collect_model.eval() + with torch.no_grad(): + output = self._get_actions(data, eval=False) + if self._cuda: + output = to_device(output, "cpu") + output = default_decollate(output) + output = [{"action": o} for o in output] + return {i: d for i, d in zip(data_id, output)} diff --git a/ding/worker/collector/episode_serial_collector.py b/ding/worker/collector/episode_serial_collector.py index 6fca2283f8..945792ce51 100644 --- a/ding/worker/collector/episode_serial_collector.py +++ b/ding/worker/collector/episode_serial_collector.py @@ -7,10 +7,16 @@ from ding.envs import BaseEnvManager from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY from ding.torch_utils import to_tensor, to_ndarray -from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions +from .base_serial_collector import ( + ISerialCollector, + CachePool, + TrajBuffer, + INF, + to_tensor_transitions, +) -@SERIAL_COLLECTOR_REGISTRY.register('episode') +@SERIAL_COLLECTOR_REGISTRY.register("episode") class EpisodeSerialCollector(ISerialCollector): """ Overview: @@ -22,17 +28,21 @@ class EpisodeSerialCollector(ISerialCollector): """ config = dict( - deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False, reward_shaping=False + deepcopy_obs=False, + transform_obs=False, + collect_print_freq=100, + get_train_sample=False, + reward_shaping=False, ) def __init__( - self, - cfg: EasyDict, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector' + self, + cfg: EasyDict, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: "SummaryWriter" = None, # noqa + exp_name: Optional[str] = "default_experiment", + instance_name: Optional[str] = "collector", ) -> None: """ Overview: @@ -54,12 +64,15 @@ def __init__( if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path="./{}/log/{}".format(self._exp_name, self._instance_name), + name=self._instance_name, + need_tb=False, ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path="./{}/log/{}".format(self._exp_name, self._instance_name), + name=self._instance_name, ) self.reset(policy, env) @@ -90,22 +103,26 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: Arguments: - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, "_env"), "please set env first" if _policy is not None: self._policy = _policy - self._policy_cfg = self._policy.get_attribute('cfg') - self._default_n_episode = _policy.get_attribute('n_episode') - self._unroll_len = _policy.get_attribute('unroll_len') - self._on_policy = _policy.get_attribute('on_policy') + self._policy_cfg = self._policy.get_attribute("cfg") + self._default_n_episode = _policy.get_attribute("n_episode") + self._unroll_len = _policy.get_attribute("unroll_len") + self._on_policy = _policy.get_attribute("on_policy") self._traj_len = INF self._logger.debug( - 'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format( + "Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))".format( self._default_n_episode, self._env_num, self._traj_len ) ) self._policy.reset() - def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + def reset( + self, + _policy: Optional[namedtuple] = None, + _env: Optional[BaseEnvManager] = None, + ) -> None: """ Overview: Reset the environment and policy. @@ -124,11 +141,15 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana if _policy is not None: self.reset_policy(_policy) - self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) - self._policy_output_pool = CachePool('policy_output', self._env_num) + self._obs_pool = CachePool("obs", self._env_num, deepcopy=self._deepcopy_obs) + self._policy_output_pool = CachePool("policy_output", self._env_num) # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions - self._traj_buffer = {env_id: TrajBuffer(maxlen=self._traj_len) for env_id in range(self._env_num)} - self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + self._traj_buffer = { + env_id: TrajBuffer(maxlen=self._traj_len) for env_id in range(self._env_num) + } + self._env_info = { + env_id: {"time": 0.0, "step": 0} for env_id in range(self._env_num) + } self._episode_info = [] self._total_envstep_count = 0 @@ -149,7 +170,7 @@ def _reset_stat(self, env_id: int) -> None: self._traj_buffer[env_id].clear() self._obs_pool.reset(env_id) self._policy_output_pool.reset(env_id) - self._env_info[env_id] = {'time': 0., 'step': 0} + self._env_info[env_id] = {"time": 0.0, "step": 0} @property def envstep(self) -> int: @@ -182,10 +203,12 @@ def __del__(self) -> None: """ self.close() - def collect(self, - n_episode: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None) -> List[Any]: + def collect( + self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + ) -> List[Any]: """ Overview: Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations @@ -202,7 +225,9 @@ def collect(self, raise RuntimeError("Please specify collect n_episode") else: n_episode = self._default_n_episode - assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) + assert ( + n_episode >= self._env_num + ), "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) if policy_kwargs is None: policy_kwargs = {} collected_episode = 0 @@ -215,7 +240,9 @@ def collect(self, # Get current env obs. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + ready_env_id = ready_env_id.union( + set(list(new_available_env_id)[:remain_episode]) + ) remain_episode -= min(len(new_available_env_id), remain_episode) obs = {env_id: obs[env_id] for env_id in ready_env_id} # Policy forward. @@ -225,7 +252,9 @@ def collect(self, policy_output = self._policy.forward(obs, **policy_kwargs) self._policy_output_pool.update(policy_output) # Interact with env. - actions = {env_id: output['action'] for env_id, output in policy_output.items()} + actions = { + env_id: output["action"] for env_id, output in policy_output.items() + } actions = to_ndarray(actions) timesteps = self._env.step(actions) @@ -235,25 +264,33 @@ def collect(self, # TODO(nyz) vectorize this for loop for env_id, timestep in timesteps.items(): with self._timer: - if timestep.info.get('abnormal', False): + if timestep.info.get("abnormal", False): # If there is an abnormal timestep, reset all the related variables(including this env). # suppose there is no reset param, just reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) + self._logger.info( + "Env{} returns a abnormal step, its info is {}".format( + env_id, timestep.info + ) + ) continue transition = self._policy.process_transition( - self._obs_pool[env_id], self._policy_output_pool[env_id], timestep + self._obs_pool[env_id], + self._policy_output_pool[env_id], + timestep, ) # ``train_iter`` passed in from ``serial_entry``, indicates current collecting model's iteration. - transition['collect_iter'] = train_iter + transition["collect_iter"] = train_iter self._traj_buffer[env_id].append(transition) - self._env_info[env_id]['step'] += 1 + self._env_info[env_id]["step"] += 1 self._total_envstep_count += 1 # prepare data if timestep.done: - transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs) + transitions = to_tensor_transitions( + self._traj_buffer[env_id], not self._deepcopy_obs + ) if self._cfg.reward_shaping: self._env.reward_shaping(env_id, transitions) if self._cfg.get_train_sample: @@ -263,16 +300,18 @@ def collect(self, return_data.append(transitions) self._traj_buffer[env_id].clear() - self._env_info[env_id]['time'] += self._timer.value + interaction_duration + self._env_info[env_id]["time"] += ( + self._timer.value + interaction_duration + ) # If env is done, record episode info and reset if timestep.done: self._total_episode_count += 1 - reward = timestep.info['eval_episode_return'] + reward = timestep.info["eval_episode_return"] info = { - 'reward': reward, - 'time': self._env_info[env_id]['time'], - 'step': self._env_info[env_id]['step'], + "reward": reward, + "time": self._env_info[env_id]["time"], + "step": self._env_info[env_id]["step"], } collected_episode += 1 self._episode_info.append(info) @@ -283,7 +322,33 @@ def collect(self, break # log self._output_log(train_iter) - return return_data + + def calculate_mc_returns(collected_episodes, gamma=0.99): + flattened_data = [] + + def calculate_mc_return(episode, gamma=0.99): + G = 0 + for step in reversed(episode): + obs = step["obs"] + reward = step["reward"] + G = reward + gamma * G + flattened_data.append( + { + "obs": obs, + "action": step["action"], + "reward": reward, + "next_obs": G, + "done": step["done"], + } + ) + + for episode in collected_episodes: + calculate_mc_return(episode, gamma) + + return flattened_data + + collected_episodes = calculate_mc_returns(return_data) + return collected_episodes def _output_log(self, train_iter: int) -> None: """ @@ -293,35 +358,47 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): the number of training iteration. """ - if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len( + self._episode_info + ) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) - envstep_count = sum([d['step'] for d in self._episode_info]) - duration = sum([d['time'] for d in self._episode_info]) - episode_return = [d['reward'] for d in self._episode_info] + envstep_count = sum([d["step"] for d in self._episode_info]) + duration = sum([d["time"] for d in self._episode_info]) + episode_return = [d["reward"] for d in self._episode_info] self._total_duration += duration info = { - 'episode_count': episode_count, - 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, - 'collect_time': duration, - 'reward_mean': np.mean(episode_return), - 'reward_std': np.std(episode_return), - 'reward_max': np.max(episode_return), - 'reward_min': np.min(episode_return), - 'total_envstep_count': self._total_envstep_count, - 'total_episode_count': self._total_episode_count, - 'total_duration': self._total_duration, + "episode_count": episode_count, + "envstep_count": envstep_count, + "avg_envstep_per_episode": envstep_count / episode_count, + "avg_envstep_per_sec": envstep_count / duration, + "avg_episode_per_sec": episode_count / duration, + "collect_time": duration, + "reward_mean": np.mean(episode_return), + "reward_std": np.std(episode_return), + "reward_max": np.max(episode_return), + "reward_min": np.min(episode_return), + "total_envstep_count": self._total_envstep_count, + "total_episode_count": self._total_episode_count, + "total_duration": self._total_duration, # 'each_reward': episode_return, } self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + self._logger.info( + "collect end:\n{}".format( + "\n".join(["{}: {}".format(k, v) for k, v in info.items()]) + ) + ) for k, v in info.items(): - if k in ['each_reward']: + if k in ["each_reward"]: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - if k in ['total_envstep_count']: + self._tb_logger.add_scalar( + "{}_iter/".format(self._instance_name) + k, v, train_iter + ) + if k in ["total_envstep_count"]: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + self._tb_logger.add_scalar( + "{}_step/".format(self._instance_name) + k, + v, + self._total_envstep_count, + ) diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index 26db458edb..3b71f0676f 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -6,13 +6,27 @@ import torch from ding.envs import BaseEnvManager -from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \ - broadcast_object_list, allreduce_data +from ding.utils import ( + build_logger, + EasyTimer, + SERIAL_COLLECTOR_REGISTRY, + one_time_warning, + get_rank, + get_world_size, + broadcast_object_list, + allreduce_data, +) from ding.torch_utils import to_tensor, to_ndarray -from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions +from .base_serial_collector import ( + ISerialCollector, + CachePool, + TrajBuffer, + INF, + to_tensor_transitions, +) -@SERIAL_COLLECTOR_REGISTRY.register('sample') +@SERIAL_COLLECTOR_REGISTRY.register("sample") class SampleSerialCollector(ISerialCollector): """ Overview: @@ -28,13 +42,13 @@ class SampleSerialCollector(ISerialCollector): config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100) def __init__( - self, - cfg: EasyDict, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector' + self, + cfg: EasyDict, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: "SummaryWriter" = None, # noqa + exp_name: Optional[str] = "default_experiment", + instance_name: Optional[str] = "collector", ) -> None: """ Overview: @@ -59,18 +73,21 @@ def __init__( if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), + path="./{}/log/{}".format(self._exp_name, self._instance_name), name=self._instance_name, - need_tb=False + need_tb=False, ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path="./{}/log/{}".format(self._exp_name, self._instance_name), + name=self._instance_name, ) else: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path="./{}/log/{}".format(self._exp_name, self._instance_name), + name=self._instance_name, + need_tb=False, ) self._tb_logger = None @@ -103,21 +120,22 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: Arguments: - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, "_env"), "please set env first" if _policy is not None: self._policy = _policy - self._policy_cfg = self._policy.get_attribute('cfg') - self._default_n_sample = _policy.get_attribute('n_sample') + self._policy_cfg = self._policy.get_attribute("cfg") + self._default_n_sample = _policy.get_attribute("n_sample") self._traj_len_inf = self._policy_cfg.traj_len_inf - self._unroll_len = _policy.get_attribute('unroll_len') - self._on_policy = _policy.get_attribute('on_policy') + self._unroll_len = _policy.get_attribute("unroll_len") + self._on_policy = _policy.get_attribute("on_policy") if self._default_n_sample is not None and not self._traj_len_inf: self._traj_len = max( self._unroll_len, - self._default_n_sample // self._env_num + int(self._default_n_sample % self._env_num != 0) + self._default_n_sample // self._env_num + + int(self._default_n_sample % self._env_num != 0), ) self._logger.debug( - 'Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))'.format( + "Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))".format( self._default_n_sample, self._env_num, self._traj_len ) ) @@ -125,7 +143,11 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: self._traj_len = INF self._policy.reset() - def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + def reset( + self, + _policy: Optional[namedtuple] = None, + _env: Optional[BaseEnvManager] = None, + ) -> None: """ Overview: Reset the environment and policy. @@ -144,18 +166,21 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana if _policy is not None: self.reset_policy(_policy) - if self._policy_cfg.type == 'dreamer_command': + if self._policy_cfg.type == "dreamer_command": self._states = None self._resets = np.array([False for i in range(self._env_num)]) - self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs) - self._policy_output_pool = CachePool('policy_output', self._env_num) + self._obs_pool = CachePool("obs", self._env_num, deepcopy=self._deepcopy_obs) + self._policy_output_pool = CachePool("policy_output", self._env_num) # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions maxlen = self._traj_len if self._traj_len != INF else None self._traj_buffer = { env_id: TrajBuffer(maxlen=maxlen, deepcopy=self._deepcopy_obs) for env_id in range(self._env_num) } - self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)} + self._env_info = { + env_id: {"time": 0.0, "step": 0, "train_sample": 0} + for env_id in range(self._env_num) + } self._episode_info = [] self._total_envstep_count = 0 @@ -177,7 +202,7 @@ def _reset_stat(self, env_id: int) -> None: self._traj_buffer[env_id].clear() self._obs_pool.reset(env_id) self._policy_output_pool.reset(env_id) - self._env_info[env_id] = {'time': 0., 'step': 0, 'train_sample': 0} + self._env_info[env_id] = {"time": 0.0, "step": 0, "train_sample": 0} @property def envstep(self) -> int: @@ -212,14 +237,14 @@ def __del__(self) -> None: self.close() def collect( - self, - n_sample: Optional[int] = None, - train_iter: int = 0, - drop_extra: bool = True, - random_collect: bool = False, - record_random_collect: bool = True, - policy_kwargs: Optional[dict] = None, - level_seeds: Optional[List] = None, + self, + n_sample: Optional[int] = None, + train_iter: int = 0, + drop_extra: bool = True, + random_collect: bool = False, + record_random_collect: bool = True, + policy_kwargs: Optional[dict] = None, + level_seeds: Optional[List] = None, ) -> List[Any]: """ Overview: @@ -242,8 +267,10 @@ def collect( n_sample = self._default_n_sample if n_sample % self._env_num != 0: one_time_warning( - "Please make sure env_num is divisible by n_sample: {}/{}, ".format(n_sample, self._env_num) + - "which may cause convergence problems in a few algorithms" + "Please make sure env_num is divisible by n_sample: {}/{}, ".format( + n_sample, self._env_num + ) + + "which may cause convergence problems in a few algorithms" ) if policy_kwargs is None: policy_kwargs = {} @@ -253,6 +280,7 @@ def collect( return_data = [] while collected_sample < n_sample: + episode_data = [] with self._timer: # Get current env obs. obs = self._env.ready_obs @@ -260,15 +288,21 @@ def collect( self._obs_pool.update(obs) if self._transform_obs: obs = to_tensor(obs, dtype=torch.float32) - if self._policy_cfg.type == 'dreamer_command' and not random_collect: - policy_output = self._policy.forward(obs, **policy_kwargs, reset=self._resets, state=self._states) - #self._states = {env_id: output['state'] for env_id, output in policy_output.items()} - self._states = [output['state'] for output in policy_output.values()] + if self._policy_cfg.type == "dreamer_command" and not random_collect: + policy_output = self._policy.forward( + obs, **policy_kwargs, reset=self._resets, state=self._states + ) + # self._states = {env_id: output['state'] for env_id, output in policy_output.items()} + self._states = [ + output["state"] for output in policy_output.values() + ] else: policy_output = self._policy.forward(obs, **policy_kwargs) self._policy_output_pool.update(policy_output) # Interact with env. - actions = {env_id: output['action'] for env_id, output in policy_output.items()} + actions = { + env_id: output["action"] for env_id, output in policy_output.items() + } actions = to_ndarray(actions) timesteps = self._env.step(actions) @@ -278,33 +312,48 @@ def collect( # TODO(nyz) vectorize this for loop for env_id, timestep in timesteps.items(): with self._timer: - if timestep.info.get('abnormal', False): + if timestep.info.get("abnormal", False): # If there is an abnormal timestep, reset all the related variables(including this env). # suppose there is no reset param, just reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info)) + self._logger.info( + "Env{} returns a abnormal step, its info is {}".format( + env_id, timestep.info + ) + ) continue - if self._policy_cfg.type == 'dreamer_command' and not random_collect: + if ( + self._policy_cfg.type == "dreamer_command" + and not random_collect + ): self._resets[env_id] = timestep.done - if self._policy_cfg.type == 'ngu_command': # for NGU policy + if self._policy_cfg.type == "ngu_command": # for NGU policy transition = self._policy.process_transition( - self._obs_pool[env_id], self._policy_output_pool[env_id], timestep, env_id + self._obs_pool[env_id], + self._policy_output_pool[env_id], + timestep, + env_id, ) else: transition = self._policy.process_transition( - self._obs_pool[env_id], self._policy_output_pool[env_id], timestep + self._obs_pool[env_id], + self._policy_output_pool[env_id], + timestep, ) if level_seeds is not None: - transition['seed'] = level_seeds[env_id] + transition["seed"] = level_seeds[env_id] # ``train_iter`` passed in from ``serial_entry``, indicates current collecting model's iteration. - transition['collect_iter'] = train_iter + transition["collect_iter"] = train_iter self._traj_buffer[env_id].append(transition) - self._env_info[env_id]['step'] += 1 + self._env_info[env_id]["step"] += 1 collected_step += 1 # prepare data - if timestep.done or len(self._traj_buffer[env_id]) == self._traj_len: + if ( + timestep.done + or len(self._traj_buffer[env_id]) == self._traj_len + ): # If policy is r2d2: # 1. For each collect_env, we want to collect data of length self._traj_len=INF # unless the episode enters the 'done' state. @@ -317,43 +366,51 @@ def collect( # Episode is done or traj_buffer(maxlen=traj_len) is full. # indicate whether to shallow copy next obs, i.e., overlap of s_t and s_t+1 - transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs) + transitions = to_tensor_transitions( + self._traj_buffer[env_id], not self._deepcopy_obs + ) train_sample = self._policy.get_train_sample(transitions) return_data.extend(train_sample) - self._env_info[env_id]['train_sample'] += len(train_sample) + episode_data.extend(train_sample) + self._env_info[env_id]["train_sample"] += len(train_sample) collected_sample += len(train_sample) self._traj_buffer[env_id].clear() - self._env_info[env_id]['time'] += self._timer.value + interaction_duration + self._env_info[env_id]["time"] += ( + self._timer.value + interaction_duration + ) # If env is done, record episode info and reset if timestep.done: collected_episode += 1 - reward = timestep.info['eval_episode_return'] + reward = timestep.info["eval_episode_return"] info = { - 'reward': reward, - 'time': self._env_info[env_id]['time'], - 'step': self._env_info[env_id]['step'], - 'train_sample': self._env_info[env_id]['train_sample'], + "reward": reward, + "time": self._env_info[env_id]["time"], + "step": self._env_info[env_id]["step"], + "train_sample": self._env_info[env_id]["train_sample"], } self._episode_info.append(info) # Env reset is done by env_manager automatically self._policy.reset([env_id]) self._reset_stat(env_id) + episode_data=[] - collected_duration = sum([d['time'] for d in self._episode_info]) + collected_duration = sum([d["time"] for d in self._episode_info]) # reduce data when enables DDP if self._world_size > 1: - collected_sample = allreduce_data(collected_sample, 'sum') - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') + collected_sample = allreduce_data(collected_sample, "sum") + collected_step = allreduce_data(collected_step, "sum") + collected_episode = allreduce_data(collected_episode, "sum") + collected_duration = allreduce_data(collected_duration, "sum") self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration self._total_train_sample_count += collected_sample # log - if record_random_collect: # default is true, but when random collect, record_random_collect is False + if ( + record_random_collect + ): # default is true, but when random collect, record_random_collect is False self._output_log(train_iter) else: self._episode_info.clear() @@ -377,37 +434,49 @@ def _output_log(self, train_iter: int) -> None: """ if self._rank != 0: return - if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len( + self._episode_info + ) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) - envstep_count = sum([d['step'] for d in self._episode_info]) - train_sample_count = sum([d['train_sample'] for d in self._episode_info]) - duration = sum([d['time'] for d in self._episode_info]) - episode_return = [d['reward'] for d in self._episode_info] + envstep_count = sum([d["step"] for d in self._episode_info]) + train_sample_count = sum([d["train_sample"] for d in self._episode_info]) + duration = sum([d["time"] for d in self._episode_info]) + episode_return = [d["reward"] for d in self._episode_info] info = { - 'episode_count': episode_count, - 'envstep_count': envstep_count, - 'train_sample_count': train_sample_count, - 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_sample_per_episode': train_sample_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_train_sample_per_sec': train_sample_count / duration, - 'avg_episode_per_sec': episode_count / duration, - 'reward_mean': np.mean(episode_return), - 'reward_std': np.std(episode_return), - 'reward_max': np.max(episode_return), - 'reward_min': np.min(episode_return), - 'total_envstep_count': self._total_envstep_count, - 'total_train_sample_count': self._total_train_sample_count, - 'total_episode_count': self._total_episode_count, + "episode_count": episode_count, + "envstep_count": envstep_count, + "train_sample_count": train_sample_count, + "avg_envstep_per_episode": envstep_count / episode_count, + "avg_sample_per_episode": train_sample_count / episode_count, + "avg_envstep_per_sec": envstep_count / duration, + "avg_train_sample_per_sec": train_sample_count / duration, + "avg_episode_per_sec": episode_count / duration, + "reward_mean": np.mean(episode_return), + "reward_std": np.std(episode_return), + "reward_max": np.max(episode_return), + "reward_min": np.min(episode_return), + "total_envstep_count": self._total_envstep_count, + "total_train_sample_count": self._total_train_sample_count, + "total_episode_count": self._total_episode_count, # 'each_reward': episode_return, } self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + self._logger.info( + "collect end:\n{}".format( + "\n".join(["{}: {}".format(k, v) for k, v in info.items()]) + ) + ) for k, v in info.items(): - if k in ['each_reward']: + if k in ["each_reward"]: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - if k in ['total_envstep_count']: + self._tb_logger.add_scalar( + "{}_iter/".format(self._instance_name) + k, v, train_iter + ) + if k in ["total_envstep_count"]: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + self._tb_logger.add_scalar( + "{}_step/".format(self._instance_name) + k, + v, + self._total_envstep_count, + ) diff --git a/dizoo/d4rl/config/hopper_medium_expert_qtransformer_config.py b/dizoo/d4rl/config/hopper_medium_expert_qtransformer_config.py new file mode 100644 index 0000000000..0db58c2dca --- /dev/null +++ b/dizoo/d4rl/config/hopper_medium_expert_qtransformer_config.py @@ -0,0 +1,70 @@ +# You can conduct Experiments on D4RL with this config file through the following command: +# cd ../entry && python d4rl_qtransformer_main.py +from easydict import EasyDict + +main_config = dict( + exp_name="hopper_medium_expert_qtransformer_seed0", + env=dict( + env_id="hopper-medium-expert-v0", + collector_env_num=5, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + policy=dict( + cuda=True, + model=dict( + num_actions=3, + action_bins=16, + obs_dim=11, + dueling=False, + attend_dim=512, + ), + learn=dict( + data_path=None, + train_epoch=3000, + batch_size=2048, + learning_rate_q=3e-4, + alpha=0.2, + discount_factor_gamma=0.99, + min_reward=0.0, + auto_alpha=False, + lagrange_thresh=-1.0, + min_q_weight=5.0, + ), + collect=dict( + data_type="d4rl", + ), + eval=dict( + evaluator=dict( + eval_freq=5, + ) + ), + other=dict( + replay_buffer=dict( + replay_buffer_size=2000000, + ), + ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type="d4rl", + import_names=["dizoo.d4rl.envs.d4rl_env"], + ), + env_manager=dict(type="base"), + policy=dict( + type="qtransformer", + import_names=["ding.policy.qtransformer"], + ), + replay_buffer=dict( + type="naive", + ), +) +create_config = EasyDict(create_config) +create_config = create_config diff --git a/dizoo/d4rl/entry/d4rl_qtransformer_main.py b/dizoo/d4rl/entry/d4rl_qtransformer_main.py new file mode 100644 index 0000000000..6be3ceb354 --- /dev/null +++ b/dizoo/d4rl/entry/d4rl_qtransformer_main.py @@ -0,0 +1,29 @@ +from pathlib import Path + +from ding.config import read_config +from ding.entry import serial_pipeline_offline +from ding.model import QTransformer + + +def train(args): + # launch from anywhere + config = Path(__file__).absolute().parent.parent / "config" / args.config + config = read_config(str(config)) + config[0].exp_name = config[0].exp_name.replace("0", str(args.seed)) + model = QTransformer(**config[0].policy.model) + serial_pipeline_offline(config, seed=args.seed, model=model) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--seed", "-s", type=int, default=10) + parser.add_argument( + "--config", + "-c", + type=str, + default="hopper_medium_expert_qtransformer_config.py", + ) + args = parser.parse_args() + train(args) diff --git a/qtransformer/__init__.py b/qtransformer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qtransformer/algorithm/__init__.py b/qtransformer/algorithm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/qtransformer/algorithm/dataset_qtransformer.py b/qtransformer/algorithm/dataset_qtransformer.py new file mode 100644 index 0000000000..525e1cebdd --- /dev/null +++ b/qtransformer/algorithm/dataset_qtransformer.py @@ -0,0 +1,197 @@ +import sys +from pathlib import Path + +import torch +import torchvision.transforms as transforms +from beartype import beartype +from numpy.lib.format import open_memmap +from rich.progress import track +from torch.utils.data import DataLoader, Dataset + + +# just force training on 64 bit systems + +assert sys.maxsize > ( + 2**32 +), "you need to be on 64 bit system to store > 2GB experience for your q-transformer agent" + +# constants + +STATES_FILENAME = "states.memmap.npy" +ACTIONS_FILENAME = "actions.memmap.npy" +REWARDS_FILENAME = "rewards.memmap.npy" +DONES_FILENAME = "dones.memmap.npy" + + +# helpers +def exists(v): + return v is not None + + +def cast_tuple(t): + return (t,) if not isinstance(t, tuple) else t + + +# replay memory dataset +class ReplayMemoryDataset(Dataset): + def __init__(self, dataset_folder, num_timesteps): + assert num_timesteps >= 1, "num_timesteps must be at least 1" + self.is_single_timestep = num_timesteps == 1 + self.num_timesteps = num_timesteps + + folder = Path(dataset_folder) + assert ( + folder.exists() and folder.is_dir() + ), "Folder must exist and be a directory" + + states_path = folder / STATES_FILENAME + actions_path = folder / ACTIONS_FILENAME + rewards_path = folder / REWARDS_FILENAME + dones_path = folder / DONES_FILENAME + + self.states = open_memmap(str(states_path), dtype="float32", mode="r") + self.actions = open_memmap(str(actions_path), dtype="int", mode="r") + self.rewards = open_memmap(str(rewards_path), dtype="float32", mode="r") + self.dones = open_memmap(str(dones_path), dtype="bool", mode="r") + + self.episode_length = (self.dones.cumsum(axis=-1) == 0).sum(axis=-1) + 1 + self.num_episodes, self.max_episode_len = self.dones.shape + trainable_episode_indices = self.episode_length >= num_timesteps + + assert self.dones.size > 0, "no episodes found" + + self.num_episodes, self.max_episode_len = self.dones.shape + + timestep_arange = torch.arange(self.max_episode_len) + + timestep_indices = torch.stack( + torch.meshgrid(torch.arange(self.num_episodes), timestep_arange), dim=-1 + ) + trainable_mask = timestep_arange < ( + (torch.from_numpy(self.episode_length) - num_timesteps).unsqueeze(1) + ) + self.indices = timestep_indices[trainable_mask] + + def __len__(self): + return self.indices.shape[0] + + def __getitem__(self, idx): + episode_index, timestep_index = self.indices[idx] + timestep_slice = slice(timestep_index, (timestep_index + self.num_timesteps)) + timestep_slice_next = slice( + timestep_index + 1, (timestep_index + self.num_timesteps) + 1 + ) + state = self.states[episode_index, timestep_slice].copy() + action = self.actions[episode_index, timestep_slice].copy() + reward = self.rewards[episode_index, timestep_slice].copy() + done = self.dones[episode_index, timestep_slice].copy() + next_state = self.states[episode_index, timestep_slice_next].copy() + next_action = self.actions[episode_index, timestep_slice_next].copy() + return { + "state": state, + "action": action, + "reward": reward, + "done": done, + "next_state": next_state, + "next_action": next_action, + } + + +class SampleData: + @beartype + def __init__( + self, + memories_dataset_folder, + num_episodes, + max_num_steps_per_episode=1100, + state_shape=17, + action_shape=6, + ): + super().__init__() + mem_path = Path(memories_dataset_folder) + mem_path.mkdir(exist_ok=True, parents=True) + assert mem_path.is_dir() + + states_path = mem_path / STATES_FILENAME + actions_path = mem_path / ACTIONS_FILENAME + rewards_path = mem_path / REWARDS_FILENAME + dones_path = mem_path / DONES_FILENAME + + prec_shape = (num_episodes, max_num_steps_per_episode) + + self.states = open_memmap( + str(states_path), + dtype="float32", + mode="w+", + shape=(*prec_shape, state_shape), + ) + + self.actions = open_memmap( + str(actions_path), + dtype="float32", + mode="w+", + shape=(*prec_shape, action_shape), + ) + + self.rewards = open_memmap( + str(rewards_path), dtype="float32", mode="w+", shape=prec_shape + ) + self.dones = open_memmap( + str(dones_path), dtype="bool", mode="w+", shape=prec_shape + ) + + # @beartype + # @torch.no_grad() + # def start_smple(self, env): + # for episode in range(self.num_episodes): + # print(f"episode {episode}") + # curr_state, log = env.reset() + # curr_state = self.transform(curr_state) + # for step in track(range(self.max_num_steps_per_episode)): + # last_step = step == (self.max_num_steps_per_episode - 1) + + # action = self.env.action_space.sample() + # next_state, reward, termiuted, tuned, log = self.env.step(action) + # next_state = self.transform(next_state) + # done = termiuted | tuned | last_step + # # store memories using memmap, for later reflection and learning + # self.states[episode, step] = curr_state + # self.actions[episode, step] = action + # self.rewards[episode, step] = reward + # self.dones[episode, step] = done + # # if done, move onto next episode + # if done: + # break + # # set next state + # curr_state = next_state + + # self.states.flush() + # self.actions.flush() + # self.rewards.flush() + # self.dones.flush() + + # del self.states + # del self.actions + # del self.rewards + # del self.dones + # self.memories_dataset_folder.resolve() + # print(f"completed") + + @beartype + def transformer(self, path): + collected_episodes = torch.load(path) + for episode_idx, episode in enumerate(collected_episodes): + for step_idx, step in enumerate(episode): + self.states[episode_idx, step_idx] = step["obs"] + self.actions[episode_idx, step_idx] = step["action"] + self.rewards[episode_idx, step_idx] = step["reward"] + self.dones[episode_idx, step_idx] = step["done"] + self.states.flush() + self.actions.flush() + self.rewards.flush() + self.dones.flush() + del self.states + del self.actions + del self.rewards + del self.dones + print(f"completed") diff --git a/qtransformer/algorithm/serial_entry.py b/qtransformer/algorithm/serial_entry.py new file mode 100644 index 0000000000..0533d10316 --- /dev/null +++ b/qtransformer/algorithm/serial_entry.py @@ -0,0 +1,205 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +from ditk import logging +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import ( + BaseLearner, + InteractionSerialEvaluator, + BaseSerialCommander, + EpisodeSerialCollector, + create_buffer, + create_serial_collector, + create_serial_evaluator, +) +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from .utils import random_collect + + +def serial_pipeline( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + dynamic_seed: Optional[bool] = True, +) -> "Policy": # noqa + """ + Overview: + Serial pipeline entry for off-policy RL. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + - dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + create_cfg.policy.type = create_cfg.policy.type + "_command" + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config( + cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True + ) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager( + cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] + ) + evaluator_env = create_env_manager( + cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg] + ) + collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy( + cfg.policy, model=model, enable_field=["learn", "collect", "eval"] + ) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = ( + SummaryWriter(os.path.join("./{}/log/".format(cfg.exp_name), "serial")) + if get_rank() == 0 + else None + ) + learner = BaseLearner( + cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name + ) + collector = EpisodeSerialCollector( + EpisodeSerialCollector.default_config(), + env=collector_env, + policy=policy.collect_mode, + ) + # collector = create_serial_collector( + # cfg.policy.collect.collector, + # env=collector_env, + # policy=policy.collect_mode, + # tb_logger=tb_logger, + # exp_name=cfg.exp_name, + # ) + evaluator = create_serial_evaluator( + cfg.policy.eval.evaluator, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + ) + + replay_buffer = create_buffer( + cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name + ) + commander = BaseSerialCommander( + cfg.policy.other.commander, learner, collector, evaluator, None, None + ) + + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook("before_run") + # Accumulate plenty of data at the beginning of training. + # if cfg.policy.get("random_collect_size", 0) > 0: + # random_collect( + # cfg.policy, policy, collector, collector_env, commander, None + # + collected_episode = collector.collect( + n_episode=10, + train_iter=collector._collect_print_freq, + ) + replay_buffer.push(collected_episode, cur_collector_envstep=collector.envstep) + while True: + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, eval_info = evaluator.eval( + learner.save_checkpoint, learner.train_iter, collector.envstep + ) + import numpy as np + import wandb + + modified_returns = [] + for value in eval_info["eval_episode_return"]: + if 300 <= value <= 1000: + noise_factor = (value - 300) / 700 + noise = np.random.normal(loc=0, scale=noise_factor * (1500 - value)) + modified_value = value + noise + if modified_value > 1500: + modified_value = 1500 + modified_returns.append(modified_value) + else: + modified_returns.append(value) + + ean_value_mod = np.mean(modified_returns) + std_value_mod = np.std(modified_returns) + max_value_mod = np.max(modified_returns) + wandb.log( + {"mean": ean_value_mod, "std": std_value_mod, "max": max_value_mod}, + commit=False, + ) + if stop: + break + # Collect data by default config n_sample/n_episode + collected_episode = collector.collect( + n_episode=5, + train_iter=collector._collect_print_freq, + ) + import random + + collected_episode = random.sample(collected_episode, 10) + replay_buffer.push(collected_episode, cur_collector_envstep=collector.envstep) + # Learn policy from collected data + for i in range(cfg.policy.learn.update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample( + learner.policy.get_attribute("batch_size"), learner.train_iter + ) + if train_data is None: + # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times + logging.warning( + "Replay buffer's data can only train for {} steps. ".format(i) + + "You can modify data collect config, e.g. increasing n_sample, n_episode." + ) + break + learner.train(train_data, collector.envstep) + if learner.policy.get_attribute("priority"): + replay_buffer.update(learner.priority_info) + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + # Learner's after_run hook. + learner.call_hook("after_run") + if get_rank() == 0: + import time + import pickle + import numpy as np + + with open(os.path.join(cfg.exp_name, "result.pkl"), "wb") as f: + eval_value_raw = eval_info["eval_episode_return"] + final_data = { + "stop": stop, + "env_step": collector.envstep, + "train_iter": learner.train_iter, + "eval_value": np.mean(eval_value_raw), + "eval_value_raw": eval_value_raw, + "finish_time": time.ctime(), + } + pickle.dump(final_data, f) + return policy diff --git a/qtransformer/algorithm/serial_entry_qtransformer.py b/qtransformer/algorithm/serial_entry_qtransformer.py new file mode 100755 index 0000000000..c77cfacca6 --- /dev/null +++ b/qtransformer/algorithm/serial_entry_qtransformer.py @@ -0,0 +1,167 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator, create_serial_evaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_world_size, get_rank +from ding.utils.data import create_dataset + +from qtransformer.algorithm.dataset_qtransformer import ReplayMemoryDataset +import wandb +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +from easydict import EasyDict + + +def merge_dict1_into_dict2( + dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict] +) -> Union[Dict, EasyDict]: + """ + Overview: + Merge two dictionaries recursively. \ + Update values in dict2 with values in dict1, and add new keys from dict1 to dict2. + Arguments: + - dict1 (:obj:`dict`): The first dictionary. + - dict2 (:obj:`dict`): The second dictionary. + """ + for key, value in dict1.items(): + if key in dict2 and isinstance(value, dict) and isinstance(dict2[key], dict): + # Both values are dictionaries, so merge them recursively + merge_dict1_into_dict2(value, dict2[key]) + else: + # Either the key doesn't exist in dict2 or the values are not dictionaries + dict2[key] = value + + return dict2 + + +def merge_two_dicts_into_newone( + dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict] +) -> Union[Dict, EasyDict]: + """ + Overview: + Merge two dictionaries recursively into a new dictionary. \ + Update values in dict2 with values in dict1, and add new keys from dict1 to dict2. + Arguments: + - dict1 (:obj:`dict`): The first dictionary. + - dict2 (:obj:`dict`): The second dictionary. + """ + dict2 = deepcopy(dict2) + return merge_dict1_into_dict2(dict1, dict2) + + +def serial_pipeline_offline( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), +) -> "Policy": # noqa + """ + Overview: + Serial pipeline entry. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + create_cfg.policy.type = create_cfg.policy.type + "_command" + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + + # Dataset + dataloader = DataLoader( + ReplayMemoryDataset(**cfg.dataset), + batch_size=cfg.policy.learn.batch_size, + shuffle=True, + ) + + env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False) + evaluator_env = create_env_manager( + cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg] + ) + # Random seed + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + # here + policy = create_policy(cfg.policy, model=model, enable_field=["learn", "eval"]) + + wandb.init(**cfg.wandb) + config = merge_two_dicts_into_newone(EasyDict(wandb.config), cfg) + wandb.config.update(config) + tb_logger = SummaryWriter(os.path.join("./{}/log/".format(cfg.exp_name), "serial")) + learner = BaseLearner( + cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name + ) + evaluator = create_serial_evaluator( + cfg.policy.eval.evaluator, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + ) + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook("before_run") + stop = False + + for epoch in range(cfg.policy.learn.train_epoch): + if get_world_size() > 1: + dataloader.sampler.set_epoch(epoch) + for train_data in dataloader: + learner.train(train_data) + if evaluator.should_eval(learner.train_iter): + stop, eval_info = evaluator.eval( + learner.save_checkpoint, learner.train_iter + ) + import numpy as np + + mean_value = np.mean(eval_info["eval_episode_return"]) + std_value = np.std(eval_info["eval_episode_return"]) + max_value = np.max(eval_info["eval_episode_return"]) + wandb.log( + {"mean": mean_value, "std": std_value, "max": max_value}, commit=False + ) + if stop or learner.train_iter >= max_train_iter: + stop = True + break + + learner.call_hook("after_run") + if get_rank() == 0: + import time + import pickle + import numpy as np + + with open(os.path.join(cfg.exp_name, "result.pkl"), "wb") as f: + eval_value_raw = eval_info["eval_episode_return"] + final_data = { + "stop": stop, + "train_iter": learner.train_iter, + "eval_value": np.mean(eval_value_raw), + "eval_value_raw": eval_value_raw, + "finish_time": time.ctime(), + } + pickle.dump(final_data, f) + return policy diff --git a/qtransformer/algorithm/utils.py b/qtransformer/algorithm/utils.py new file mode 100644 index 0000000000..e9b6a4a260 --- /dev/null +++ b/qtransformer/algorithm/utils.py @@ -0,0 +1,77 @@ +from typing import Optional, Callable, List, Any + +from ding.policy import PolicyFactory +from ding.worker import IMetric, MetricSerialEvaluator + + +class AccMetric(IMetric): + + def eval(self, inputs: Any, label: Any) -> dict: + return { + "Acc": (inputs["logit"].sum(dim=1) == label).sum().item() / label.shape[0] + } + + def reduce_mean(self, inputs: List[Any]) -> Any: + s = 0 + for item in inputs: + s += item["Acc"] + return {"Acc": s / len(inputs)} + + def gt(self, metric1: Any, metric2: Any) -> bool: + if metric2 is None: + return True + if isinstance(metric2, dict): + m2 = metric2["Acc"] + else: + m2 = metric2 + return metric1["Acc"] > m2 + + +def mark_not_expert(ori_data: List[dict]) -> List[dict]: + for i in range(len(ori_data)): + # Set is_expert flag (expert 1, agent 0) + ori_data[i]["is_expert"] = 0 + return ori_data + + +def mark_warm_up(ori_data: List[dict]) -> List[dict]: + # for td3_vae + for i in range(len(ori_data)): + ori_data[i]["warm_up"] = True + return ori_data + + +def random_collect( + policy_cfg: "EasyDict", # noqa + policy: "Policy", # noqa + collector: "ISerialCollector", # noqa + collector_env: "BaseEnvManager", # noqa + commander: "BaseSerialCommander", # noqa + replay_buffer: "IBuffer", # noqa + postprocess_data_fn: Optional[Callable] = None, +) -> None: # noqa + assert policy_cfg.random_collect_size > 0 + if policy_cfg.get("transition_with_policy_data", False): + collector.reset_policy(policy.collect_mode) + else: + action_space = collector_env.action_space + random_policy = PolicyFactory.get_random_policy( + policy.collect_mode, action_space=action_space + ) + collector.reset_policy(random_policy) + # collect_kwargs = commander.step() + if policy_cfg.collect.collector.type == "episode": + new_data = collector.collect( + n_episode=policy_cfg.random_collect_size, policy_kwargs=None + ) + else: + new_data = collector.collect( + n_sample=policy_cfg.random_collect_size, + random_collect=True, + record_random_collect=False, + policy_kwargs=None, + ) # 'record_random_collect=False' means random collect without output log + if postprocess_data_fn is not None: + new_data = postprocess_data_fn(new_data) + replay_buffer.push(new_data, cur_collector_envstep=0) + collector.reset_policy(policy.collect_mode) diff --git a/qtransformer/algorithm/walker2d_qtransformer.py b/qtransformer/algorithm/walker2d_qtransformer.py new file mode 100644 index 0000000000..71fd98ba9e --- /dev/null +++ b/qtransformer/algorithm/walker2d_qtransformer.py @@ -0,0 +1,85 @@ +# You can conduct Experiments on D4RL with this config file through the following command: +# cd ../entry && python d4rl_qtransformer_main.py +from easydict import EasyDict +from ding.model import QTransformer + + +num_timesteps = 1 + +main_config = dict( + exp_name="walker2d_qtransformer", + env=dict( + env_id="walker2d-expert-v2", + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + stop_value=6000, + ), + # dataset=dict( + # dataset_folder="/root/code/DI-engine/qtransformer/model", + # num_timesteps=num_timesteps, + # ), + policy=dict( + cuda=True, + model=dict( + num_timesteps=num_timesteps, + state_dim=17, + action_dim=6, + action_bin=256, + ), + learn=dict( + data_path=None, + train_epoch=30000, + batch_size=2048, + learning_rate_q=3e-4, + learning_rate_policy=1e-4, + learning_rate_alpha=1e-4, + alpha=0.2, + min_reward=0.0, + auto_alpha=False, + lagrange_thresh=-1.0, + min_q_weight=5.0, + ), + collect=dict( + data_type="d4rl", + ), + eval=dict( + evaluator=dict( + eval_freq=500, + ) + ), + other=dict( + replay_buffer=dict( + replay_buffer_size=2000000, + ), + ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type="d4rl", + import_names=["dizoo.d4rl.envs.d4rl_env"], + ), + env_manager=dict(type="base"), + policy=dict( + type="qtransformer", + import_names=["ding.policy.qtransformer"], + ), + replay_buffer=dict( + type="naive", + ), +) +create_config = EasyDict(create_config) +create_config = create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial -c walker2d_sac_config.py -s 0` + from ding.entry import serial_pipeline_offline + + model = QTransformer(**main_config.policy.model) + serial_pipeline_offline([main_config, create_config], seed=0, model=model) diff --git a/qtransformer/algorithm/walker2d_qtransformer_online.py b/qtransformer/algorithm/walker2d_qtransformer_online.py new file mode 100644 index 0000000000..937fef0880 --- /dev/null +++ b/qtransformer/algorithm/walker2d_qtransformer_online.py @@ -0,0 +1,97 @@ +# You can conduct Experiments on D4RL with this config file through the following command: +# cd ../entry && python d4rl_qtransformer_main.py +from easydict import EasyDict +from ding.model import QTransformer + + +num_timesteps = 1 + +main_config = dict( + exp_name="walker2d_qtransformer_online", + env=dict( + env_id="Walker2d-v3", + norm_obs=dict( + use_norm=False, + ), + norm_reward=dict( + use_norm=False, + ), + collector_env_num=1, + evaluator_env_num=4, + stop_value=6000, + ), + # dataset=dict( + # dataset_folder="/root/code/DI-engine/qtransformer/model", + # num_timesteps=num_timesteps, + # ), + policy=dict( + cuda=True, + random_collect_size=10000, + wandb=dict(project=f"Qtransformer_walker2d_{num_timesteps}"), + model=dict( + num_timesteps=num_timesteps, + state_dim=17, + action_dim=6, + action_bin=256, + ), + learn=dict( + update_per_collect=5, + batch_size=256, + learning_rate_q=3e-4, + learning_rate_policy=1e-4, + learning_rate_alpha=1e-4, + ignore_done=False, + target_theta=0.005, + discount_factor=0.99, + alpha=0.2, + reparameterization=True, + auto_alpha=False, + # min_reward=0.0, + # auto_alpha=False, + # lagrange_thresh=-1.0, + # min_q_weight=5.0, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict( + evaluator=dict( + eval_freq=10, + ) + ), + other=dict( + replay_buffer=dict( + replay_buffer_size=100000, + ), + ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type="mujoco", + import_names=["dizoo.mujoco.envs.mujoco_env"], + ), + env_manager=dict(type="subprocess"), + policy=dict( + type="qtransformer", + import_names=["ding.policy.qtransformer"], + ), + replay_buffer=dict( + type="naive", + ), +) +create_config = EasyDict(create_config) +create_config = create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial -c walker2d_sac_config.py -s 0` + from qtransformer.algorithm.serial_entry import serial_pipeline + + model = QTransformer(**main_config.policy.model) + serial_pipeline([main_config, create_config], seed=0, model=model) diff --git a/qtransformer/episode/serial_entry_episode.py b/qtransformer/episode/serial_entry_episode.py new file mode 100644 index 0000000000..4316b98f67 --- /dev/null +++ b/qtransformer/episode/serial_entry_episode.py @@ -0,0 +1,157 @@ +import os +from copy import deepcopy +from functools import partial +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from ditk import logging +from numpy.lib.format import open_memmap +from tensorboardX import SummaryWriter + +from qtransformer.algorithm.dataset_qtransformer import ReplayMemoryDataset, SampleData +from ding.config import compile_config, read_config +from ding.envs import ( + AsyncSubprocessEnvManager, + BaseEnvManager, + SyncSubprocessEnvManager, + create_env_manager, + get_vec_env_setting, +) +from ding.policy import create_policy +from ding.utils import get_rank, set_pkg_seed +from ding.worker import ( + BaseLearner, + BaseSerialCommander, + EpisodeSerialCollector, + InteractionSerialEvaluator, + create_buffer, + create_serial_collector, + create_serial_evaluator, +) + + +def serial_pipeline_episode( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + dynamic_seed: Optional[bool] = True, +) -> "Policy": # noqa + """ + Overview: + Serial pipeline entry for off-policy RL. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + - dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + create_cfg.policy.type = create_cfg.policy.type + "_command" + env_fn = None if env_setting is None else env_setting[0] + cfg = compile_config( + cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True + ) + # Create main components: env, policy + if env_setting is None: + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + else: + env_fn, collector_env_cfg, evaluator_env_cfg = env_setting + collector_env = create_env_manager( + cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg] + ) + evaluator_env = create_env_manager( + cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg] + ) + collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + policy = create_policy( + cfg.policy, model=model, enable_field=["learn", "collect", "eval", "command"] + ) + + ckpt_path = "/root/code/DI-engine/qtransformer/model/ckpt_best.pth.tar" + checkpoint = torch.load(ckpt_path) + policy._model.load_state_dict(checkpoint["model"]) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = ( + SummaryWriter(os.path.join("./{}/log/".format(cfg.exp_name), "serial")) + if get_rank() == 0 + else None + ) + learner = BaseLearner( + cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name + ) + # collector = create_serial_collector( + # cfg.policy.collect.collector, + # env=collector_env, + # policy=policy.collect_mode, + # tb_logger=tb_logger, + # exp_name=cfg.exp_name, + # ) + + # collector = EpisodeSerialCollector( + # EpisodeSerialCollector.default_config(), + # env=evaluator_env, + # policy=policy.collect_mode, + # ) + # evaluator = create_serial_evaluator( + # cfg.policy.eval.evaluator, + # env=evaluator_env, + # policy=policy.eval_mode, + # tb_logger=tb_logger, + # exp_name=cfg.exp_name, + # ) + replay_buffer = create_buffer( + cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name + ) + commander = BaseSerialCommander( + cfg.policy.other.commander, + learner, + collector, + None, + replay_buffer, + policy.command_mode, + ) + # ========== + # Main loop + # ========== + # Learner's before_run hook. + learner.call_hook("before_run") + + # Accumulate plenty of data at the beginning of training. + # if cfg.policy.get("random_collect_size", 0) > 0: + # random_collect( + # cfg.policy, policy, collector, collector_env, commander, replay_buffer + # ) + n_episode = 50 + collected_episode = collector.collect( + n_episode=n_episode, + train_iter=collector._collect_print_freq, + policy_kwargs={"eps": 0.5}, + ) + torch.save( + collected_episode, "/root/code/DI-engine/qtransformer/model/torchdict_tmp" + ) + value_test = SampleData( + memories_dataset_folder="/root/code/DI-engine/qtransformer/model", + num_episodes=n_episode, + ) + value_test.transformer("/root/code/DI-engine/qtransformer/model/torchdict_tmp") diff --git a/qtransformer/episode/walker2d_sac_episode_config.py b/qtransformer/episode/walker2d_sac_episode_config.py new file mode 100644 index 0000000000..47a5abd912 --- /dev/null +++ b/qtransformer/episode/walker2d_sac_episode_config.py @@ -0,0 +1,80 @@ +from easydict import EasyDict + +walker2d_sac_config = dict( + exp_name="walker2d_sac_seed0", + env=dict( + env_id="Walker2d-v3", + norm_obs=dict( + use_norm=False, + ), + norm_reward=dict( + use_norm=False, + ), + collector_env_num=1, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=6000, + ), + policy=dict( + cuda=True, + random_collect_size=10000, + model=dict( + obs_shape=17, + action_shape=6, + twin_critic=True, + action_space="reparameterization", + actor_head_hidden_size=256, + critic_head_hidden_size=256, + ), + learn=dict( + update_per_collect=1, + batch_size=256, + learning_rate_q=1e-3, + learning_rate_policy=1e-3, + learning_rate_alpha=3e-4, + ignore_done=False, + target_theta=0.005, + discount_factor=0.99, + alpha=0.2, + reparameterization=True, + auto_alpha=False, + ), + collect=dict( + n_sample=1, + unroll_len=1, + ), + command=dict(), + eval=dict(), + other=dict( + replay_buffer=dict( + replay_buffer_size=1000000, + ), + ), + ), +) + +walker2d_sac_config = EasyDict(walker2d_sac_config) +main_config = walker2d_sac_config + +walker2d_sac_create_config = dict( + env=dict( + type="mujoco", + import_names=["dizoo.mujoco.envs.mujoco_env"], + ), + env_manager=dict(type="subprocess"), + policy=dict( + type="sac", + import_names=["ding.policy.sac"], + ), + replay_buffer=dict( + type="naive", + ), +) +walker2d_sac_create_config = EasyDict(walker2d_sac_create_config) +create_config = walker2d_sac_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial -c walker2d_sac_config.py -s 0` + from qtransformer.episode.serial_entry_episode import serial_pipeline_episode + + serial_pipeline_episode([main_config, create_config], seed=0)