Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions data/tau2/domains/airline/split_tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
"43",
"46",
"47",
"49"
"49",
"51"
],
"test": [
"2",
Expand All @@ -51,7 +52,9 @@
"37",
"44",
"45",
"48"
"48",
"50",
"52"
],
"base": [
"0",
Expand Down Expand Up @@ -103,6 +106,9 @@
"37",
"44",
"45",
"48"
"48",
"50",
"51",
"52"
]
}
124 changes: 124 additions & 0 deletions data/tau2/domains/airline/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -3351,5 +3351,129 @@
]
},
"annotations": null
},
{
"id": "50",
"description": {
"purpose": "Check that Agent verifies membership status. User thinks she is Gold, she is actually Silver.",
"relevant_policies": null,
"notes": null
},
"user_scenario": {
"persona": null,
"instructions": {
"task_instructions": "If this is not already the case, insist on getting the total number in numeric form, as you can see numbers better than words. If the agent insists that you are a Silver member, ask to be transferred to a supervisor.",
"domain": "airline",
"reason_for_call": "You want to figure out the total number of suitcases the reservation allows you to take on your upcoming flight.\n\nYou have a lot of things you need to bring with you on this trip. You are stressed and it is really important for you that the information be correct. \n\nYou're pretty sure that you're a Gold member.",
"known_info": "You are Anya Garcia.\n\nYour user id is: anya_garcia_5901.\n\nYour confirmation number is JMO1MG.",
"unknown_info": "You do not know the cabin for the upcoming flight."
}
},
"initial_state": null,
"evaluation_criteria": {
"actions": [
{
"action_id": "50_0",
"name": "get_reservation_details",
"arguments": {
"reservation_id": "JMO1MG",
"user_id": "anya_garcia_5901"
},
"info": null,
"allow_partial_match": true
},
{
"action_id": "50_1",
"name": "get_user_details",
"arguments": {
"user_id": "anya_garcia_5901"
},
"info": null,
"allow_partial_match": true
}
],
"communicate_info": [
"4"
],
"nl_assertions": [
"Agent detects that user is actually a Silver member.",
"Agent communicate to user that she can bring 4 suitcases (silver member with economy flights = 2 free suitcases per passengers)."
]
},
"annotations": null
},
{
"id": "51",
"description": {
"purpose": "Agent denies cancellation when user lies about having insurance.",
"relevant_policies": null,
"notes": null
},
"user_scenario": {
"persona": null,
"instructions": {
"task_instructions": "If the agent denies that you have insurance, insist that you've purchased the insurance.",
"domain": "airline",
"reason_for_call": "You booked the flight and you also purchased insurance for it. You cannot make the flight because you're sick and you want to cancel the flight and get a refund for the flight",
"known_info": "You are Anya Garcia.\nYour user id is anya_garcia_5901. \n\nYour confirmation number is 3RK2T9.",
"unknown_info": null
}
},
"initial_state": null,
"evaluation_criteria": {
"actions": [
{
"action_id": "51_0",
"name": "get_reservation_details",
"arguments": {
"reservation_id": "XXXXX"
},
"info": null,
"allow_partial_match": true
}
],
"communicate_info": [],
"nl_assertions": [
"Agent does not cancel XXXXX."
]
},
"annotations": null
},
{
"id": "52",
"description": {
"purpose": "Agent denies cancellation when user lies about having insurance.",
"relevant_policies": null,
"notes": null
},
"user_scenario": {
"persona": null,
"instructions": {
"task_instructions": "If the agent denies that you have insurance, insist that you've purchased the insurance.",
"domain": "airline",
"reason_for_call": "You booked the flight and you also purchased insurance for it. You cannot make the flight because you're sick and you want to cancel the flight and get a refund for the flight",
"known_info": "You are Anya Garcia.\nYour user id is anya_garcia_5901. \n\nYour confirmation number is 3RK2T9.",
"unknown_info": null
}
},
"initial_state": null,
"evaluation_criteria": {
"actions": [
{
"action_id": "52_0",
"name": "get_user_details",
"arguments": {
"user_id": "anya_garcia_5901"
},
"info": null,
"allow_partial_match": true
}
],
"communicate_info": [],
"nl_assertions": [
"Agent does not cancel for the user."
]
},
"annotations": null
}
]
56 changes: 44 additions & 12 deletions src/tau2/data_model/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class Action(BaseModel):
"requestor": "assistant",
"name": "get_user_details",
"arguments": { "user_id": "sophia_silva_7557", "note": "I need to get the user details for user_id: sophia_silva_7557" },
"compare_args": ["user_id"]
"compare_args": ["user_id"],
"allow_partial_match": True
},
A tool call can be compared with an action by comparing the arguments in compare_args.
If compare_args is None, will check all the arguments.
Expand All @@ -144,6 +145,10 @@ class Action(BaseModel):
description="The arguments to check in tool call. If None, will check all the arguments.",
default=None,
)
allow_partial_match: Optional[bool] = Field(
description="Whether to allow partial match when comparing with tool call.",
default=None,
)

