Skip to content

Commit 33056ce

Browse files
committed
evaluate: now compatible with separable envs
1 parent 34ce7a5 commit 33056ce

File tree

1 file changed

+73
-33
lines changed

1 file changed

+73
-33
lines changed

dragonfly/src/core/evaluate.py

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Generic imports
2+
import numpy as np
23
import gymnasium as gym
34

45
# Custom imports
56
from dragonfly.src.utils.json import json_parser
67
from dragonfly.src.agent.agent import agent_factory
78
from dragonfly.src.env.environment import environment
9+
from dragonfly.src.env.mpi import mpi
810
from dragonfly.src.utils.renderer import renderer
911
from dragonfly.src.utils.prints import new_line, spacer
1012

@@ -46,43 +48,81 @@ def evaluate(net_folder, json_file, ns, nw, aw, eval_frequency):
4648
term_ns = False
4749
term_dn = True
4850

49-
# Reset
50-
n = 0
51-
scr = 0.0
52-
obs = env.reset_all()
53-
54-
# Specify warmup (unrolling without control)
55-
if (nw > 0):
56-
# Retrieve action type
57-
t = env.get_action_type()
58-
if (t == "continuous"):
59-
act = []
60-
for a in aw: act.append(float(a))
61-
act = [act]
62-
if (t == "discrete"):
63-
act = []
64-
for a in aw: act.append(int(a))
65-
66-
# Loop with neutral action
67-
for i in range(nw):
51+
# Check whether the environment is separable or not
52+
if env.spaces.separable():
53+
# Reset
54+
n = 0
55+
scr = 0.0
56+
natural_act_dim = env.spaces.natural_act_dim()
57+
true_obs_dim = env.spaces.true_obs_dim()
58+
obs = np.zeros((natural_act_dim, mpi.size, true_obs_dim))
59+
act = np.zeros((natural_act_dim, mpi.size))
60+
rwd = np.zeros((natural_act_dim, mpi.size))
61+
dne = np.zeros((natural_act_dim, mpi.size), dtype=bool)
62+
for i in range(natural_act_dim):
63+
obs[i,:] = env.reset_all()
64+
65+
# Unroll
66+
while True:
67+
for i in range(natural_act_dim):
68+
actions = agent.control(obs[i,:,:])
69+
act[i,:] = np.reshape(actions, (mpi.size))
70+
71+
for i in range(natural_act_dim):
72+
o, r, d, t = env.step(np.transpose(act))
73+
obs[i,:,:] = o[:,:]
74+
rwd[i,:] = r[:]
75+
dne[i,:] = d[:]
76+
77+
scr += np.sum(rwd[:,0], axis=0)
78+
79+
if (n%eval_frequency == 0): rnd.store(env)
80+
if (term_ns and n >= ns-1): break
81+
if (term_dn and dne[0]): break
82+
83+
n += 1
84+
85+
rnd.store(env)
86+
rnd.finish(".", 0, 0)
87+
env.close()
88+
else:
89+
# Reset
90+
n = 0
91+
scr = 0.0
92+
obs = env.reset_all()
93+
94+
# Specify warmup (unrolling without control)
95+
if (nw > 0):
96+
# Retrieve action type
97+
t = env.get_action_type()
98+
if (t == "continuous"):
99+
act = []
100+
for a in aw: act.append(float(a))
101+
act = [act]
102+
if (t == "discrete"):
103+
act = []
104+
for a in aw: act.append(int(a))
105+
106+
# Loop with neutral action
107+
for i in range(nw):
108+
obs, rwd, dne, trc = env.step(act)
109+
rnd.store(env)
110+
111+
# Unroll
112+
while True:
113+
act = agent.control(obs)
68114
obs, rwd, dne, trc = env.step(act)
69-
rnd.store(env)
115+
scr += rwd[0]
70116

71-
# Unroll
72-
while True:
73-
act = agent.control(obs)
74-
obs, rwd, dne, trc = env.step(act)
75-
scr += rwd[0]
117+
if (n%eval_frequency == 0): rnd.store(env)
118+
if (term_ns and n >= ns-1): break
119+
if (term_dn and dne): break
76120

77-
if (n%eval_frequency == 0): rnd.store(env)
78-
if (term_ns and n >= ns-1): break
79-
if (term_dn and dne): break
121+
n += 1
80122

81-
n += 1
82-
83-
rnd.store(env)
84-
rnd.finish(".", 0, 0)
85-
env.close()
123+
rnd.store(env)
124+
rnd.finish(".", 0, 0)
125+
env.close()
86126

87127
# Print
88128
new_line()

0 commit comments

Comments
 (0)