@@ -88,8 +88,8 @@ def test_main_console(script_mod):
8888_rl_agent_loading_configs = {
8989 "agent_path" : CARTPOLE_TEST_POLICY_PATH ,
9090 # FIXME(yawen): the policy we load was trained on 8 parallel environments
91- # and for some reason using it breaks if we use just 1 (like would be the
92- # default with the fast named_config)
91+ # and for some reason using it breaks if we use just 1 (like would be the
92+ # default with the fast named_config)
9393 "common" : dict (num_vec = 8 ),
9494}
9595
@@ -232,6 +232,40 @@ def test_train_dagger_main(tmpdir):
232232 assert isinstance (run .result , dict )
233233
234234
235+ def test_train_dagger_warmstart (tmpdir ):
236+ run = train_imitation .train_imitation_ex .run (
237+ command_name = "dagger" ,
238+ named_configs = ["cartpole" ] + ALGO_FAST_CONFIGS ["imitation" ],
239+ config_updates = dict (
240+ common = dict (log_root = tmpdir ),
241+ demonstrations = dict (rollout_path = CARTPOLE_TEST_ROLLOUT_PATH ),
242+ dagger = dict (
243+ expert_policy_type = "ppo" ,
244+ expert_policy_path = CARTPOLE_TEST_POLICY_PATH ,
245+ ),
246+ ),
247+ )
248+ assert run .status == "COMPLETED"
249+
250+ log_dir = pathlib .Path (run .config ["common" ]["log_dir" ])
251+ policy_path = log_dir / "scratch" / "policy-latest.pt"
252+ run_warmstart = train_imitation .train_imitation_ex .run (
253+ command_name = "dagger" ,
254+ named_configs = ["cartpole" ] + ALGO_FAST_CONFIGS ["imitation" ],
255+ config_updates = dict (
256+ common = dict (log_root = tmpdir ),
257+ demonstrations = dict (rollout_path = CARTPOLE_TEST_ROLLOUT_PATH ),
258+ dagger = dict (
259+ expert_policy_type = "ppo" ,
260+ expert_policy_path = CARTPOLE_TEST_POLICY_PATH ,
261+ ),
262+ agent_path = policy_path ,
263+ ),
264+ )
265+ assert run_warmstart .status == "COMPLETED"
266+ assert isinstance (run_warmstart .result , dict )
267+
268+
235269def test_train_dagger_error_and_exceptions (tmpdir ):
236270 with pytest .raises (Exception , match = ".*expert_policy_path cannot be None.*" ):
237271 train_imitation .train_imitation_ex .run (
@@ -261,6 +295,32 @@ def test_train_bc_main(tmpdir):
261295 assert isinstance (run .result , dict )
262296
263297
298+ def test_train_bc_warmstart (tmpdir ):
299+ run = train_imitation .train_imitation_ex .run (
300+ command_name = "bc" ,
301+ named_configs = ["cartpole" ] + ALGO_FAST_CONFIGS ["imitation" ],
302+ config_updates = dict (
303+ common = dict (log_root = tmpdir ),
304+ demonstrations = dict (rollout_path = CARTPOLE_TEST_ROLLOUT_PATH ),
305+ ),
306+ )
307+ assert run .status == "COMPLETED"
308+
309+ policy_path = pathlib .Path (run .config ["common" ]["log_dir" ]) / "final.th"
310+ run_warmstart = train_imitation .train_imitation_ex .run (
311+ command_name = "bc" ,
312+ named_configs = ["cartpole" ] + ALGO_FAST_CONFIGS ["imitation" ],
313+ config_updates = dict (
314+ common = dict (log_root = tmpdir ),
315+ demonstrations = dict (rollout_path = CARTPOLE_TEST_ROLLOUT_PATH ),
316+ agent_path = policy_path ,
317+ ),
318+ )
319+
320+ assert run_warmstart .status == "COMPLETED"
321+ assert isinstance (run_warmstart .result , dict )
322+
323+
264324TRAIN_RL_PPO_CONFIGS = [{}, _rl_agent_loading_configs ]
265325
266326
@@ -376,6 +436,35 @@ def test_train_adversarial(tmpdir, named_configs, command):
376436 _check_train_ex_result (run .result )
377437
378438
439+ @pytest .mark .parametrize ("command" , ("airl" , "gail" ))
440+ def test_train_adversarial_warmstart (tmpdir , command ):
441+ named_configs = ["cartpole" ] + ALGO_FAST_CONFIGS ["adversarial" ]
442+ config_updates = {
443+ "common" : dict (log_root = tmpdir ),
444+ "demonstrations" : dict (rollout_path = CARTPOLE_TEST_ROLLOUT_PATH ),
445+ }
446+ run = train_adversarial .train_adversarial_ex .run (
447+ command_name = command ,
448+ named_configs = named_configs ,
449+ config_updates = config_updates ,
450+ )
451+
452+ log_dir = pathlib .Path (run .config ["common" ]["log_dir" ])
453+ policy_path = log_dir / "checkpoints" / "final" / "gen_policy"
454+
455+ run_warmstart = train_adversarial .train_adversarial_ex .run (
456+ command_name = command ,
457+ named_configs = named_configs ,
458+ config_updates = {
459+ "agent_path" : policy_path ,
460+ ** config_updates ,
461+ },
462+ )
463+
464+ assert run_warmstart .status == "COMPLETED"
465+ _check_train_ex_result (run_warmstart .result )
466+
467+
379468@pytest .mark .parametrize ("command" , ("airl" , "gail" ))
380469def test_train_adversarial_sac (tmpdir , command ):
381470 """Smoke test for imitation.scripts.train_adversarial."""
0 commit comments