Skip to content

Commit

Permalink
AWAC implementation
Browse files Browse the repository at this point in the history
Summary: An AWAC based actor critic algorithm for offline RL.

Reviewed By: danielrjiang

Differential Revision: D60653469

fbshipit-source-id: 55f3150109339bad044858034f59bf76e6031dca
  • Loading branch information
Ruiyang Xu authored and facebook-github-bot committed Aug 6, 2024
1 parent 55c5f54 commit 6b0e81d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union

import torch
from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -95,6 +95,10 @@ def __init__(
temperature_advantage_weighted_regression: float = 0.5,
advantage_clamp: float = 100.0,
action_representation_module: Optional[ActionRepresentationModule] = None,
actor_network_instance: Optional[ActorNetwork] = None,
critic_network_instance: Optional[
Union[ValueNetwork, QValueNetwork, torch.nn.Module]
] = None,
) -> None:
super(ImplicitQLearning, self).__init__(
state_dim=state_dim,
Expand All @@ -120,6 +124,8 @@ def __init__(
is_action_continuous=action_space.is_continuous, # inferred from the action space
on_policy=False,
action_representation_module=action_representation_module,
actor_network_instance=actor_network_instance,
critic_network_instance=critic_network_instance,
)

self._expectile = expectile
Expand Down
8 changes: 6 additions & 2 deletions pearl/utils/functional_utils/learning/critic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,12 @@ def update_critic_target_network(
)
else:
update_target_network(
target_network._model,
network._model,
(
target_network._model
if hasattr(target_network, "_model")
else target_network
),
network._model if hasattr(network, "_model") else network,
tau=tau,
)

Expand Down

0 comments on commit 6b0e81d

Please sign in to comment.