Skip to content

Commit

Permalink
Bug
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Mar 18, 2022
1 parent 5f9f621 commit e9e2ad5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
35 changes: 18 additions & 17 deletions scripts/optimize/optuna_crowd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from logging import ERROR
from typing import Optional

import cv2
import numpy as np
Expand Down Expand Up @@ -37,7 +38,7 @@ class Parser(BaseParser):
}


def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float:
def objective(trial: optuna.trial.BaseTrial, idx: Optional[int], worker_id: int, path: str) -> float:
# Get some parameters
lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
n_episodes = 1
Expand Down Expand Up @@ -156,7 +157,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
entity="redtachyon",
sync_tensorboard=True,
config=config,
name=f"trial{trial.number}",
name=f"trial{trial.number}-{idx if idx is not None else ''}",
)

model = RelationModel(config["model"], action_space=env.action_space)
Expand Down Expand Up @@ -188,7 +189,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float

# EVALUATION
env = UnitySimpleCrowdEnv(
file_name=args.env,
file_name=path,
virtual_display=(1600, 900),
no_graphics=False,
worker_id=worker_id+5,
Expand Down Expand Up @@ -243,21 +244,21 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
# Generate the dashboard
print("Generating dashboard")
print("Skipping dashboard")
# trajectory = du.read_trajectory(trajectory_path)
#
# plt.clf()
# du.make_dashboard(trajectory, save_path=dashboard_path)
trajectory = du.read_trajectory(trajectory_path)

plt.clf()
du.make_dashboard(trajectory, save_path=dashboard_path)

# Upload to wandb
# print("Uploading dashboard")
# wandb.log(
# {
# "dashboard": wandb.Image(
# dashboard_path,
# caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}",
# )
# }
# )
print("Uploading dashboard")
wandb.log(
{
"dashboard": wandb.Image(
dashboard_path,
caption=f"Dashboard {mode} {'det' if d else 'rng'} {i}",
)
}
)

frame_size = renders.shape[1:3]

Expand Down Expand Up @@ -300,7 +301,7 @@ def objective(trial: optuna.trial.BaseTrial, worker_id: int, path: str) -> float
study = optuna.load_study(storage=f"sqlite:///{args.optuna_name}.db", study_name=args.optuna_name)

study.optimize(
lambda trial: objective(trial, args.worker_id, args.env), n_trials=args.n_trials
lambda trial: objective(trial, None, args.worker_id, args.env), n_trials=args.n_trials
)

print("Best params:", study.best_params)
Expand Down
2 changes: 1 addition & 1 deletion scripts/optimize/retrain_crowd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ class Parser(BaseParser):
print(f"Trial {idx}")
for i in range(args.n_trials):
print(f"Run {i}")
objective(trial, args.worker_id, args.env)
objective(trial, i, args.worker_id, args.env)

0 comments on commit e9e2ad5

Please sign in to comment.