1818python cql_train_cleaned.py --episodes 20000 --size 8 --alpha 0.25 --gamma 0.95
1919
2020"""
21+
2122from __future__ import annotations
2223
2324import argparse
5253
5354# -------------------- helpers --------------------
5455
56+
5557def setup_logging (level = logging .INFO ):
5658 logging .basicConfig (
5759 level = level ,
@@ -96,14 +98,21 @@ def make_agents(num_predators: int = 2, num_preys: int = 2) -> List[Agent]:
9698 for i in range (1 , num_preys + 1 ):
9799 agents .append (Agent (agent_name = f"prey_{ i } " , agent_team = i , agent_type = "prey" ))
98100 for i in range (1 , num_predators + 1 ):
99- agents .append (Agent (agent_name = f"predator_{ i } " , agent_team = i , agent_type = "predator" ))
101+ agents .append (
102+ Agent (agent_name = f"predator_{ i } " , agent_team = i , agent_type = "predator" )
103+ )
100104 return agents
101105
102106
103- def make_env_and_meta (agents : List [Agent ], grid_size : int , seed : int ) -> Tuple [GridWorldEnv , int , int ]:
104- env = GridWorldEnv (agents = agents , render_mode = None , size = grid_size , perc_num_obstacle = 10 , seed = seed )
107+ def make_env_and_meta (
108+ agents : List [Agent ], grid_size : int , seed : int
109+ ) -> Tuple [GridWorldEnv , int , int ]:
110+ env = GridWorldEnv (
111+ agents = agents , render_mode = None , size = grid_size , perc_num_obstacle = 10 , seed = seed
112+ )
105113 n_cells = grid_size * grid_size
106- # total joint states = n_cells ** n_agents
114+
115+ # total joint states = n_cells ** n_agents (may be very large)
107116 n_states = n_cells ** len (agents )
108117 n_actions = env .action_space .n
109118 return env , n_states , n_actions
@@ -113,11 +122,15 @@ def estimate_table_bytes(n_states: int, n_joint_actions: int, dtype=np.float32)
113122 return int (n_states ) * int (n_joint_actions ) * np .dtype (dtype ).itemsize
114123
115124
116- def init_joint_q_table (n_states : int , n_joint_actions : int , max_bytes : int | None = None ) -> np .ndarray :
125+ def init_joint_q_table (
126+ n_states : int , n_joint_actions : int , max_bytes : int | None = None
127+ ) -> np .ndarray :
117128 """Create joint Q table; optionally check memory requirement first."""
118129 needed = estimate_table_bytes (n_states , n_joint_actions )
119130 if max_bytes is not None and needed > max_bytes :
120- raise MemoryError (f"Joint Q-table requires { needed / (1024 ** 3 ):.2f} GiB > allowed { max_bytes / (1024 ** 3 ):.2f} GiB" )
131+ raise MemoryError (
132+ f"Joint Q-table requires { needed / (1024 ** 3 ):.2f} GiB > allowed { max_bytes / (1024 ** 3 ):.2f} GiB"
133+ )
121134 return np .zeros ((n_states , n_joint_actions ), dtype = np .float32 )
122135
123136
@@ -129,6 +142,7 @@ def save_q_table(path: str, Q: np.ndarray):
129142
130143# -------------------- training loop --------------------
131144
145+
132146def train (
133147 episodes : int = 5000 ,
134148 max_steps : int = 200 ,
@@ -157,18 +171,26 @@ def train(
157171
158172 env , n_states , n_actions = make_env_and_meta (agents , grid_size , seed )
159173
174+ n_agents = len (agent_names )
175+
160176 # joint-action space size
161- n_joint_actions = n_actions ** n_agents
177+ n_joint_actions = n_actions ** n_agents
162178
163179 # memory safety check (use a conservative default if not provided)
164180 if max_table_bytes is None :
165181 # set default max to 8 GiB for safety on typical dev machines
166- max_table_bytes = 16 * 1024 ** 3
182+ max_table_bytes = 16 * 1024 ** 3
167183
168- LOGGER .info ("Allocating joint Q-table: states=%d, joint_actions=%d" , n_states , n_joint_actions )
184+ LOGGER .info (
185+ "Allocating joint Q-table: states=%d, joint_actions=%d" ,
186+ n_states ,
187+ n_joint_actions ,
188+ )
169189 Q = init_joint_q_table (n_states , n_joint_actions , max_bytes = max_table_bytes )
170190
171- save_path_Q = os .path .join (os .path .dirname (save_path ) or "." , "central_cql_q_table.npz" )
191+ save_path_Q = os .path .join (
192+ os .path .dirname (save_path ) or "." , "central_cql_q_table.npz"
193+ )
172194
173195 eps = eps_start
174196
@@ -200,9 +222,14 @@ def train(
200222 # select actions by marginalizing the joint-Q over others
201223 flat_row = Q [s ]
202224 if flat_row .size != n_joint_actions :
203- raise ValueError ("Unexpected joint-Q row length: %d != %d" % (flat_row .size , n_joint_actions ))
225+ raise ValueError (
226+ "Unexpected joint-Q row length: %d != %d"
227+ % (flat_row .size , n_joint_actions )
228+ )
204229
205- q_tensor = flat_row .reshape (action_shape ) # shape: (n_actions, n_actions, ...)
230+ q_tensor = flat_row .reshape (
231+ action_shape
232+ ) # shape: (n_actions, n_actions, ...)
206233
207234 # compute marginal per-agent action-values by averaging over other axes
208235 q_vals_per_agent = []
@@ -219,14 +246,18 @@ def train(
219246 else :
220247 row = np .asarray (q_vals_per_agent [i ])
221248 best = float (np .max (row ))
222- best_actions = np .flatnonzero (np .isclose (row , best )).astype (int ).tolist ()
249+ best_actions = (
250+ np .flatnonzero (np .isclose (row , best )).astype (int ).tolist ()
251+ )
223252 a_i = int (rng .choice (best_actions ))
224253 chosen_actions .append (a_i )
225254
226255 joint_idx = joint_actions_to_index (chosen_actions , n_actions )
227256
228257 # build actions dict for env.step
229- actions = {agents [i ].agent_name : int (chosen_actions [i ]) for i in range (n_agents )}
258+ actions = {
259+ agents [i ].agent_name : int (chosen_actions [i ]) for i in range (n_agents )
260+ }
230261
231262 mgp = env .step (actions )
232263 next_obs , rewards = mgp ["obs" ], mgp ["reward" ]
@@ -260,7 +291,12 @@ def train(
260291 s2 = joint_state_index (next_positions , grid_size )
261292
262293 # CQL update (centralized TD update)
263- td_target = central_r + (gamma * next_pot_sum ) - current_pot_sum + gamma * np .max (Q [s2 ])
294+ td_target = (
295+ central_r
296+ + (gamma * next_pot_sum )
297+ - current_pot_sum
298+ + gamma * np .max (Q [s2 ])
299+ )
264300 td_error = td_target - Q [s , joint_idx ]
265301 Q [s , joint_idx ] += alpha * td_error
266302
@@ -283,17 +319,32 @@ def train(
283319 writer .add_scalar ("episode/captures" , captures_this_episode , ep )
284320
285321 for name in agent_names :
286- writer .add_scalar (f"episode/total_reward/{ name } " , float (total_reward_per_agent [name ]), ep )
287- mean_reward_running = float (np .mean (rewards_per_ep [name ][- window :])) if rewards_per_ep [name ] else 0.0
322+ writer .add_scalar (
323+ f"episode/total_reward/{ name } " , float (total_reward_per_agent [name ]), ep
324+ )
325+ mean_reward_running = (
326+ float (np .mean (rewards_per_ep [name ][- window :]))
327+ if rewards_per_ep [name ]
328+ else 0.0
329+ )
288330 writer .add_scalar (f"mean/{ name } /reward" , mean_reward_running , ep )
289331
290- mean_captures_running = float (np .mean (captures_per_ep [- window :])) if captures_per_ep else 0.0
332+ mean_captures_running = (
333+ float (np .mean (captures_per_ep [- window :])) if captures_per_ep else 0.0
334+ )
291335 writer .add_scalar ("mean/captures" , mean_captures_running , ep )
292336
293337 # epsilon decay and logs
294338 if ep % 100 == 0 :
295339 eps = max (eps_end , eps * eps_decay )
296- avg_per_agent = {name : np .mean (rewards_per_ep [name ][- 100 :]) if rewards_per_ep [name ] else 0.0 for name in agent_names }
340+ avg_per_agent = {
341+ name : (
342+ np .mean (rewards_per_ep [name ][- 100 :])
343+ if rewards_per_ep [name ]
344+ else 0.0
345+ )
346+ for name in agent_names
347+ }
297348 LOGGER .info (
298349 "Ep %d | eps=%.3f | averages(last100)=%s | mean captures(last100)=%.2f" ,
299350 ep ,
@@ -323,6 +374,7 @@ def train(
323374
324375# ---------------- CLI ----------------
325376
377+
326378def parse_args ():
327379 p = argparse .ArgumentParser ("Train central CQL (tabular)" )
328380 p .add_argument ("--episodes" , type = int , default = 40000 )
@@ -333,14 +385,19 @@ def parse_args():
333385 p .add_argument ("--save-path" , type = str , default = "baselines/CQL/" )
334386 p .add_argument ("--predators" , type = int , default = 2 )
335387 p .add_argument ("--preys" , type = int , default = 2 )
336- p .add_argument ("--max-table-gb" , type = float , default = 16.0 , help = "Max allowed joint-Q memory in GiB before aborting" )
388+ p .add_argument (
389+ "--max-table-gb" ,
390+ type = float ,
391+ default = 16.0 ,
392+ help = "Max allowed joint-Q memory in GiB before aborting" ,
393+ )
337394 return p .parse_args ()
338395
339396
340397if __name__ == "__main__" :
341398 setup_logging ()
342399 args = parse_args ()
343- max_table_bytes = int (args .max_table_gb * 1024 ** 3 ) if args .max_table_gb else None
400+ max_table_bytes = int (args .max_table_gb * 1024 ** 3 ) if args .max_table_gb else None
344401 try :
345402 train (
346403 episodes = args .episodes ,
0 commit comments