-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathinspect_agent.py
More file actions
42 lines (34 loc) · 937 Bytes
/
inspect_agent.py
File metadata and controls
42 lines (34 loc) · 937 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
import os
import gymnasium as gym
import torch
def main(args):
env = gym.make(args.env_name, render_mode="human")
agent = torch.load(f"saved_agents/{args.agent_name}")
while True:
obs, _ = env.reset()
agent.episode_reset()
done = False
total_reward = 0
while not done:
obs, reward, done, _, _ = env.step(agent.act(obs))
total_reward += reward
env.render()
print(total_reward)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--env-name",
type=str,
required=True,
help="Gym environment to test the agent on",
)
parser.add_argument(
"--agent-name",
type=str,
required=True,
choices=os.listdir("saved_agents"),
help="Name of the agent to load",
)
args = parser.parse_args()
main(args)