Skip to content

Commit e5ba468

Browse files
authored
Merge pull request #6 from ProValarous/NehalBranch
PR #5
2 parents 63efcd5 + 5d2f1e8 commit e5ba468

23 files changed

+1053
-410
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
name: "🕷️ Bug report"
2+
description: Report errors or unexpected behavior
3+
labels:
4+
- Issue-Bug
5+
- Needs-Triage
6+
7+
body:
8+
- type: markdown
9+
attributes:
10+
value: Please make sure to [search for existing issues](https://github.com/ProValarous/Predator-Prey-Gridworld-Environment/issues) before filing a new one!
11+
12+
- type: textarea
13+
attributes:
14+
label: Steps to reproduce
15+
description: We highly suggest including screenshots and a bug report log (System tray > Report bug).
16+
placeholder: Having detailed steps helps us reproduce the bug.
17+
validations:
18+
required: true
19+
20+
- type: textarea
21+
attributes:
22+
label: ✔️ Expected Behavior
23+
placeholder: What were you expecting?
24+
validations:
25+
required: false
26+
27+
- type: textarea
28+
attributes:
29+
label: ❌ Actual Behavior
30+
placeholder: What happened instead?
31+
validations:
32+
required: false
33+
34+
- id: additionalInfo
35+
type: textarea
36+
attributes:
37+
label: Additional Information
38+
validations:
39+
required: false

src/baselines/CQL/cql_train.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
python cql_train_cleaned.py --episodes 20000 --size 8 --alpha 0.25 --gamma 0.95
1919
2020
"""
21+
2122
from __future__ import annotations
2223

2324
import argparse
@@ -52,6 +53,7 @@
5253

5354
# -------------------- helpers --------------------
5455

56+
5557
def setup_logging(level=logging.INFO):
5658
logging.basicConfig(
5759
level=level,
@@ -96,14 +98,21 @@ def make_agents(num_predators: int = 2, num_preys: int = 2) -> List[Agent]:
9698
for i in range(1, num_preys + 1):
9799
agents.append(Agent(agent_name=f"prey_{i}", agent_team=i, agent_type="prey"))
98100
for i in range(1, num_predators + 1):
99-
agents.append(Agent(agent_name=f"predator_{i}", agent_team=i, agent_type="predator"))
101+
agents.append(
102+
Agent(agent_name=f"predator_{i}", agent_team=i, agent_type="predator")
103+
)
100104
return agents
101105

102106

103-
def make_env_and_meta(agents: List[Agent], grid_size: int, seed: int) -> Tuple[GridWorldEnv, int, int]:
104-
env = GridWorldEnv(agents=agents, render_mode=None, size=grid_size, perc_num_obstacle=10, seed=seed)
107+
def make_env_and_meta(
108+
agents: List[Agent], grid_size: int, seed: int
109+
) -> Tuple[GridWorldEnv, int, int]:
110+
env = GridWorldEnv(
111+
agents=agents, render_mode=None, size=grid_size, perc_num_obstacle=10, seed=seed
112+
)
105113
n_cells = grid_size * grid_size
106-
# total joint states = n_cells ** n_agents
114+
115+
# total joint states = n_cells ** n_agents (may be very large)
107116
n_states = n_cells ** len(agents)
108117
n_actions = env.action_space.n
109118
return env, n_states, n_actions
@@ -113,11 +122,15 @@ def estimate_table_bytes(n_states: int, n_joint_actions: int, dtype=np.float32)
113122
return int(n_states) * int(n_joint_actions) * np.dtype(dtype).itemsize
114123

115124

116-
def init_joint_q_table(n_states: int, n_joint_actions: int, max_bytes: int | None = None) -> np.ndarray:
125+
def init_joint_q_table(
126+
n_states: int, n_joint_actions: int, max_bytes: int | None = None
127+
) -> np.ndarray:
117128
"""Create joint Q table; optionally check memory requirement first."""
118129
needed = estimate_table_bytes(n_states, n_joint_actions)
119130
if max_bytes is not None and needed > max_bytes:
120-
raise MemoryError(f"Joint Q-table requires {needed/(1024**3):.2f} GiB > allowed {max_bytes/(1024**3):.2f} GiB")
131+
raise MemoryError(
132+
f"Joint Q-table requires {needed/(1024**3):.2f} GiB > allowed {max_bytes/(1024**3):.2f} GiB"
133+
)
121134
return np.zeros((n_states, n_joint_actions), dtype=np.float32)
122135

123136

@@ -129,6 +142,7 @@ def save_q_table(path: str, Q: np.ndarray):
129142

130143
# -------------------- training loop --------------------
131144

145+
132146
def train(
133147
episodes: int = 5000,
134148
max_steps: int = 200,
@@ -157,18 +171,26 @@ def train(
157171

158172
env, n_states, n_actions = make_env_and_meta(agents, grid_size, seed)
159173

174+
n_agents = len(agent_names)
175+
160176
# joint-action space size
161-
n_joint_actions = n_actions ** n_agents
177+
n_joint_actions = n_actions**n_agents
162178

163179
# memory safety check (use a conservative default if not provided)
164180
if max_table_bytes is None:
165181
# set default max to 8 GiB for safety on typical dev machines
166-
max_table_bytes = 16 * 1024 ** 3
182+
max_table_bytes = 16 * 1024**3
167183

168-
LOGGER.info("Allocating joint Q-table: states=%d, joint_actions=%d", n_states, n_joint_actions)
184+
LOGGER.info(
185+
"Allocating joint Q-table: states=%d, joint_actions=%d",
186+
n_states,
187+
n_joint_actions,
188+
)
169189
Q = init_joint_q_table(n_states, n_joint_actions, max_bytes=max_table_bytes)
170190

171-
save_path_Q = os.path.join(os.path.dirname(save_path) or ".", "central_cql_q_table.npz")
191+
save_path_Q = os.path.join(
192+
os.path.dirname(save_path) or ".", "central_cql_q_table.npz"
193+
)
172194

173195
eps = eps_start
174196

@@ -200,9 +222,14 @@ def train(
200222
# select actions by marginalizing the joint-Q over others
201223
flat_row = Q[s]
202224
if flat_row.size != n_joint_actions:
203-
raise ValueError("Unexpected joint-Q row length: %d != %d" % (flat_row.size, n_joint_actions))
225+
raise ValueError(
226+
"Unexpected joint-Q row length: %d != %d"
227+
% (flat_row.size, n_joint_actions)
228+
)
204229

205-
q_tensor = flat_row.reshape(action_shape) # shape: (n_actions, n_actions, ...)
230+
q_tensor = flat_row.reshape(
231+
action_shape
232+
) # shape: (n_actions, n_actions, ...)
206233

207234
# compute marginal per-agent action-values by averaging over other axes
208235
q_vals_per_agent = []
@@ -219,14 +246,18 @@ def train(
219246
else:
220247
row = np.asarray(q_vals_per_agent[i])
221248
best = float(np.max(row))
222-
best_actions = np.flatnonzero(np.isclose(row, best)).astype(int).tolist()
249+
best_actions = (
250+
np.flatnonzero(np.isclose(row, best)).astype(int).tolist()
251+
)
223252
a_i = int(rng.choice(best_actions))
224253
chosen_actions.append(a_i)
225254

226255
joint_idx = joint_actions_to_index(chosen_actions, n_actions)
227256

228257
# build actions dict for env.step
229-
actions = {agents[i].agent_name: int(chosen_actions[i]) for i in range(n_agents)}
258+
actions = {
259+
agents[i].agent_name: int(chosen_actions[i]) for i in range(n_agents)
260+
}
230261

231262
mgp = env.step(actions)
232263
next_obs, rewards = mgp["obs"], mgp["reward"]
@@ -260,7 +291,12 @@ def train(
260291
s2 = joint_state_index(next_positions, grid_size)
261292

262293
# CQL update (centralized TD update)
263-
td_target = central_r + (gamma * next_pot_sum) - current_pot_sum + gamma * np.max(Q[s2])
294+
td_target = (
295+
central_r
296+
+ (gamma * next_pot_sum)
297+
- current_pot_sum
298+
+ gamma * np.max(Q[s2])
299+
)
264300
td_error = td_target - Q[s, joint_idx]
265301
Q[s, joint_idx] += alpha * td_error
266302

@@ -283,17 +319,32 @@ def train(
283319
writer.add_scalar("episode/captures", captures_this_episode, ep)
284320

285321
for name in agent_names:
286-
writer.add_scalar(f"episode/total_reward/{name}", float(total_reward_per_agent[name]), ep)
287-
mean_reward_running = float(np.mean(rewards_per_ep[name][-window:])) if rewards_per_ep[name] else 0.0
322+
writer.add_scalar(
323+
f"episode/total_reward/{name}", float(total_reward_per_agent[name]), ep
324+
)
325+
mean_reward_running = (
326+
float(np.mean(rewards_per_ep[name][-window:]))
327+
if rewards_per_ep[name]
328+
else 0.0
329+
)
288330
writer.add_scalar(f"mean/{name}/reward", mean_reward_running, ep)
289331

290-
mean_captures_running = float(np.mean(captures_per_ep[-window:])) if captures_per_ep else 0.0
332+
mean_captures_running = (
333+
float(np.mean(captures_per_ep[-window:])) if captures_per_ep else 0.0
334+
)
291335
writer.add_scalar("mean/captures", mean_captures_running, ep)
292336

293337
# epsilon decay and logs
294338
if ep % 100 == 0:
295339
eps = max(eps_end, eps * eps_decay)
296-
avg_per_agent = {name: np.mean(rewards_per_ep[name][-100:]) if rewards_per_ep[name] else 0.0 for name in agent_names}
340+
avg_per_agent = {
341+
name: (
342+
np.mean(rewards_per_ep[name][-100:])
343+
if rewards_per_ep[name]
344+
else 0.0
345+
)
346+
for name in agent_names
347+
}
297348
LOGGER.info(
298349
"Ep %d | eps=%.3f | averages(last100)=%s | mean captures(last100)=%.2f",
299350
ep,
@@ -323,6 +374,7 @@ def train(
323374

324375
# ---------------- CLI ----------------
325376

377+
326378
def parse_args():
327379
p = argparse.ArgumentParser("Train central CQL (tabular)")
328380
p.add_argument("--episodes", type=int, default=40000)
@@ -333,14 +385,19 @@ def parse_args():
333385
p.add_argument("--save-path", type=str, default="baselines/CQL/")
334386
p.add_argument("--predators", type=int, default=2)
335387
p.add_argument("--preys", type=int, default=2)
336-
p.add_argument("--max-table-gb", type=float, default=16.0, help="Max allowed joint-Q memory in GiB before aborting")
388+
p.add_argument(
389+
"--max-table-gb",
390+
type=float,
391+
default=16.0,
392+
help="Max allowed joint-Q memory in GiB before aborting",
393+
)
337394
return p.parse_args()
338395

339396

340397
if __name__ == "__main__":
341398
setup_logging()
342399
args = parse_args()
343-
max_table_bytes = int(args.max_table_gb * 1024 ** 3) if args.max_table_gb else None
400+
max_table_bytes = int(args.max_table_gb * 1024**3) if args.max_table_gb else None
344401
try:
345402
train(
346403
episodes=args.episodes,

src/baselines/CQL/test_iql.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def try_load_qs(file_path: str) -> Dict[str, np.ndarray]:
4040
name = key[2:] if key.startswith("Q_") else key
4141
qs[name] = arr.astype(np.float32)
4242
if not qs:
43-
raise RuntimeError(f"No valid Q-table arrays found in '{file_path}'. Keys: {list(data.files)}")
43+
raise RuntimeError(
44+
f"No valid Q-table arrays found in '{file_path}'. Keys: {list(data.files)}"
45+
)
4446
return qs
4547

48+
4649
def state_index_from_obs(obs: dict, predator: Agent, prey: Agent, size: int) -> int:
4750
"""
4851
Build state index using predator + prey positions.
@@ -54,12 +57,7 @@ def state_index_from_obs(obs: dict, predator: Agent, prey: Agent, size: int) ->
5457
pred_x, pred_y = int(pos_pred[0]), int(pos_pred[1])
5558
prey_x, prey_y = int(pos_prey[0]), int(pos_prey[1])
5659

57-
return (
58-
pred_x * size * size * size
59-
+ pred_y * size * size
60-
+ prey_x * size
61-
+ prey_y
62-
)
60+
return pred_x * size * size * size + pred_y * size * size + prey_x * size + prey_y
6361

6462

6563
def choose_action(agent: Agent, q_table: np.ndarray, s_idx: int) -> int:
@@ -83,7 +81,9 @@ def run_test(
8381

8482
prey, predator = make_agents()
8583
agents = [prey, predator]
86-
env = GridWorldEnv(agents=agents, render_mode="human", size=size, perc_num_obstacle=10)
84+
env = GridWorldEnv(
85+
agents=agents, render_mode="human", size=size, perc_num_obstacle=10
86+
)
8787

8888
try:
8989
for ep in range(1, episodes + 1):
@@ -112,7 +112,9 @@ def run_test(
112112
q_table = qs[k]
113113
break
114114
if q_table is None:
115-
raise RuntimeError(f"No Q-table for agent '{ag.agent_name}'. Keys: {list(qs.keys())}")
115+
raise RuntimeError(
116+
f"No Q-table for agent '{ag.agent_name}'. Keys: {list(qs.keys())}"
117+
)
116118

117119
a = choose_action(ag, q_table, s_idx)
118120
actions[ag.agent_name] = a
@@ -145,4 +147,10 @@ def parse_args() -> argparse.Namespace:
145147
if __name__ == "__main__":
146148
setup_logging()
147149
args = parse_args()
148-
run_test(q_file=args.file, size=args.size, episodes=args.episodes, max_steps=args.max_steps, pause=args.pause)
150+
run_test(
151+
q_file=args.file,
152+
size=args.size,
153+
episodes=args.episodes,
154+
max_steps=args.max_steps,
155+
pause=args.pause,
156+
)

0 commit comments

Comments
 (0)