diff --git a/README.md b/README.md index 9a804ef..8d73a09 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,38 @@ You can then use this agent by specifying the path to the file and the class nam > [!NOTE] > See the [agents folder](https://github.com/microsoft/tale-suite/tree/main/agents) for more concrete examples. +## 5. Training Your Language Agents on TALES +TALES offers both train splits and test splits, the latter of which make up the games all models in our technical report were evaluated on. + +The following is an example of how to import desired environments and allow an agent to play through them. + +Note that importing the relevant framework automatically registers all environments in that framework with gym. +You can individually import the frameworks if you want to only evaluate on them one at a time. +For now, we do not include a jericho train split. + +``` +import gymnasium as gym +from tales import * + +# Training splits +train_envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' in env_spec.id] + +# Testing splits +envs = [env_spec.id for env_spec in gym.envs.registry.values() if "tales/" in env_spec.id and 'train' not in env_spec.id] + +train_env = gym.make( + train_envs[0], + disable_env_checker=True, + admissible_commands=True, +) + +test_env = gym.make( + envs[0], + disable_env_checker=True, + admissible_commands=True, +) +``` + ## Citation ``` @article{cui2025tales, diff --git a/tales/alfworld/__init__.py b/tales/alfworld/__init__.py index e63d292..d3b7178 100644 --- a/tales/alfworld/__init__.py +++ b/tales/alfworld/__init__.py @@ -4,9 +4,14 @@ from .alfworld_env import ALFWorldTask environments = [] +train_environments = [] for split in ["seen", "unseen"]: for task_type in TASK_TYPES: + gamefiles = sorted(alfworld_data.get_alfworld_game(task_type, split)) + train_gamefiles = gamefiles[1:] + test_gamefiles = [gamefiles[0]] + task_name = task_type.replace("_", " ").title().replace(" ", "") env_name = f"ALFWorld{task_name}{split.title()}" environments.append([env_name, "v0"]) @@ -14,7 +19,21 @@ gym.register( id=f"tales/{env_name}-v0", entry_point="tales.alfworld:ALFWorldTask", - kwargs={"task_type": task_type, "split": split}, + kwargs={ + "all_gamefiles": test_gamefiles, + "start_gamefile": test_gamefiles[0], + }, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.alfworld:ALFWorldTask", + kwargs={ + "all_gamefiles": train_gamefiles, + "start_gamefile": train_gamefiles[0], + }, ) diff --git a/tales/alfworld/alfworld_env.py b/tales/alfworld/alfworld_env.py index 380885b..9499ff4 100644 --- a/tales/alfworld/alfworld_env.py +++ b/tales/alfworld/alfworld_env.py @@ -50,9 +50,9 @@ def step(self, action): class ALFWorldTask(ALFWorldEnv): - def __init__(self, task_type, split, *args, **kwargs): - self.gamefiles = sorted(alfworld_data.get_alfworld_game(task_type, split)) - super().__init__(self.gamefiles[0], *args, **kwargs) + def __init__(self, all_gamefiles, start_gamefile, *args, **kwargs): + self.gamefiles = all_gamefiles + super().__init__(start_gamefile, *args, **kwargs) def reset(self, *, seed=None, options=None): if seed is not None: diff --git a/tales/get_env_splits.py b/tales/get_env_splits.py deleted file mode 100644 index 2bfb632..0000000 --- a/tales/get_env_splits.py +++ /dev/null @@ -1,183 +0,0 @@ -# This is literally just a wrapper to get the train and test-time splits. Is 99.99% just building on Marc's existing code. -import glob -from os.path import join as pjoin - -from tales.alfworld import alfworld_data, alfworld_env -from tales.textworld import textworld_data, textworld_env -from tales.textworld_express import twx_data, twx_env - - -def get_textworld_env_splits( - difficulties=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], games_per_difficulty=1 -): - # Returns a list of envs for training and test splits for Textworld-Cookingworld: - # For training, we let the user specify difficulties and how many games per difficulty to include. - # For testing, we use all difficulties from 1 to 10, and use one game each, similar to the evaluation in the original paper. - textworld_data.prepare_twcooking_data() # make sure the data is ready - - # Training split: - # Get the game files: - train_games_files = [] - for diff in difficulties: - all_games = sorted(textworld_data.get_cooking_game(diff, split="train")) - train_games_files.extend(all_games[:games_per_difficulty]) - - # Testing split: - test_games_files = [] - for i in range(1, 11): - # Just get one game per difficulty for testing. - # This is similar to the evaluation in the original paper. - all_games = sorted(textworld_data.get_cooking_game(i, split="test")) - test_games_files.append(all_games[0]) - - return train_games_files, test_games_files - - -def get_alfworld_env_splits(games_per_task=2): - # For alfworld, we just generate the test split first and then condition the train split to not have the same files as the text split. - alfworld_data.prepare_alfworld_data() # make sure the data is ready - test_games_files = [] - for task in alfworld_data.TASK_TYPES: - game_files_seen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - game_files_unseen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - # The test split always only takes the first game file in the split. - test_games_files.extend(game_files_seen[[0]]) - test_games_files.extend(game_files_unseen[[0]]) - - # Assert we have the right number of files. - assert len(test_games_files) == 2 * len(alfworld_data.TASK_TYPES) - - # Now, get the training split. - # We want to make sure that the training split does not have any files that are in the test split. - train_games_files = [] - for task in alfworld_data.TASK_TYPES: - game_files_seen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_SEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - game_files_unseen = sorted( - glob.glob( - pjoin( - alfworld_data.TALES_CACHE_ALFWORLD_VALID_UNSEEN, - f"{task}*", - "**", - "*.tw-pddl", - ) - ) - ) - # Remove any files that are in the test split. - filtered_game_files_seen = [ - f for f in game_files_seen if not any(s in f for s in test_games_files) - ] - filtered_game_files_unseen = [ - f for f in game_files_unseen if not any(s in f for s in test_games_files) - ] - - # Now get the requested number of games per task type - train_games_files.extend(filtered_game_files_seen[:games_per_task]) - train_games_files.extend(filtered_game_files_unseen[:games_per_task]) - - return train_games_files, test_games_files - - -class GeneralTALESEnv: - # A general env wrapper such that the train/test files gotten from the above functions can easily just be plugged into an env and ran. - # This returns a 'fake' batch env that will always deterministically cycle through the provided env file/seeds unless explicitly told to randomize (for training) - # TODO: implement for Scienceworld and Jericho - def __init__(self, env_name, split, *args, **kwargs): - self.env_name = env_name - self.split = split - self.env_idx = 0 - self.kwargs = kwargs - self.args = args - self.game_files = None - if env_name == "textworld": - self.train_envs, self.test_envs = get_textworld_env_splits(**kwargs) - if split == "train": - self.game_files = self.train_envs - else: - self.game_files = self.test_envs - self.env = textworld_env.TextWorldEnv( - self.game_files[self.env_idx], *args, **kwargs - ) - elif env_name == "twx": - # Train/test in twx are just seed based. - self.game_files = twx_data.TASKS - self.env = twx_env.TextWorldExpressEnv( - game_name=self.game_files[self.env_idx][1], - game_params=self.game_files[self.env_idx][2], - admissible_commands=False, - split=split, - *args, - **kwargs, - ) - elif env_name == "alfworld": - self.train_envs, self.test_envs = get_alfworld_env_splits(**kwargs) - if split == "train": - self.game_files = self.train_envs - else: - self.game_files = self.test_envs - self.env = alfworld_env.ALFWorldEnv( - self.game_files[self.env_idx], *args, **kwargs - ) - else: - raise ValueError( - f"Unknown environment name: {env_name}, please choose from textworld, twx, or alfworld." - ) - - # Not sure if this is right, need to double check w/ Marc - def reset(self, *, seed=None, options=None): - return self.env.reset(seed=seed, options=options) - - def get_next_task(self, seed=None, options=None): - # Move to the next env in the list. - self.env_idx = (self.env_idx + 1) % len(self.game_files) - if self.env is not None: - self.env.close() - if self.env_name == "textworld": - self.env = textworld_env.TextWorldEnv( - self.game_files[self.env_idx], *self.args, **self.kwargs - ) - elif self.env_name == "twx": - self.env = twx_env.TextWorldExpressEnv( - game_name=self.game_files[self.env_idx][1], - game_params=self.game_files[self.env_idx][2], - *self.args, - **self.kwargs, - ) - elif self.env_name == "alfworld": - self.env = alfworld_env.ALFWorldEnv( - self.game_files[self.env_idx], *self.args, **self.kwargs - ) - else: - raise ValueError( - f"next_task not implemented for env {self.env_name}, only for textworld and alfworld." - ) - return self.reset(seed=seed, options=options) - - def step(self, action): - return self.env.step(action) diff --git a/tales/scienceworld/__init__.py b/tales/scienceworld/__init__.py index cdfb549..dbf5b31 100644 --- a/tales/scienceworld/__init__.py +++ b/tales/scienceworld/__init__.py @@ -3,6 +3,7 @@ from .scienceworld_env import TASK_NAMES, ScienceWorldEnv environments = [] +train_environments = [] for task_name in TASK_NAMES: env_name = f"ScienceWorld{task_name.title().replace('-', '')}" @@ -11,7 +12,15 @@ gym.register( id=f"tales/{env_name}-v0", entry_point="tales.scienceworld:ScienceWorldEnv", - kwargs={"task_name": task_name}, + kwargs={"task_name": task_name, "split": "test"}, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.scienceworld:ScienceWorldEnv", + kwargs={"task_name": task_name, "split": "train"}, ) diff --git a/tales/scienceworld/scienceworld_env.py b/tales/scienceworld/scienceworld_env.py index 64e2dbf..6572e0a 100644 --- a/tales/scienceworld/scienceworld_env.py +++ b/tales/scienceworld/scienceworld_env.py @@ -9,12 +9,14 @@ class ScienceWorldEnv(gym.Env): - def __init__(self, task_name, admissible_commands=False, *args, **kwargs): + def __init__( + self, task_name, admissible_commands=False, split="Test", *args, **kwargs + ): self.task_name = task_name self.admissible_commands = admissible_commands self.env = scienceworld.ScienceWorldEnv(self.task_name, envStepLimit=np.inf) self.variations = scienceworld_data.get_variations( - self.task_name, split="test", env=self.env + self.task_name, split=split, env=self.env ) self.variation = self.variations[0] diff --git a/tales/textworld/__init__.py b/tales/textworld/__init__.py index 146e1da..cd5835b 100644 --- a/tales/textworld/__init__.py +++ b/tales/textworld/__init__.py @@ -4,16 +4,28 @@ from .textworld_env import TextWorldEnv, TWCookingEnv environments = [] +train_environments = [] # TWCookingEnv for difficulty in range(1, 10 + 1): + gamefiles = sorted(textworld_data.get_cooking_game(difficulty)) + train_gamefiles = gamefiles[1:] + test_gamefiles = [gamefiles[0]] env_name = f"TWCookingLevel{difficulty}" environments.append([env_name, "v0"]) gym.register( id=f"tales/{env_name}-v0", entry_point="tales.textworld:TWCookingEnv", - kwargs={"difficulty": difficulty}, + kwargs={"all_gamefiles": test_gamefiles, "start_gamefile": test_gamefiles[0]}, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.textworld:TWCookingEnv", + kwargs={"all_gamefiles": train_gamefiles, "start_gamefile": train_gamefiles[0]}, ) diff --git a/tales/textworld/textworld_env.py b/tales/textworld/textworld_env.py index 324240c..30bbbe6 100644 --- a/tales/textworld/textworld_env.py +++ b/tales/textworld/textworld_env.py @@ -34,9 +34,9 @@ def step(self, action): class TWCookingEnv(TextWorldEnv): - def __init__(self, difficulty, *args, **kwargs): - self.gamefiles = sorted(textworld_data.get_cooking_game(difficulty)) - super().__init__(self.gamefiles[0], *args, **kwargs) + def __init__(self, all_gamefiles, start_gamefile, *args, **kwargs): + self.gamefiles = all_gamefiles + super().__init__(start_gamefile, *args, **kwargs) def reset(self, *, seed=None, options=None): if seed is not None: diff --git a/tales/textworld_express/__init__.py b/tales/textworld_express/__init__.py index a2e73d0..ff22589 100644 --- a/tales/textworld_express/__init__.py +++ b/tales/textworld_express/__init__.py @@ -3,6 +3,7 @@ from .twx_env import TASKS, TextWorldExpressEnv environments = [] +train_environments = [] for task_name, game_name, game_params in TASKS: env_name = f"TWX{task_name}" @@ -11,7 +12,15 @@ gym.register( id=f"tales/{env_name}-v0", entry_point="tales.textworld_express:TextWorldExpressEnv", - kwargs={"game_name": game_name, "game_params": game_params}, + kwargs={"game_name": game_name, "game_params": game_params, "split": "test"}, + ) + + train_env_name = env_name + "_train" + train_environments.append([train_env_name, "v0"]) + gym.register( + id=f"tales/{train_env_name}-v0", + entry_point="tales.textworld_express:TextWorldExpressEnv", + kwargs={"game_name": game_name, "game_params": game_params, "split": "train"}, )