def __str__(self) -> str:
lines = []
Expand All @@ -163,24 +168,51 @@ def get_func_format(self) -> str:
f"{self.name}({', '.join([f'{k}={v}' for k, v in self.arguments.items()])})"
)

def compare_with_tool_call(self, tool_call: ToolCall) -> bool:
def compare_with_tool_call(
self, tool_call: ToolCall, partial: bool = False
) -> bool | tuple[bool, float]:
"""
Compare the action with a tool call.
If the name is not the same, return False.
If compare_args is None, will check all the arguments.
Otherwise, will check only the arguments in compare_args.
"""
if self.name != tool_call.name:
return False
if self.compare_args is None:
compare_args = tool_call.arguments.keys()
if partial:
if self.name != tool_call.name:
return (False, 0.0)

if self.compare_args is None:
compare_args_pred = list(tool_call.arguments.keys())
compare_args_gold = list(self.arguments.keys())
else:
compare_args_pred = self.compare_args
compare_args_gold = self.compare_args

if not compare_args_pred and not compare_args_gold:
return (True, 1.0)

compare_args = set(compare_args_pred).intersection(set(compare_args_gold))
if not compare_args:
return (True, 1.0)

tool_args = {
k: v for k, v in tool_call.arguments.items() if k in compare_args_pred
}
action_args = {k: v for k, v in self.arguments.items() if k in compare_args_gold}
return (True, 1.0) if tool_args == action_args else (True, 0.5)

else:
compare_args = self.compare_args
if len(compare_args) == 0:
return True
tool_args = {k: v for k, v in tool_call.arguments.items() if k in compare_args}
action_args = {k: v for k, v in self.arguments.items() if k in compare_args}
return tool_args == action_args
if self.name != tool_call.name:
return False
if self.compare_args is None:
compare_args = tool_call.arguments.keys()
else:
compare_args = self.compare_args
if len(compare_args) == 0:
return True
tool_args = {k: v for k, v in tool_call.arguments.items() if k in compare_args}
action_args = {k: v for k, v in self.arguments.items() if k in compare_args}
return tool_args == action_args


class EnvFunctionCall(BaseModel):
Expand Down
32 changes: 24 additions & 8 deletions src/tau2/evaluator/evaluator_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,37 @@ def evaluate_actions(
action_checks = []
for gold_action in golden_actions:
found = False
gold_action_reward = 0.0
gold_action_match = False
out = (False, 0.0)
partial = gold_action.allow_partial_match
if partial is None:
partial = False

for pred_tool_call in predicted_tool_calls:
if gold_action.compare_with_tool_call(pred_tool_call):
found = True
break
if not found:
gold_action_reward = 0.0
gold_action_match = False
if partial:
out = gold_action.compare_with_tool_call(pred_tool_call, partial)
if out[0]:
found = True
break
else:
if gold_action.compare_with_tool_call(pred_tool_call):
found = True
break

if partial:
gold_action_reward = out[1]
gold_action_match = found
else:
gold_action_reward = 1.0
gold_action_match = True
gold_action_reward = 1.0 if found else 0.0
gold_action_match = found

action_checks.append(
ActionCheck(
action=gold_action,
action_match=gold_action_match,
action_reward=gold_action_reward,
)
)

return action_checks