-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain_agent.py
More file actions
85 lines (71 loc) · 2.21 KB
/
train_agent.py
File metadata and controls
85 lines (71 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import os
import gymnasium as gym
import torch
from algorithms.a2c import A2CTrainer
from algorithms.dqn import DQNTrainer
from algorithms.ppo import PPOTrainer
from algorithms.reinforce import ReinforceTrainer
from algorithms.sarsa import SarsaTrainer
from algorithms.vanilla_dqn import VanillaDQNTrainer
def get_trainer(algorithm_name):
"""
Creates an instance of a trainer for the specified algorithm name.
:param algorithm_name: str, name of the algorithm to use
:returns: class with type BaseTrainer
"""
if algorithm_name == "reinforce":
return ReinforceTrainer()
elif algorithm_name == "sarsa":
return SarsaTrainer()
elif algorithm_name == "vanilla_dqn":
return VanillaDQNTrainer()
elif algorithm_name == "dqn":
return DQNTrainer()
elif algorithm_name == "a2c":
return A2CTrainer()
elif algorithm_name == "ppo":
return PPOTrainer()
raise ValueError("Unknown algorithm {}".format(algorithm_name))
def main(args):
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
trainer = get_trainer(args.algorithm)
if not os.path.exists("saved_agents"):
os.makedirs("saved_agents")
agent = trainer.train_agent(
env=env,
test_env=test_env,
save_name=args.save_name,
render=args.render,
)
torch.save(agent, f"saved_agents/{args.save_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--env-name",
type=str,
required=True,
help="Gym environment to train an agent for",
)
parser.add_argument(
"--algorithm",
type=str,
required=True,
choices=["reinforce", "sarsa", "vanilla_dqn", "dqn", "a2c", "ppo"],
help="Algorithm to use for training an agent",
)
parser.add_argument(
"--render",
default=False,
action="store_true",
help="Whether or not to render the environment during training",
)
parser.add_argument(
"--save-name",
type=str,
default="test_agent",
help="Name to save the agent with after training",
)
args = parser.parse_args()
main(args)