Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions src/google/adk/utils/instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from ..sessions.state import State

__all__ = [
'inject_session_state',
"inject_session_state",
]

logger = logging.getLogger('google_adk.' + __name__)
logger = logging.getLogger("google_adk." + __name__)


async def inject_session_state(
Expand Down Expand Up @@ -76,18 +76,23 @@ async def _async_sub(pattern, repl_async_fn, string) -> str:
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
return "".join(result)

async def _replace_match(match) -> str:
var_name = match.group().lstrip('{').rstrip('}').strip()
matched_text = match.group()

if matched_text.startswith("{{") and matched_text.endswith("}}"):
return matched_text[1:-1]

var_name = matched_text.lstrip("{").rstrip("}").strip()
optional = False
if var_name.endswith('?'):
if var_name.endswith("?"):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
var_name = var_name.removesuffix("?")
if var_name.startswith("artifact."):
var_name = var_name.removeprefix("artifact.")
if invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
raise ValueError("Artifact service is not initialized.")
artifact = await invocation_context.artifact_service.load_artifact(
app_name=invocation_context.session.app_name,
user_id=invocation_context.session.user_id,
Expand All @@ -97,31 +102,31 @@ async def _replace_match(match) -> str:
if artifact is None:
if optional:
logger.debug(
'Artifact %s not found, replacing with empty string', var_name
"Artifact %s not found, replacing with empty string", var_name
)
return ''
return ""
else:
raise KeyError(f'Artifact {var_name} not found.')
raise KeyError(f"Artifact {var_name} not found.")
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in invocation_context.session.state:
value = invocation_context.session.state[var_name]
if value is None:
return ''
return ""
return str(value)
else:
if optional:
logger.debug(
'Context variable %s not found, replacing with empty string',
"Context variable %s not found, replacing with empty string",
var_name,
)
return ''
return ""
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
raise KeyError(f"Context variable not found: `{var_name}`.")

return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
return await _async_sub(r"{+[^{}]*}+", _replace_match, template)


def _is_valid_state_name(var_name):
Expand All @@ -138,12 +143,12 @@ def _is_valid_state_name(var_name):
Returns:
True if the variable name is a valid state name, False otherwise.
"""
parts = var_name.split(':')
parts = var_name.split(":")
if len(parts) == 1:
return var_name.isidentifier()

if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
if (parts[0] + ':') in prefixes:
if (parts[0] + ":") in prefixes:
return parts[1].isidentifier()
return False
71 changes: 71 additions & 0 deletions tests/unittests/utils/test_instructions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,74 @@ async def test_inject_session_state_with_optional_missing_state_returns_empty():
instruction_template, invocation_context
)
assert populated_instruction == "Optional value: "


@pytest.mark.asyncio
async def test_inject_session_state_with_double_brace_escaping():
instruction_template = "Example: {{user_id}}"
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Example: {user_id}"


@pytest.mark.asyncio
async def test_inject_session_state_with_double_brace_escaping_and_normal_substitution():
instruction_template = "Hello {name}, example: {{variable}}"
invocation_context = await _create_test_readonly_context(
state={"name": "Alice"}
)

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Hello Alice, example: {variable}"


@pytest.mark.asyncio
async def test_inject_session_state_with_python_fstring_example():
instruction_template = """
Example Python code:
logger.error(f"User not found: {{user_id}}")
"""
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
expected = """
Example Python code:
logger.error(f"User not found: {user_id}")
"""
assert populated_instruction == expected


@pytest.mark.asyncio
async def test_inject_session_state_with_typescript_template_literal():
instruction_template = """
Example TypeScript code:
console.log(`User: ${{userId}}`);
"""
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
expected = """
Example TypeScript code:
console.log(`User: ${userId}`);
"""
assert populated_instruction == expected


@pytest.mark.asyncio
async def test_inject_session_state_with_multiple_double_brace_patterns():
instruction_template = "Examples: {{var1}}, {{var2}}, {{var3}}"
invocation_context = await _create_test_readonly_context()

populated_instruction = await instructions_utils.inject_session_state(
instruction_template, invocation_context
)
assert populated_instruction == "Examples: {var1}, {var2}, {var3}"
Loading