Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .application_entry_drex_collect_data import drex_collecting_data
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc_mcts import serial_pipeline_pc_mcts
126 changes: 126 additions & 0 deletions ding/entry/serial_entry_pc_mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Union, Optional, Tuple
import os
import torch
from functools import partial
from tensorboardX import SummaryWriter
from copy import deepcopy
from torch.utils.data import DataLoader, Dataset
import pickle

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator
from ding.config import read_config, compile_config
from ding.policy import create_policy
from ding.utils import set_pkg_seed


class MCTSPCDataset(Dataset):

def __init__(self, data_dic, seq_len=4):
self.observations = data_dic['obs']
self.actions = data_dic['actions']
self.hidden_states = data_dic['hidden_state']
self.seq_len = seq_len
self.length = len(self.observations) - seq_len - 1

def __getitem__(self, idx):
"""
Assume the trajectory is: o1, h2, h3, h4
"""
return {
'obs': self.observations[idx],
'hidden_states': list(reversed(self.hidden_states[idx + 1:idx + self.seq_len + 1])),
'action': self.actions[idx]
}

def __len__(self):
return self.length


def load_mcts_datasets(path, seq_len, batch_size=32):
with open(path, 'rb') as f:
dic = pickle.load(f)
tot_len = len(dic['obs'])
train_dic = {k: v[:-tot_len // 10] for k, v in dic.items()}
test_dic = {k: v[-tot_len // 10:] for k, v in dic.items()}
return DataLoader(MCTSPCDataset(train_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size), \
DataLoader(MCTSPCDataset(test_dic, seq_len=seq_len), shuffle=True, batch_size=batch_size)


def serial_pipeline_pc_mcts(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
max_iter=int(1e6),
) -> Union['Policy', bool]: # noqa
r"""
Overview:
Serial pipeline entry of imitation learning.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- data_path (:obj:`str`): Path of training data.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
Returns:
- policy (:obj:`Policy`): Converged policy.
- convergence (:obj:`bool`): whether il training is converged
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)

# Env, Policy
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
# Random seed
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])

# Main components
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
dataloader, test_dataloader = load_mcts_datasets(cfg.policy.expert_data_path, seq_len=cfg.policy.seq_len)
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

# ==========
# Main loop
# ==========
learner.call_hook('before_run')
stop = False
iter_cnt = 0
for epoch in range(cfg.policy.learn.train_epoch):
# train
criterion = torch.nn.CrossEntropyLoss()
for i, train_data in enumerate(dataloader):
learner.train(train_data)
iter_cnt += 1
if iter_cnt >= max_iter:
stop = True
break
if epoch % 69 == 0:
policy._optimizer.param_groups[0]['lr'] /= 10
if stop:
break
losses = []
acces = []
for _, test_data in enumerate(test_dataloader):
logits = policy._model.forward_eval(test_data['obs'].permute(0, 3, 1, 2).float().cuda() / 255.)

loss = criterion(logits, test_data['action'].cuda()).item()
preds = torch.argmax(logits, dim=-1)
acc = torch.sum((preds == test_data['action'].cuda())).item() / preds.shape[0]

losses.append(loss)
acces.append(acc)
print('Test Finished! Loss: {} acc: {}'.format(sum(losses) / len(losses), sum(acces) / len(acces)))
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
learner.call_hook('after_run')
print('final reward is: {}'.format(reward))
return policy, stop
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from .madqn import MADQN
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloning
from .procedure_cloning import ProcedureCloningMCTS
115 changes: 78 additions & 37 deletions ding/model/template/procedure_cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,25 @@ def __init__(
self.attention_layer = []

self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att

self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
for i in range(n_att - 1):
self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))

self.attention_layer = nn.ModuleList(self.attention_layer)
self.att_drop = nn.Dropout(drop_p)

self.fc_blocks = []
self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU()))
for i in range(n_feedforward - 1):
self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU()))
self.fc_blocks = nn.ModuleList(self.fc_blocks)

self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward)
self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T)
self.norm_layer = nn.ModuleList(self.norm_layer)

self.mask = nn.Parameter(
torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T), requires_grad=False
)

def forward(self, x: torch.Tensor):
for i in range(self.n_att):
Expand All @@ -42,36 +49,43 @@ def forward(self, x: torch.Tensor):
return x


@MODEL_REGISTRY.register('pc')
class ProcedureCloning(nn.Module):
@MODEL_REGISTRY.register('pc_mcts')
class ProcedureCloningMCTS(nn.Module):

def __init__(
self,
obs_shape: SequenceType,
action_dim: int,
cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256],
cnn_activation: Optional[nn.Module] = nn.ReLU(),
cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3],
cnn_stride: SequenceType = [1, 1, 1, 1, 1],
cnn_padding: Optional[SequenceType] = ['same', 'same', 'same', 'same', 'same'],
mlp_hidden_list: SequenceType = [256, 256],
mlp_activation: Optional[nn.Module] = nn.ReLU(),
att_heads: int = 8,
att_hidden: int = 128,
n_att: int = 4,
n_feedforward: int = 2,
feedforward_hidden: int = 256,
drop_p: float = 0.5,
augment: bool = True,
max_T: int = 17
self,
obs_shape: SequenceType,
hidden_shape: SequenceType,
action_dim: int,
seq_len: int,
cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256],
cnn_activation: Optional[nn.Module] = nn.ReLU(),
cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3],
cnn_stride: SequenceType = [1, 1, 1, 1, 1],
cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1],
mlp_hidden_list: SequenceType = [256, 256],
mlp_activation: Optional[nn.Module] = nn.ReLU(),
att_heads: int = 8,
att_hidden: int = 128,
n_att: int = 4,
n_feedforward: int = 2,
feedforward_hidden: int = 256,
drop_p: float = 0.5,
augment: bool = True,
) -> None:
super().__init__()
self.obs_shape = obs_shape
self.hidden_shape = hidden_shape
self.seq_len = seq_len
max_T = seq_len + 1

#Conv Encoder
self.embed_state = ConvEncoder(
obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding
)
self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation)
self.embed_hidden = ConvEncoder(
hidden_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding
)

self.cnn_hidden_list = cnn_hidden_list
self.augment = augment
Expand All @@ -95,25 +109,52 @@ def __init__(
cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward
)

self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1])
self.predict_hidden_state = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1])
self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim)

def forward(self, states: torch.Tensor, goals: torch.Tensor,
actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

B, T, _ = actions.shape
def _compute_embeddings(self, states: torch.Tensor, hidden_states: torch.Tensor):
B, T, *_ = hidden_states.shape

# shape: (B, h_dim)
# shape: (B, 1, h_dim)
state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1])
goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1])
# shape: (B, context_len, h_dim)
actions_embeddings = self.embed_action(actions)
# shape: (B, T, h_dim)
hidden_state_embeddings = self.embed_hidden(hidden_states.reshape(B * T, *hidden_states.shape[2:])) \
.reshape(B, T, self.cnn_hidden_list[-1])
return state_embeddings, hidden_state_embeddings

h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1)
def _compute_transformer(self, h):
B, T, *_ = h.shape
h = self.transformer(h)
h = h.reshape(B, T + 2, self.cnn_hidden_list[-1])
h = h.reshape(B, T, self.cnn_hidden_list[-1])

hidden_state_preds = self.predict_hidden_state(h[:, 0:-1, ...])
action_preds = self.predict_action(h[:, -1, :])
return hidden_state_preds, action_preds

def forward(self, states: torch.Tensor, hidden_states: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# State is current observation.
# Hidden states is a sequence including [L, R, ...].
# The shape of state and hidden state may be different.
B, T, *_ = hidden_states.shape
assert T == self.seq_len
state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states)

h = torch.cat((state_embeddings, hidden_state_embeddings), dim=1)
hidden_state_preds, action_preds = self._compute_transformer(h)

return hidden_state_preds, action_preds, hidden_state_embeddings.detach()

def forward_eval(self, states: torch.Tensor) -> torch.Tensor:
batch_size = states.shape[0]
hidden_states = torch.zeros(batch_size, self.seq_len, *self.hidden_shape, dtype=states.dtype).to(states.device)
embedding_mask = torch.zeros(1, self.seq_len, 1).to(states.device)

state_embeddings, hidden_state_embeddings = self._compute_embeddings(states, hidden_states)

goal_preds = self.predict_goal(h[:, 0, :])
action_preds = self.predict_action(h[:, 1:, :])
for i in range(self.seq_len):
h = torch.cat((state_embeddings, hidden_state_embeddings * embedding_mask), dim=1)
hidden_state_embeddings, action_pred = self._compute_transformer(h)
embedding_mask[0, i, 0] = 1

return goal_preds, action_preds
return action_pred
38 changes: 17 additions & 21 deletions ding/model/template/tests/test_procedure_cloning.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
import torch
import pytest
import numpy as np
from itertools import product

from ding.model.template import ProcedureCloning
from ding.torch_utils import is_differentiable
from ding.utils import squeeze
from ding.model.template import ProcedureCloningMCTS

B = 4
T = 15
obs_shape = [(64, 64, 3)]
action_dim = [9]
obs_shape = (64, 64, 3)
hidden_shape = (9, 9, 64)
action_dim = 9
obs_embeddings = 256
args = list(product(*[obs_shape, action_dim]))


@pytest.mark.unittest
@pytest.mark.parametrize('obs_shape, action_dim', args)
class TestProcedureCloning:
def test_procedure_cloning():
inputs = {
'states': torch.randn(B, *obs_shape),
'hidden_states': torch.randn(B, T, *hidden_shape),
'actions': torch.randn(B, action_dim)
}
model = ProcedureCloningMCTS(obs_shape=obs_shape, hidden_shape=hidden_shape, seq_len=T, action_dim=action_dim)

def test_procedure_cloning(self, obs_shape, action_dim):
inputs = {
'states': torch.randn(B, *obs_shape),
'goals': torch.randn(B, *obs_shape),
'actions': torch.randn(B, T, action_dim)
}
model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim)
print(model)

print(model)
hidden_state_preds, action_preds, target_hidden_state = model(inputs['states'], inputs['hidden_states'])
assert hidden_state_preds.shape == (B, T, obs_embeddings)
assert action_preds.shape == (B, action_dim)

goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
assert goal_preds.shape == (B, obs_embeddings)
assert action_preds.shape == (B, T + 1, action_dim)
action_eval = model.forward_eval(inputs['states'])
assert action_eval.shape == (B, action_dim)
1 change: 1 addition & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from .bc import BehaviourCloningPolicy
from .ibc import IBCPolicy
from .pc import ProcedureCloningPolicyMCTS

# new-type policy
from .ppof import PPOFPolicy
Loading