From ef349ba7cfdebf339e9aedcd09d89d6c917f86e5 Mon Sep 17 00:00:00 2001 From: Christopher Zhang Cui Date: Fri, 5 Sep 2025 06:12:21 +0000 Subject: [PATCH 1/4] Added scienceworld and jericho splits, fixed alfworld split bug --- tales/get_env_splits.py | 169 ++++++++++++++++++++----- tales/jericho/jericho_data.py | 35 ++++- tales/scienceworld/scienceworld_env.py | 6 +- tales/textworld/textworld_data.py | 10 +- tales/textworld/textworld_env.py | 3 +- tales/textworld_express/twx_env.py | 8 +- 6 files changed, 189 insertions(+), 42 deletions(-) diff --git a/tales/get_env_splits.py b/tales/get_env_splits.py index 47c737b..1ea8484 100644 --- a/tales/get_env_splits.py +++ b/tales/get_env_splits.py @@ -1,12 +1,17 @@ # This is literally just a wrapper to get the train and test-time splits. Is 99.99% just building on Marc's existing code. import glob from os.path import join as pjoin + +from tales.alfworld import alfworld_data, alfworld_env +from tales.jericho import jericho_data +from tales.scienceworld import scienceworld_data, scienceworld_env from tales.textworld import textworld_data, textworld_env from tales.textworld_express import twx_data, twx_env -from tales.alfworld import alfworld_data, alfworld_env -def get_textworld_env_splits(difficulties = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], games_per_difficulty=1): +def get_textworld_env_splits( + difficulties=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], games_per_difficulty=1 +): # Returns a list of envs for training and test splits for Textworld-Cookingworld: # For training, we let the user specify difficulties and how many games per difficulty to include. # For testing, we use all difficulties from 1 to 10, and use one game each, similar to the evaluation in the original paper. @@ -29,36 +34,95 @@ def get_textworld_env_splits(difficulties = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], gam return train_games_files, test_games_files -def get_alfworld_env_splits(games_per_task = 2): + +def get_alfworld_env_splits(games_per_task=2): # For alfworld, we just generate the test split first and then condition the train split to not have the same files as the text split. alfworld_data.prepare_alfworld_data() # make sure the data is ready test_games_files = [] for task in alfworld_data.TASK_TYPES: - game_files_seen = sorted(glob.glob(pjoin(alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, f"{task}*", "**", "*.tw-pddl"))) - game_files_unseen = sorted(glob.glob(pjoin(alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, f"{task}*", "**", "*.tw-pddl"))) + game_files_seen = sorted( + glob.glob( + pjoin( + alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, + f"{task}*", + "**", + "*.tw-pddl", + ) + ) + ) + game_files_unseen = sorted( + glob.glob( + pjoin( + alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, + f"{task}*", + "**", + "*.tw-pddl", + ) + ) + ) # The test split always only takes the first game file in the split. - test_games_files.extend(game_files_seen[[0]]) - test_games_files.extend(game_files_unseen[[0]]) + test_games_files.append(game_files_seen[0]) + test_games_files.append(game_files_unseen[0]) + print(len(test_games_files)) # Assert we have the right number of files. + print(len(test_games_files)) + print(len(alfworld_data.TASK_TYPES)) assert len(test_games_files) == 2 * len(alfworld_data.TASK_TYPES) # Now, get the training split. # We want to make sure that the training split does not have any files that are in the test split. train_games_files = [] for task in alfworld_data.TASK_TYPES: - game_files_seen = sorted(glob.glob(pjoin(alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, f"{task}*", "**", "*.tw-pddl"))) - game_files_unseen = sorted(glob.glob(pjoin(alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, f"{task}*", "**", "*.tw-pddl"))) + game_files_seen = sorted( + glob.glob( + pjoin( + alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, + f"{task}*", + "**", + "*.tw-pddl", + ) + ) + ) + game_files_unseen = sorted( + glob.glob( + pjoin( + alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, + f"{task}*", + "**", + "*.tw-pddl", + ) + ) + ) # Remove any files that are in the test split. - filtered_game_files_seen = [f for f in game_files_seen if not any(s in f for s in test_games_files)] - filtered_game_files_unseen = [f for f in game_files_unseen if not any(s in f for s in test_games_files)] + filtered_game_files_seen = [ + f for f in game_files_seen if not any(s in f for s in test_games_files) + ] + filtered_game_files_unseen = [ + f for f in game_files_unseen if not any(s in f for s in test_games_files) + ] # Now get the requested number of games per task type train_games_files.extend(filtered_game_files_seen[:games_per_task]) train_games_files.extend(filtered_game_files_unseen[:games_per_task]) return train_games_files, test_games_files - + + +def get_jericho_env_splits(): + # For jericho, we just use the predefined train/test split. + jericho_data.prepare_jericho_data() # make sure the data is ready + all_games = sorted(jericho_data.GAMES_INFOS.keys()) + train_games = jericho_data.JERICHO_TRAIN_GAMES + test_games = [g for g in all_games if g not in train_games] + + # Get the game files: + train_games_files = [jericho_data.get_game(g) for g in train_games] + test_games_files = [jericho_data.get_game(g) for g in test_games] + + return train_games_files, test_games_files + + class GeneralTALESEnv: # A general env wrapper such that the train/test files gotten from the above functions can easily just be plugged into an env and ran. # This returns a 'fake' batch env that will always deterministically cycle through the provided env file/seeds unless explicitly told to randomize (for training) @@ -76,49 +140,88 @@ def __init__(self, env_name, split, *args, **kwargs): self.game_files = self.train_envs else: self.game_files = self.test_envs - self.env = textworld_env.TextWorldEnv(self.game_files[self.env_idx], - *args, **kwargs) + self.env = textworld_env.TextWorldEnv( + self.game_files[self.env_idx], *args, **kwargs + ) elif env_name == "twx": # Train/test in twx are just seed based. self.game_files = twx_data.TASKS - self.env = twx_env.TextWorldExpressEnv(game_name = self.game_files[self.env_idx][1], - game_params = self.game_files[self.env_idx][2], - admissible_commands=False, - split=split, - *args, **kwargs) + self.env = twx_env.TextWorldExpressEnv( + game_name=self.game_files[self.env_idx][1], + game_params=self.game_files[self.env_idx][2], + admissible_commands=False, + split=split, + *args, + **kwargs, + ) elif env_name == "alfworld": self.train_envs, self.test_envs = get_alfworld_env_splits(**kwargs) if split == "train": self.game_files = self.train_envs else: self.game_files = self.test_envs - self.env = alfworld_env.ALFWorldEnv(self.game_files[self.env_idx], - *args, **kwargs) + self.env = alfworld_env.ALFWorldEnv( + self.game_files[self.env_idx], *args, **kwargs + ) + elif env_name == "scienceworld": + self.game_files = scienceworld_data.get_task_names() + self.env = scienceworld_env.ScienceWorldEnv( + task_name=self.game_files[self.env_idx], split=split, *args, **kwargs + ) + elif env_name == "jericho": + self.train_envs, self.test_envs = get_jericho_env_splits() + if split == "train": + self.game_files = self.train_envs + else: + self.game_files = self.test_envs + self.env = textworld_env.TextWorldEnv( + self.game_files[self.env_idx], *args, **kwargs + ) else: - raise ValueError(f"Unknown environment name: {env_name}, please choose from textworld, twx, or alfworld.") - + raise ValueError( + f"Environment {env_name} not supported. Supported envs are textworld, twx, alfworld, scienceworld, and jericho." + ) + # Not sure if this is right, need to double check w/ Marc def reset(self, *, seed=None, options=None): return self.env.reset(seed=seed, options=options) - def get_next_task(self, seed = None, options=None): + def get_next_task(self, seed=None, options=None): # Move to the next env in the list. self.env_idx = (self.env_idx + 1) % len(self.game_files) if self.env is not None: self.env.close() if self.env_name == "textworld": - self.env = textworld_env.TextWorldEnv(self.game_files[self.env_idx], *self.args, **self.kwargs) + self.env = textworld_env.TextWorldEnv( + self.game_files[self.env_idx], *self.args, **self.kwargs + ) elif self.env_name == "twx": - self.env = twx_env.TextWorldExpressEnv(game_name = self.game_files[self.env_idx][1], - game_params = self.game_files[self.env_idx][2], - *self.args, **self.kwargs) + self.env = twx_env.TextWorldExpressEnv( + game_name=self.game_files[self.env_idx][1], + game_params=self.game_files[self.env_idx][2], + *self.args, + **self.kwargs, + ) elif self.env_name == "alfworld": - self.env = alfworld_env.ALFWorldEnv(self.game_files[self.env_idx], *self.args, **self.kwargs) + self.env = alfworld_env.ALFWorldEnv( + self.game_files[self.env_idx], *self.args, **self.kwargs + ) + elif self.env_name == "scienceworld": + self.env = scienceworld_env.ScienceWorldEnv( + task_name=self.game_files[self.env_idx], + split=self.split, + *self.args, + **self.kwargs, + ) + elif self.env_name == "jericho": + self.env = textworld_env.TextWorldEnv( + self.game_files[self.env_idx], *self.args, **self.kwargs + ) else: - raise ValueError(f"next_task not implemented for env {self.env_name}, only for textworld and alfworld.") - return self.reset(seed = seed, options = options) + raise ValueError( + f"next_task not implemented for env {self.env_name}, only for textworld and alfworld." + ) + return self.reset(seed=seed, options=options) def step(self, action): return self.env.step(action) - - diff --git a/tales/jericho/jericho_data.py b/tales/jericho/jericho_data.py index 6eab578..1cbcc2b 100644 --- a/tales/jericho/jericho_data.py +++ b/tales/jericho/jericho_data.py @@ -8,9 +8,41 @@ GAMES_URLS = "https://github.com/BYU-PCCL/z-machine-games/raw/master/jericho-game-suite" GAMES_JSON_URL = "https://raw.githubusercontent.com/microsoft/tale-suite/refs/heads/main/tales/jericho/games.json" TALES_CACHE_JERICHO = pjoin(TALES_CACHE_HOME, "jericho") +JERICHO_TRAIN_GAMES = [ + "loose", + "karn", + "ballyhoo", + "zork2", + "adventureland", + "omniquest", + "weapon", + "905", + "wishbringer", + "night", + "tryst205", + "zork3", + "murdac", + "afflicted", + "moonlit", + "dragon", + "reverb", + "jewel", + "enter", + "snacktime", + "enchanter", + "acorncourt", + "huntdark", + "gold", + "yomomma", + "inhumane", + "zenon", +] # Check if the games json exists, and if not, then download it. -if not os.path.exists(pjoin(os.path.dirname(__file__), "games.json")) or TALES_FORCE_DOWNLOAD: +if ( + not os.path.exists(pjoin(os.path.dirname(__file__), "games.json")) + or TALES_FORCE_DOWNLOAD +): download( GAMES_JSON_URL, dst=pjoin(os.path.dirname(__file__)), @@ -19,7 +51,6 @@ ) - with open(pjoin(os.path.dirname(__file__), "games.json")) as f: GAMES_INFOS = json.load(f) diff --git a/tales/scienceworld/scienceworld_env.py b/tales/scienceworld/scienceworld_env.py index 64e2dbf..6572e0a 100644 --- a/tales/scienceworld/scienceworld_env.py +++ b/tales/scienceworld/scienceworld_env.py @@ -9,12 +9,14 @@ class ScienceWorldEnv(gym.Env): - def __init__(self, task_name, admissible_commands=False, *args, **kwargs): + def __init__( + self, task_name, admissible_commands=False, split="Test", *args, **kwargs + ): self.task_name = task_name self.admissible_commands = admissible_commands self.env = scienceworld.ScienceWorldEnv(self.task_name, envStepLimit=np.inf) self.variations = scienceworld_data.get_variations( - self.task_name, split="test", env=self.env + self.task_name, split=split, env=self.env ) self.variation = self.variations[0] diff --git a/tales/textworld/textworld_data.py b/tales/textworld/textworld_data.py index 38a5e92..f1d5b49 100644 --- a/tales/textworld/textworld_data.py +++ b/tales/textworld/textworld_data.py @@ -42,9 +42,13 @@ def prepare_twcooking_data(force=TALES_FORCE_DOWNLOAD): def get_cooking_game(difficulty, split="test"): prepare_twcooking_data() # make sure the data is ready if split == "train": - cooking_dir = pjoin(TALES_CACHE_TWCOOKING_TRAIN, f"difficulty_level_{difficulty}") + cooking_dir = pjoin( + TALES_CACHE_TWCOOKING_TRAIN, f"difficulty_level_{difficulty}" + ) elif split == "test": - cooking_dir = pjoin(TALES_CACHE_TWCOOKING_TEST, f"difficulty_level_{difficulty}") - + cooking_dir = pjoin( + TALES_CACHE_TWCOOKING_TEST, f"difficulty_level_{difficulty}" + ) + game_files = glob.glob(pjoin(cooking_dir, "*.z8")) return game_files diff --git a/tales/textworld/textworld_env.py b/tales/textworld/textworld_env.py index e8c8557..324240c 100644 --- a/tales/textworld/textworld_env.py +++ b/tales/textworld/textworld_env.py @@ -5,6 +5,7 @@ from . import textworld_data + class TextWorldEnv(gym.Env): def __init__(self, gamefile, admissible_commands=False, *args, **kwargs): self.infos = textworld.EnvInfos( @@ -25,7 +26,7 @@ def reset(self, *, seed=None, options=None): if self.env is None: self.env = textworld.start(self.gamefile, self.infos, wrappers=[Filter]) - + return self.env.reset() def step(self, action): diff --git a/tales/textworld_express/twx_env.py b/tales/textworld_express/twx_env.py index 2c29d9f..0324bd1 100644 --- a/tales/textworld_express/twx_env.py +++ b/tales/textworld_express/twx_env.py @@ -10,7 +10,13 @@ class TextWorldExpressEnv(gym.Env): def __init__( - self, game_name, game_params, admissible_commands=False, split="test", *args, **kwargs + self, + game_name, + game_params, + admissible_commands=False, + split="test", + *args, + **kwargs, ): self.game_name = game_name self.game_params = game_params From a7e695b1445262b66a4b59fb8cd04ab3e6e67e40 Mon Sep 17 00:00:00 2001 From: Christopher Zhang Cui Date: Sat, 13 Sep 2025 23:10:56 +0000 Subject: [PATCH 2/4] Added train and test splits directly into env init --- README.md | 1 + tales/README.md | 31 +++++++++++++++++++++++++++++ tales/alfworld/__init__.py | 21 ++++++++++++++++++- tales/alfworld/alfworld_env.py | 6 +++--- tales/get_env_splits.py | 12 +++++------ tales/scienceworld/__init__.py | 11 +++++++++- tales/textworld/__init__.py | 14 ++++++++++++- tales/textworld/textworld_env.py | 6 +++--- tales/textworld_express/__init__.py | 11 +++++++++- 9 files changed, 97 insertions(+), 16 deletions(-) create mode 100644 tales/README.md diff --git a/README.md b/README.md index 9a804ef..dac18d0 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ We provide a pre-built docker image at An example script can be found in the scripts folder. ## 2. Getting Started +0. For training details, please see the README.md in the tales folder. 1. Run benchmark evaluation on all the games for the specified random agent: diff --git a/tales/README.md b/tales/README.md new file mode 100644 index 0000000..0684dca --- /dev/null +++ b/tales/README.md @@ -0,0 +1,31 @@ +# Training and Testing on TALES +TALES offers both train splits and test splits, the latter of which make up the games all models in our technical report were evaluated on. + +The following is an example of how to import desired environments and allow an agent to play through them. + +Note that importing the relevant framework automatically registers all environments in that framework with gym. +You can individually import the frameworks if you want to only evaluate on them one at a time. +For now, we do not include a jericho train split. + +``` +import gymnasium as gym +from tales import * + +# Training splits +train_envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' in env_spec.id] + +# Testing splits +envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' not in env_spec.id] + +train_env = gym.make( + train_envs[0], + disable_env_checker=True, + admissible_commands=True, +) + +test_env = gym.make( + envs[0], + disable_env_checker=True, + admissible_commands=True, +) +``` \ No newline at end of file diff --git a/tales/alfworld/__init__.py b/tales/alfworld/__init__.py index e63d292..d3b7178 100644 --- a/tales/alfworld/__init__.py +++ b/tales/alfworld/__init__.py @@ -4,9 +4,14 @@ from .alfworld_env import ALFWorldTask environments = [] +train_environments = [] for split in ["seen", "unseen"]: for task_type in TASK_TYPES: + gamefiles = sorted(alfworld_data.get_alfworld_game(task_type, split)) + train_gamefiles = gamefiles[1:] + test_gamefiles = [gamefiles[0]] + task_name = task_type.replace("_", " ").title().replace(" ", "") env_name = f"ALFWorld{task_name}{split.title()}" environments.append([env_name, "v0"]) @@ -14,7 +19,21 @@ gym.register( id=f"tales/{env_name}-v0", entry_point="tales.alfworld:ALFWorldTask", - kwargs={"task_type": task_type, "split": split}, + kwargs={ + "all_gamefiles": test_gamefiles, + "start_gamefile": test_gamefiles[0], + }, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.alfworld:ALFWorldTask", + kwargs={ + "all_gamefiles": train_gamefiles, + "start_gamefile": train_gamefiles[0], + }, ) diff --git a/tales/alfworld/alfworld_env.py b/tales/alfworld/alfworld_env.py index 380885b..9499ff4 100644 --- a/tales/alfworld/alfworld_env.py +++ b/tales/alfworld/alfworld_env.py @@ -50,9 +50,9 @@ def step(self, action): class ALFWorldTask(ALFWorldEnv): - def __init__(self, task_type, split, *args, **kwargs): - self.gamefiles = sorted(alfworld_data.get_alfworld_game(task_type, split)) - super().__init__(self.gamefiles[0], *args, **kwargs) + def __init__(self, all_gamefiles, start_gamefile, *args, **kwargs): + self.gamefiles = all_gamefiles + super().__init__(start_gamefile, *args, **kwargs) def reset(self, *, seed=None, options=None): if seed is not None: diff --git a/tales/get_env_splits.py b/tales/get_env_splits.py index 1ea8484..43f85e2 100644 --- a/tales/get_env_splits.py +++ b/tales/get_env_splits.py @@ -2,8 +2,8 @@ import glob from os.path import join as pjoin +import tales.jericho as jericho from tales.alfworld import alfworld_data, alfworld_env -from tales.jericho import jericho_data from tales.scienceworld import scienceworld_data, scienceworld_env from tales.textworld import textworld_data, textworld_env from tales.textworld_express import twx_data, twx_env @@ -111,14 +111,14 @@ def get_alfworld_env_splits(games_per_task=2): def get_jericho_env_splits(): # For jericho, we just use the predefined train/test split. - jericho_data.prepare_jericho_data() # make sure the data is ready - all_games = sorted(jericho_data.GAMES_INFOS.keys()) - train_games = jericho_data.JERICHO_TRAIN_GAMES + jericho.jericho_data.prepare_jericho_data() # make sure the data is ready + all_games = sorted(jericho.jericho_data.GAMES_INFOS.keys()) + train_games = jericho.jericho_data.JERICHO_TRAIN_GAMES test_games = [g for g in all_games if g not in train_games] # Get the game files: - train_games_files = [jericho_data.get_game(g) for g in train_games] - test_games_files = [jericho_data.get_game(g) for g in test_games] + train_games_files = [jericho.jericho_data.get_game(g) for g in train_games] + test_games_files = [jericho.jericho_data.get_game(g) for g in test_games] return train_games_files, test_games_files diff --git a/tales/scienceworld/__init__.py b/tales/scienceworld/__init__.py index cdfb549..dbf5b31 100644 --- a/tales/scienceworld/__init__.py +++ b/tales/scienceworld/__init__.py @@ -3,6 +3,7 @@ from .scienceworld_env import TASK_NAMES, ScienceWorldEnv environments = [] +train_environments = [] for task_name in TASK_NAMES: env_name = f"ScienceWorld{task_name.title().replace('-', '')}" @@ -11,7 +12,15 @@ gym.register( id=f"tales/{env_name}-v0", entry_point="tales.scienceworld:ScienceWorldEnv", - kwargs={"task_name": task_name}, + kwargs={"task_name": task_name, "split": "test"}, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.scienceworld:ScienceWorldEnv", + kwargs={"task_name": task_name, "split": "train"}, ) diff --git a/tales/textworld/__init__.py b/tales/textworld/__init__.py index 146e1da..cd5835b 100644 --- a/tales/textworld/__init__.py +++ b/tales/textworld/__init__.py @@ -4,16 +4,28 @@ from .textworld_env import TextWorldEnv, TWCookingEnv environments = [] +train_environments = [] # TWCookingEnv for difficulty in range(1, 10 + 1): + gamefiles = sorted(textworld_data.get_cooking_game(difficulty)) + train_gamefiles = gamefiles[1:] + test_gamefiles = [gamefiles[0]] env_name = f"TWCookingLevel{difficulty}" environments.append([env_name, "v0"]) gym.register( id=f"tales/{env_name}-v0", entry_point="tales.textworld:TWCookingEnv", - kwargs={"difficulty": difficulty}, + kwargs={"all_gamefiles": test_gamefiles, "start_gamefile": test_gamefiles[0]}, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.textworld:TWCookingEnv", + kwargs={"all_gamefiles": train_gamefiles, "start_gamefile": train_gamefiles[0]}, ) diff --git a/tales/textworld/textworld_env.py b/tales/textworld/textworld_env.py index 324240c..30bbbe6 100644 --- a/tales/textworld/textworld_env.py +++ b/tales/textworld/textworld_env.py @@ -34,9 +34,9 @@ def step(self, action): class TWCookingEnv(TextWorldEnv): - def __init__(self, difficulty, *args, **kwargs): - self.gamefiles = sorted(textworld_data.get_cooking_game(difficulty)) - super().__init__(self.gamefiles[0], *args, **kwargs) + def __init__(self, all_gamefiles, start_gamefile, *args, **kwargs): + self.gamefiles = all_gamefiles + super().__init__(start_gamefile, *args, **kwargs) def reset(self, *, seed=None, options=None): if seed is not None: diff --git a/tales/textworld_express/__init__.py b/tales/textworld_express/__init__.py index a2e73d0..ff22589 100644 --- a/tales/textworld_express/__init__.py +++ b/tales/textworld_express/__init__.py @@ -3,6 +3,7 @@ from .twx_env import TASKS, TextWorldExpressEnv environments = [] +train_environments = [] for task_name, game_name, game_params in TASKS: env_name = f"TWX{task_name}" @@ -11,7 +12,15 @@ gym.register( id=f"tales/{env_name}-v0", entry_point="tales.textworld_express:TextWorldExpressEnv", - kwargs={"game_name": game_name, "game_params": game_params}, + kwargs={"game_name": game_name, "game_params": game_params, "split": "test"}, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.textworld_express:TextWorldExpressEnv", + kwargs={"game_name": game_name, "game_params": game_params, "split": "train"}, ) From 1214b521185ac6af3191fd9be082bff56b59d6a8 Mon Sep 17 00:00:00 2001 From: Christopher Zhang Cui Date: Wed, 22 Oct 2025 03:47:24 +0000 Subject: [PATCH 3/4] Got rid of get_env_split (not used anymore) Signed-off-by: Christopher Zhang Cui --- tales/get_env_splits.py | 186 ---------------------------------------- 1 file changed, 186 deletions(-) delete mode 100644 tales/get_env_splits.py diff --git a/tales/get_env_splits.py b/tales/get_env_splits.py deleted file mode 100644 index f5d69b4..0000000 --- a/tales/get_env_splits.py +++ /dev/null @@ -1,186 +0,0 @@ -# This is literally just a wrapper to get the train and test-time splits. Is 99.99% just building on Marc's existing code. -import glob -from os.path import join as pjoin - -from tales.alfworld import alfworld_data, alfworld_env -from tales.textworld import textworld_data, textworld_env -from tales.textworld_express import twx_data, twx_env - - -def get_textworld_env_splits( - difficulties=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], games_per_difficulty=1 -): - # Returns a list of envs for training and test splits for Textworld-Cookingworld: - # For training, we let the user specify difficulties and how many games per difficulty to include. - # For testing, we use all difficulties from 1 to 10, and use one game each, similar to the evaluation in the original paper. - textworld_data.prepare_twcooking_data() # make sure the data is ready - - # Training split: - # Get the game files: - train_games_files = [] - for diff in difficulties: - all_games = sorted(textworld_data.get_cooking_game(diff, split="train")) - train_games_files.extend(all_games[:games_per_difficulty]) - - # Testing split: - test_games_files = [] - for i in range(1, 11): - # Just get one game per difficulty for testing. - # This is similar to the evaluation in the original paper. - all_games = sorted(textworld_data.get_cooking_game(i, split="test")) - test_games_files.append(all_games[0]) - - return train_games_files, test_games_files - - -def get_alfworld_env_splits(games_per_task=2): - # For alfworld, we just generate the test split first and then condition the train split to not have the same files as the text split. - alfworld_data.prepare_alfworld_data() # make sure the data is ready - test_games_files = [] - for task in alfworld_data.TASK_TYPES: - game_files_seen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - game_files_unseen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - # The test split always only takes the first game file in the split. - test_games_files.append(game_files_seen[0]) - test_games_files.append(game_files_unseen[0]) - print(len(test_games_files)) - - # Assert we have the right number of files. - print(len(test_games_files)) - print(len(alfworld_data.TASK_TYPES)) - assert len(test_games_files) == 2 * len(alfworld_data.TASK_TYPES) - - # Now, get the training split. - # We want to make sure that the training split does not have any files that are in the test split. - train_games_files = [] - for task in alfworld_data.TASK_TYPES: - game_files_seen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - game_files_unseen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - # Remove any files that are in the test split. - filtered_game_files_seen = [ - f for f in game_files_seen if not any(s in f for s in test_games_files) - ] - filtered_game_files_unseen = [ - f for f in game_files_unseen if not any(s in f for s in test_games_files) - ] - - # Now get the requested number of games per task type - train_games_files.extend(filtered_game_files_seen[:games_per_task]) - train_games_files.extend(filtered_game_files_unseen[:games_per_task]) - - return train_games_files, test_games_files - - -class GeneralTALESEnv: - # A general env wrapper such that the train/test files gotten from the above functions can easily just be plugged into an env and ran. - # This returns a 'fake' batch env that will always deterministically cycle through the provided env file/seeds unless explicitly told to randomize (for training) - # TODO: implement for Scienceworld and Jericho - def __init__(self, env_name, split, *args, **kwargs): - self.env_name = env_name - self.split = split - self.env_idx = 0 - self.kwargs = kwargs - self.args = args - self.game_files = None - if env_name == "textworld": - self.train_envs, self.test_envs = get_textworld_env_splits(**kwargs) - if split == "train": - self.game_files = self.train_envs - else: - self.game_files = self.test_envs - self.env = textworld_env.TextWorldEnv( - self.game_files[self.env_idx], *args, **kwargs - ) - elif env_name == "twx": - # Train/test in twx are just seed based. - self.game_files = twx_data.TASKS - self.env = twx_env.TextWorldExpressEnv( - game_name=self.game_files[self.env_idx][1], - game_params=self.game_files[self.env_idx][2], - admissible_commands=False, - split=split, - *args, - **kwargs, - ) - elif env_name == "alfworld": - self.train_envs, self.test_envs = get_alfworld_env_splits(**kwargs) - if split == "train": - self.game_files = self.train_envs - else: - self.game_files = self.test_envs - self.env = alfworld_env.ALFWorldEnv( - self.game_files[self.env_idx], *args, **kwargs - ) - else: - raise ValueError( - f"Unknown environment name: {env_name}, please choose from textworld, twx, or alfworld." - ) - - # Not sure if this is right, need to double check w/ Marc - def reset(self, *, seed=None, options=None): - return self.env.reset(seed=seed, options=options) - - def get_next_task(self, seed=None, options=None): - # Move to the next env in the list. - self.env_idx = (self.env_idx + 1) % len(self.game_files) - if self.env is not None: - self.env.close() - if self.env_name == "textworld": - self.env = textworld_env.TextWorldEnv( - self.game_files[self.env_idx], *self.args, **self.kwargs - ) - elif self.env_name == "twx": - self.env = twx_env.TextWorldExpressEnv( - game_name=self.game_files[self.env_idx][1], - game_params=self.game_files[self.env_idx][2], - *self.args, - **self.kwargs, - ) - elif self.env_name == "alfworld": - self.env = alfworld_env.ALFWorldEnv( - self.game_files[self.env_idx], *self.args, **self.kwargs - ) - else: - raise ValueError( - f"next_task not implemented for env {self.env_name}, only for textworld and alfworld." - ) - return self.reset(seed=seed, options=options) - - def step(self, action): - return self.env.step(action) From 5be4915dfd64fab7de0cb9ce83534edb4fe3ac7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 22 Oct 2025 12:59:30 -0700 Subject: [PATCH 4/4] Update and merge Readmes. Remove Jericho unofficial train split --- README.md | 33 ++++++++++++++++++++++++++++++++- tales/README.md | 31 ------------------------------- tales/jericho/jericho_data.py | 29 ----------------------------- 3 files changed, 32 insertions(+), 61 deletions(-) delete mode 100644 tales/README.md diff --git a/README.md b/README.md index dac18d0..8d73a09 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,6 @@ We provide a pre-built docker image at An example script can be found in the scripts folder. ## 2. Getting Started -0. For training details, please see the README.md in the tales folder. 1. Run benchmark evaluation on all the games for the specified random agent: @@ -142,6 +141,38 @@ You can then use this agent by specifying the path to the file and the class nam > [!NOTE] > See the [agents folder](https://github.com/microsoft/tale-suite/tree/main/agents) for more concrete examples. +## 5. Training Your Language Agents on TALES +TALES offers both train splits and test splits, the latter of which make up the games all models in our technical report were evaluated on. + +The following is an example of how to import desired environments and allow an agent to play through them. + +Note that importing the relevant framework automatically registers all environments in that framework with gym. +You can individually import the frameworks if you want to only evaluate on them one at a time. +For now, we do not include a jericho train split. + +``` +import gymnasium as gym +from tales import * + +# Training splits +train_envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' in env_spec.id] + +# Testing splits +envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' not in env_spec.id] + +train_env = gym.make( + train_envs[0], + disable_env_checker=True, + admissible_commands=True, +) + +test_env = gym.make( + envs[0], + disable_env_checker=True, + admissible_commands=True, +) +``` + ## Citation ``` @article{cui2025tales, diff --git a/tales/README.md b/tales/README.md deleted file mode 100644 index 0684dca..0000000 --- a/tales/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# Training and Testing on TALES -TALES offers both train splits and test splits, the latter of which make up the games all models in our technical report were evaluated on. - -The following is an example of how to import desired environments and allow an agent to play through them. - -Note that importing the relevant framework automatically registers all environments in that framework with gym. -You can individually import the frameworks if you want to only evaluate on them one at a time. -For now, we do not include a jericho train split. - -``` -import gymnasium as gym -from tales import * - -# Training splits -train_envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' in env_spec.id] - -# Testing splits -envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' not in env_spec.id] - -train_env = gym.make( - train_envs[0], - disable_env_checker=True, - admissible_commands=True, -) - -test_env = gym.make( - envs[0], - disable_env_checker=True, - admissible_commands=True, -) -``` \ No newline at end of file diff --git a/tales/jericho/jericho_data.py b/tales/jericho/jericho_data.py index fbc438a..a757302 100644 --- a/tales/jericho/jericho_data.py +++ b/tales/jericho/jericho_data.py @@ -8,35 +8,6 @@ GAMES_URLS = "https://github.com/BYU-PCCL/z-machine-games/raw/master/jericho-game-suite" TALES_CACHE_JERICHO = pjoin(TALES_CACHE_HOME, "jericho") -JERICHO_TRAIN_GAMES = [ - "loose", - "karn", - "ballyhoo", - "zork2", - "adventureland", - "omniquest", - "weapon", - "905", - "wishbringer", - "night", - "tryst205", - "zork3", - "murdac", - "afflicted", - "moonlit", - "dragon", - "reverb", - "jewel", - "enter", - "snacktime", - "enchanter", - "acorncourt", - "huntdark", - "gold", - "yomomma", - "inhumane", - "zenon", -] with open(importlib_files("tales") / "jericho" / "games.json") as f: