Skip to content

Commit

Permalink
output q1 and q2 from twin_critic_action_value_loss
Browse files Browse the repository at this point in the history
Summary: [Pearl APS] output q1 and q2 from twin_critic_action_value_loss

Differential Revision: D59352504

fbshipit-source-id: b28c8d36d62c6e305cad7b4ffbbec4801c38459c
  • Loading branch information
Zhihao Cen authored and facebook-github-bot committed Jul 8, 2024
1 parent 39582e7 commit 1827a08
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pearl/policy_learners/sequential_decision_making/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
assert isinstance(self._critic, TwinCritic), "DDPG requires TwinCritic critic"

# update twin critics towards bellman target
loss = twin_critic_action_value_loss(
loss, _, _ = twin_critic_action_value_loss(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=expected_state_action_values,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
), "Critic in ImplicitQLearning should be TwinCritic"

# update twin critics towards target
loss = twin_critic_action_value_loss(
loss, _, _ = twin_critic_action_value_loss(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=target,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
) + reward_batch # (batch_size), r + gamma * V(s)

assert isinstance(self._critic, TwinCritic)
loss = twin_critic_action_value_loss(
loss, _, _ = twin_critic_action_value_loss(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=expected_state_action_values,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:
else:
raise AssertionError("terminated_batch should not be None")

loss = twin_critic_action_value_loss(
loss, _, _ = twin_critic_action_value_loss(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=expected_state_action_values,
Expand Down
2 changes: 1 addition & 1 deletion pearl/policy_learners/sequential_decision_making/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _critic_loss(self, batch: TransitionBatch) -> torch.Tensor:

# update twin critics towards bellman target
assert isinstance(self._critic, TwinCritic)
loss = twin_critic_action_value_loss(
loss, _, _ = twin_critic_action_value_loss(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=expected_state_action_values,
Expand Down
2 changes: 1 addition & 1 deletion pearl/safety_modules/reward_constrained_safety_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def cost_critic_learn_batch(

# update twin critics towards bellman target
assert isinstance(self.cost_critic, TwinCritic)
loss = twin_critic_action_value_loss(
loss, _, _ = twin_critic_action_value_loss(
state_batch=batch.state,
action_batch=batch.action,
expected_target_batch=expected_state_action_values,
Expand Down
8 changes: 5 additions & 3 deletions pearl/utils/functional_utils/learning/critic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import cast, Iterable, Optional, Type, Union
from typing import cast, Iterable, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -167,7 +167,7 @@ def twin_critic_action_value_loss(
action_batch: torch.Tensor,
expected_target_batch: torch.Tensor,
critic: TwinCritic,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This method calculates the sum of the mean squared errors between the predicted Q-values
using critic networks (LHS of the Bellman equation) and the input target estimates (RHS of the
Expand All @@ -185,11 +185,13 @@ def twin_critic_action_value_loss(
loss (torch.Tensor): Sum of mean squared errors in the Bellman equation (for action-value
prediction) corresponding to both critic networks. The expected shape is `()`. This
scalar loss is used to train both critics of the twin critic network.
q1: q1 critic network prediction
q2: q2 critic network prediction
"""

criterion = torch.nn.MSELoss()
q_1, q_2 = critic.get_q_values(state_batch, action_batch)
loss = criterion(
q_1.reshape_as(expected_target_batch), expected_target_batch.detach()
) + criterion(q_2.reshape_as(expected_target_batch), expected_target_batch.detach())
return loss
return loss, q_1, q_2
2 changes: 1 addition & 1 deletion test/unit/with_pytorch/test_twin_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_twin_critic(self) -> None:
state_batch = torch.randn(self.batch_size, self.state_dim)
action_batch = torch.randn(self.batch_size, self.action_dim)
optimizer = torch.optim.AdamW(twin_critics.parameters(), lr=1e-3)
loss = twin_critic_action_value_loss(
loss, _, _ = twin_critic_action_value_loss(
state_batch=state_batch,
action_batch=action_batch,
expected_target_batch=torch.randn(self.batch_size),
Expand Down

0 comments on commit 1827a08

Please sign in to comment.