From 0e492bed4ee00118397de948e5eff9cff47d51ef Mon Sep 17 00:00:00 2001 From: random <> Date: Thu, 29 Jan 2026 21:31:00 -0800 Subject: [PATCH 1/3] extending to PRM(process-reward-model) with partial rewards --- src/tau2/data_model/tasks.py | 56 ++++++++++++++++++++------ src/tau2/evaluator/evaluator_action.py | 32 +++++++++++---- 2 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/tau2/data_model/tasks.py b/src/tau2/data_model/tasks.py index a08aaaff..f3f003a7 100644 --- a/src/tau2/data_model/tasks.py +++ b/src/tau2/data_model/tasks.py @@ -122,7 +122,8 @@ class Action(BaseModel): "requestor": "assistant", "name": "get_user_details", "arguments": { "user_id": "sophia_silva_7557", "note": "I need to get the user details for user_id: sophia_silva_7557" }, - "compare_args": ["user_id"] + "compare_args": ["user_id"], + "allow_partial_match": True }, A tool call can be compared with an action by comparing the arguments in compare_args. If compare_args is None, will check all the arguments. @@ -144,6 +145,10 @@ class Action(BaseModel): description="The arguments to check in tool call. If None, will check all the arguments.", default=None, ) + allow_partial_match: Optional[bool] = Field( + description="Whether to allow partial match when comparing with tool call.", + default=None, + ) def __str__(self) -> str: lines = [] @@ -163,24 +168,51 @@ def get_func_format(self) -> str: f"{self.name}({', '.join([f'{k}={v}' for k, v in self.arguments.items()])})" ) - def compare_with_tool_call(self, tool_call: ToolCall) -> bool: + def compare_with_tool_call( + self, tool_call: ToolCall, partial: bool = False + ) -> bool | tuple[bool, float]: """ Compare the action with a tool call. If the name is not the same, return False. If compare_args is None, will check all the arguments. Otherwise, will check only the arguments in compare_args. """ - if self.name != tool_call.name: - return False - if self.compare_args is None: - compare_args = tool_call.arguments.keys() + if partial: + if self.name != tool_call.name: + return (False, 0.0) + + if self.compare_args is None: + compare_args_pred = list(tool_call.arguments.keys()) + compare_args_gold = list(self.arguments.keys()) + else: + compare_args_pred = self.compare_args + compare_args_gold = self.compare_args + + if not compare_args_pred and not compare_args_gold: + return (True, 1.0) + + compare_args = set(compare_args_pred).intersection(set(compare_args_gold)) + if not compare_args: + return (True, 1.0) + + tool_args = { + k: v for k, v in tool_call.arguments.items() if k in compare_args_pred + } + action_args = {k: v for k, v in self.arguments.items() if k in compare_args_gold} + return (True, 1.0) if tool_args == action_args else (True, 0.5) + else: - compare_args = self.compare_args - if len(compare_args) == 0: - return True - tool_args = {k: v for k, v in tool_call.arguments.items() if k in compare_args} - action_args = {k: v for k, v in self.arguments.items() if k in compare_args} - return tool_args == action_args + if self.name != tool_call.name: + return False + if self.compare_args is None: + compare_args = tool_call.arguments.keys() + else: + compare_args = self.compare_args + if len(compare_args) == 0: + return True + tool_args = {k: v for k, v in tool_call.arguments.items() if k in compare_args} + action_args = {k: v for k, v in self.arguments.items() if k in compare_args} + return tool_args == action_args class EnvFunctionCall(BaseModel): diff --git a/src/tau2/evaluator/evaluator_action.py b/src/tau2/evaluator/evaluator_action.py index 4355e40f..d3d90927 100644 --- a/src/tau2/evaluator/evaluator_action.py +++ b/src/tau2/evaluator/evaluator_action.py @@ -69,16 +69,31 @@ def evaluate_actions( action_checks = [] for gold_action in golden_actions: found = False + gold_action_reward = 0.0 + gold_action_match = False + out = (False, 0.0) + partial = gold_action.allow_partial_match + if partial is None: + partial = False + for pred_tool_call in predicted_tool_calls: - if gold_action.compare_with_tool_call(pred_tool_call): - found = True - break - if not found: - gold_action_reward = 0.0 - gold_action_match = False + if partial: + out = gold_action.compare_with_tool_call(pred_tool_call, partial) + if out[0]: + found = True + break + else: + if gold_action.compare_with_tool_call(pred_tool_call): + found = True + break + + if partial: + gold_action_reward = out[1] + gold_action_match = found else: - gold_action_reward = 1.0 - gold_action_match = True + gold_action_reward = 1.0 if found else 0.0 + gold_action_match = found + action_checks.append( ActionCheck( action=gold_action, @@ -86,4 +101,5 @@ def evaluate_actions( action_reward=gold_action_reward, ) ) + return action_checks From cbe2bad6d4f7d5d0079a8e063abb00f857504c06 Mon Sep 17 00:00:00 2001 From: random <> Date: Thu, 29 Jan 2026 21:49:05 -0800 Subject: [PATCH 2/3] added task examples --- data/tau2/domains/airline/tasks.json | 124 +++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/data/tau2/domains/airline/tasks.json b/data/tau2/domains/airline/tasks.json index ca8b6914..23c24dac 100644 --- a/data/tau2/domains/airline/tasks.json +++ b/data/tau2/domains/airline/tasks.json @@ -3351,5 +3351,129 @@ ] }, "annotations": null + }, + { + "id": "50", + "description": { + "purpose": "Check that Agent verifies membership status. User thinks she is Gold, she is actually Silver.", + "relevant_policies": null, + "notes": null + }, + "user_scenario": { + "persona": null, + "instructions": { + "task_instructions": "If this is not already the case, insist on getting the total number in numeric form, as you can see numbers better than words. If the agent insists that you are a Silver member, ask to be transferred to a supervisor.", + "domain": "airline", + "reason_for_call": "You want to figure out the total number of suitcases the reservation allows you to take on your upcoming flight.\n\nYou have a lot of things you need to bring with you on this trip. You are stressed and it is really important for you that the information be correct. \n\nYou're pretty sure that you're a Gold member.", + "known_info": "You are Anya Garcia.\n\nYour user id is: anya_garcia_5901.\n\nYour confirmation number is JMO1MG.", + "unknown_info": "You do not know the cabin for the upcoming flight." + } + }, + "initial_state": null, + "evaluation_criteria": { + "actions": [ + { + "action_id": "50_0", + "name": "get_reservation_details", + "arguments": { + "reservation_id": "JMO1MG", + "user_id": "anya_garcia_5901" + }, + "info": null, + "allow_partial_match": true + }, + { + "action_id": "50_1", + "name": "get_user_details", + "arguments": { + "user_id": "anya_garcia_5901" + }, + "info": null, + "allow_partial_match": true + } + ], + "communicate_info": [ + "4" + ], + "nl_assertions": [ + "Agent detects that user is actually a Silver member.", + "Agent communicate to user that she can bring 4 suitcases (silver member with economy flights = 2 free suitcases per passengers)." + ] + }, + "annotations": null + }, + { + "id": "51", + "description": { + "purpose": "Agent denies cancellation when user lies about having insurance.", + "relevant_policies": null, + "notes": null + }, + "user_scenario": { + "persona": null, + "instructions": { + "task_instructions": "If the agent denies that you have insurance, insist that you've purchased the insurance.", + "domain": "airline", + "reason_for_call": "You booked the flight and you also purchased insurance for it. You cannot make the flight because you're sick and you want to cancel the flight and get a refund for the flight", + "known_info": "You are Anya Garcia.\nYour user id is anya_garcia_5901. \n\nYour confirmation number is 3RK2T9.", + "unknown_info": null + } + }, + "initial_state": null, + "evaluation_criteria": { + "actions": [ + { + "action_id": "51_0", + "name": "get_reservation_details", + "arguments": { + "reservation_id": "XXXXX" + }, + "info": null, + "allow_partial_match": true + } + ], + "communicate_info": [], + "nl_assertions": [ + "Agent does not cancel XXXXX." + ] + }, + "annotations": null + }, + { + "id": "52", + "description": { + "purpose": "Agent denies cancellation when user lies about having insurance.", + "relevant_policies": null, + "notes": null + }, + "user_scenario": { + "persona": null, + "instructions": { + "task_instructions": "If the agent denies that you have insurance, insist that you've purchased the insurance.", + "domain": "airline", + "reason_for_call": "You booked the flight and you also purchased insurance for it. You cannot make the flight because you're sick and you want to cancel the flight and get a refund for the flight", + "known_info": "You are Anya Garcia.\nYour user id is anya_garcia_5901. \n\nYour confirmation number is 3RK2T9.", + "unknown_info": null + } + }, + "initial_state": null, + "evaluation_criteria": { + "actions": [ + { + "action_id": "52_0", + "name": "get_user_details", + "arguments": { + "user_id": "anya_garcia_5901" + }, + "info": null, + "allow_partial_match": true + } + ], + "communicate_info": [], + "nl_assertions": [ + "Agent does not cancel for the user." + ] + }, + "annotations": null } ] \ No newline at end of file From df868645a3fbd1e2428088814e85eb6c22d63b20 Mon Sep 17 00:00:00 2001 From: random <> Date: Thu, 29 Jan 2026 21:52:40 -0800 Subject: [PATCH 3/3] adding task splits --- data/tau2/domains/airline/split_tasks.json | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/data/tau2/domains/airline/split_tasks.json b/data/tau2/domains/airline/split_tasks.json index c83cce15..021ba3a7 100644 --- a/data/tau2/domains/airline/split_tasks.json +++ b/data/tau2/domains/airline/split_tasks.json @@ -29,7 +29,8 @@ "43", "46", "47", - "49" + "49", + "51" ], "test": [ "2", @@ -51,7 +52,9 @@ "37", "44", "45", - "48" + "48", + "50", + "52" ], "base": [ "0", @@ -103,6 +106,9 @@ "37", "44", "45", - "48" + "48", + "50", + "51", + "52" ] } \ No newline at end of file