Skip to content

Commit

Permalink
upgrade pyre version in fbcode/pearl - batch 1
Browse files Browse the repository at this point in the history
Differential Revision: D60991999

fbshipit-source-id: 318e24f7f9b3a19b9e9fa07d94073728c83deb40
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Aug 9, 2024
1 parent 6b0e81d commit f84334c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pearl/neural_networks/contextual_bandit/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def matrix_inv_fallback_pinv(self, A: torch.Tensor) -> torch.Tensor:
try:
return torch.linalg.inv(A).contiguous()
# pyre-ignore[16]: Module `_C` has no attribute `_LinAlgError`.
# pyre-fixme[66]: Exception handler type annotation `unknown` must extend
# BaseException.
except torch._C._LinAlgError as e:
logger.warning(
"Exception raised during A inversion, falling back to pseudo-inverse",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _partition_batch_by_arm(self, batch: TransitionBatch) -> List[TransitionBatc
state = batch.state[:, arm, :]
batches.append(
TransitionBatch(
# pyre-fixme[61]: `state` is undefined, or not always defined.
state=state[mask],
reward=batch.reward[mask],
weight=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def create_offline_data(
action = agent.act(exploit=False)

action_result = env.step(action)
# pyre-fixme[58]: `+` is not supported for operand types `int` and `object`.
g += action_result.reward
agent.observe(action_result)
transition_tuple = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def run_episode(
action.cpu() if isinstance(action, torch.Tensor) else action
) # action can be int sometimes
action_result = env.step(action)
# pyre-fixme[58]: `+` is not supported for operand types `int` and `object`.
cum_reward += action_result.reward
if (
num_risky_sa is not None
Expand Down Expand Up @@ -296,8 +297,12 @@ def run_episode(

info = {"return": cum_reward}
if num_risky_sa is not None:
# pyre-fixme[6]: For 1st argument expected `SupportsKeysAndGetItem[str,
# int]` but got `Dict[str, float]`.
info.update({"risky_sa_ratio": num_risky_sa / episode_steps})
if cum_cost is not None:
# pyre-fixme[6]: For 1st argument expected `SupportsKeysAndGetItem[str,
# int]` but got `Dict[str, float]`.
info.update({"return_cost": cum_cost})

return info, episode_steps

0 comments on commit f84334c

Please sign in to comment.