Skip to content

Commit 21b66a4

Browse files
committed
feat(utils): Add support for nested state access in template injection
- Add nested state access support using dot notation in inject_session_state - Fix optional chaining error handling for better robustness - Add comprehensive test coverage for nested state templates Resolves: #575
1 parent cf21ca3 commit 21b66a4

File tree

2 files changed

+273
-20
lines changed

2 files changed

+273
-20
lines changed

src/google/adk/utils/instructions_utils.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import re
19+
from typing import Any
1920

2021
from ..agents.readonly_context import ReadonlyContext
2122
from ..sessions.state import State
@@ -46,7 +47,11 @@ async def build_instruction(
4647
) -> str:
4748
return await inject_session_state(
4849
'You can inject a state variable like {var_name} or an artifact '
49-
'{artifact.file_name} into the instruction template.',
50+
'{artifact.file_name} into the instruction template.'
51+
'You can also inject a nested variable like {var_name.nested_var}.'
52+
'If a variable or nested attribute may be missing, append `?` to the '
53+
'path or attribute name for optional handling, e.g. '
54+
'{var_name.optional_nested_var?}.',
5055
readonly_context,
5156
)
5257
@@ -78,14 +83,52 @@ async def _async_sub(pattern, repl_async_fn, string) -> str:
7883
result.append(string[last_end:])
7984
return ''.join(result)
8085

86+
def _get_nested_value(obj: Any, path: str) -> Any:
87+
"""Retrieve nested value from an object based on dot-separated path."""
88+
parts = path.split('.')
89+
current = obj
90+
91+
for part in parts:
92+
if current is None:
93+
return None
94+
95+
optional = part.endswith('?')
96+
key = part[:-1] if optional else part
97+
98+
# Try dictionary access first
99+
if hasattr(current, '__getitem__'):
100+
try:
101+
current = current[key]
102+
continue
103+
except (KeyError, TypeError):
104+
# If dict access fails, fall through to try getattr
105+
# UNLESS it's a pure dict which definitely doesn't have attributes
106+
if isinstance(current, dict):
107+
if optional:
108+
return None
109+
raise KeyError(f"Key '{key}' not found in path '{path}'")
110+
pass
111+
112+
# Try attribute access
113+
try:
114+
current = getattr(current, key)
115+
except AttributeError:
116+
# Both dict access and attribute access failed.
117+
if optional:
118+
return None
119+
raise KeyError(f"Key '{key}' not found in path '{path}'")
120+
121+
return current
122+
81123
async def _replace_match(match) -> str:
82-
var_name = match.group().lstrip('{').rstrip('}').strip()
83-
optional = False
84-
if var_name.endswith('?'):
85-
optional = True
86-
var_name = var_name.removesuffix('?')
87-
if var_name.startswith('artifact.'):
88-
var_name = var_name.removeprefix('artifact.')
124+
full_path = match.group().lstrip('{').rstrip('}').strip()
125+
126+
if full_path.startswith('artifact.'):
127+
var_name = full_path.removeprefix('artifact.')
128+
optional = var_name.endswith('?')
129+
if optional:
130+
var_name = var_name[:-1]
131+
89132
if invocation_context.artifact_service is None:
90133
raise ValueError('Artifact service is not initialized.')
91134
artifact = await invocation_context.artifact_service.load_artifact(
@@ -104,22 +147,17 @@ async def _replace_match(match) -> str:
104147
raise KeyError(f'Artifact {var_name} not found.')
105148
return str(artifact)
106149
else:
107-
if not _is_valid_state_name(var_name):
150+
if not _is_valid_state_name(full_path.split('.')[0].removesuffix('?')):
108151
return match.group()
109-
if var_name in invocation_context.session.state:
110-
value = invocation_context.session.state[var_name]
152+
153+
try:
154+
value = _get_nested_value(invocation_context.session.state, full_path)
155+
111156
if value is None:
112157
return ''
113158
return str(value)
114-
else:
115-
if optional:
116-
logger.debug(
117-
'Context variable %s not found, replacing with empty string',
118-
var_name,
119-
)
120-
return ''
121-
else:
122-
raise KeyError(f'Context variable not found: `{var_name}`.')
159+
except KeyError as e:
160+
raise KeyError(f'Context variable not found: `{full_path}`.') from e
123161

124162
return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
125163

tests/unittests/utils/test_instructions_utils.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,218 @@ async def test_inject_session_state_with_optional_missing_state_returns_empty():
267267
instruction_template, invocation_context
268268
)
269269
assert populated_instruction == "Optional value: "
270+
271+
272+
# Tests for nested state access feature
273+
@pytest.mark.asyncio
274+
async def test_inject_session_state_with_nested_dict_access():
275+
instruction_template = (
276+
"User name is {user.name} and role is {user.profile.role}"
277+
)
278+
invocation_context = await _create_test_readonly_context(
279+
state={
280+
"user": {
281+
"name": "Alice",
282+
"profile": {"role": "Engineer", "level": "Senior"},
283+
}
284+
}
285+
)
286+
287+
populated_instruction = await instructions_utils.inject_session_state(
288+
instruction_template, invocation_context
289+
)
290+
assert populated_instruction == "User name is Alice and role is Engineer"
291+
292+
293+
@pytest.mark.asyncio
294+
async def test_inject_session_state_with_deep_nested_access():
295+
instruction_template = "Deep value: {level1.level2.level3.value}"
296+
invocation_context = await _create_test_readonly_context(
297+
state={
298+
"level1": {
299+
"level2": {"level3": {"value": "deep_data", "other": "ignored"}}
300+
}
301+
}
302+
)
303+
304+
populated_instruction = await instructions_utils.inject_session_state(
305+
instruction_template, invocation_context
306+
)
307+
assert populated_instruction == "Deep value: deep_data"
308+
309+
310+
@pytest.mark.asyncio
311+
async def test_inject_session_state_with_optional_nested_access_existing():
312+
instruction_template = "Name: {user?.name} Role: {user?.profile?.role}"
313+
invocation_context = await _create_test_readonly_context(
314+
state={
315+
"user": {
316+
"name": "Bob",
317+
"profile": {"role": "Developer"},
318+
}
319+
}
320+
)
321+
322+
populated_instruction = await instructions_utils.inject_session_state(
323+
instruction_template, invocation_context
324+
)
325+
assert populated_instruction == "Name: Bob Role: Developer"
326+
327+
328+
@pytest.mark.asyncio
329+
async def test_inject_session_state_with_optional_nested_access_missing():
330+
instruction_template = "Name: {user?.name} Missing: {user?.missing?.field?}"
331+
invocation_context = await _create_test_readonly_context(
332+
state={"user": {"name": "Charlie"}}
333+
)
334+
335+
populated_instruction = await instructions_utils.inject_session_state(
336+
instruction_template, invocation_context
337+
)
338+
assert populated_instruction == "Name: Charlie Missing: "
339+
340+
341+
@pytest.mark.asyncio
342+
async def test_inject_session_state_with_optional_nested_missing_root():
343+
instruction_template = "Optional nested: {missing_root?.nested?.value?}"
344+
invocation_context = await _create_test_readonly_context(state={})
345+
346+
populated_instruction = await instructions_utils.inject_session_state(
347+
instruction_template, invocation_context
348+
)
349+
assert populated_instruction == "Optional nested: "
350+
351+
352+
@pytest.mark.asyncio
353+
async def test_inject_session_state_with_nested_none_value():
354+
instruction_template = "Value: {user.profile.role}"
355+
invocation_context = await _create_test_readonly_context(
356+
state={"user": {"profile": None}}
357+
)
358+
359+
# When a value in the path is None, it returns empty string
360+
populated_instruction = await instructions_utils.inject_session_state(
361+
instruction_template, invocation_context
362+
)
363+
assert populated_instruction == "Value: "
364+
365+
366+
@pytest.mark.asyncio
367+
async def test_inject_session_state_with_optional_nested_none_value():
368+
instruction_template = "Value: {user.profile?.role?}"
369+
invocation_context = await _create_test_readonly_context(
370+
state={"user": {"profile": None}}
371+
)
372+
373+
populated_instruction = await instructions_utils.inject_session_state(
374+
instruction_template, invocation_context
375+
)
376+
assert populated_instruction == "Value: "
377+
378+
379+
@pytest.mark.asyncio
380+
async def test_inject_session_state_with_missing_nested_key_raises_error():
381+
instruction_template = "Value: {user.profile.missing_key}"
382+
invocation_context = await _create_test_readonly_context(
383+
state={"user": {"profile": {"role": "Engineer"}}}
384+
)
385+
386+
with pytest.raises(
387+
KeyError, match="Context variable not found: `user.profile.missing_key`"
388+
):
389+
await instructions_utils.inject_session_state(
390+
instruction_template, invocation_context
391+
)
392+
393+
394+
@pytest.mark.asyncio
395+
async def test_inject_session_state_with_required_parent_missing_raises_error():
396+
"""Test that {user.profile?} raises error when 'user' (required) is missing.
397+
398+
This verifies that optional chaining is per-segment, not for the whole path.
399+
Even though 'profile?' is optional, 'user' is required and should raise error.
400+
"""
401+
instruction_template = "Value: {user.profile?}"
402+
invocation_context = await _create_test_readonly_context(state={})
403+
404+
with pytest.raises(
405+
KeyError, match="Context variable not found: `user.profile\\?`"
406+
):
407+
await instructions_utils.inject_session_state(
408+
instruction_template, invocation_context
409+
)
410+
411+
412+
@pytest.mark.asyncio
413+
async def test_inject_session_state_with_nested_and_prefixed_state():
414+
instruction_template = "User: {app:user.name} Temp: {temp:session.id}"
415+
invocation_context = await _create_test_readonly_context(
416+
state={
417+
"app:user": {"name": "Dana"},
418+
"temp:session": {"id": "session_123"},
419+
}
420+
)
421+
422+
populated_instruction = await instructions_utils.inject_session_state(
423+
instruction_template, invocation_context
424+
)
425+
assert populated_instruction == "User: Dana Temp: session_123"
426+
427+
428+
@pytest.mark.asyncio
429+
async def test_inject_session_state_with_mixed_nested_and_flat_state():
430+
instruction_template = (
431+
"Flat: {simple_key}, Nested: {user.name}, Deep: {config.app.version}"
432+
)
433+
invocation_context = await _create_test_readonly_context(
434+
state={
435+
"simple_key": "simple_value",
436+
"user": {"name": "Eve"},
437+
"config": {"app": {"version": "1.0.0"}},
438+
}
439+
)
440+
441+
populated_instruction = await instructions_utils.inject_session_state(
442+
instruction_template, invocation_context
443+
)
444+
assert populated_instruction == "Flat: simple_value, Nested: Eve, Deep: 1.0.0"
445+
446+
447+
@pytest.mark.asyncio
448+
async def test_inject_session_state_with_numeric_nested_values():
449+
instruction_template = "Age: {user.age}, Score: {user.metrics.score}"
450+
invocation_context = await _create_test_readonly_context(
451+
state={"user": {"age": 25, "metrics": {"score": 95.5}}}
452+
)
453+
454+
populated_instruction = await instructions_utils.inject_session_state(
455+
instruction_template, invocation_context
456+
)
457+
assert populated_instruction == "Age: 25, Score: 95.5"
458+
459+
460+
@pytest.mark.asyncio
461+
async def test_inject_session_state_with_nested_object_attribute_access():
462+
"""Test accessing attributes on objects (not just dicts)"""
463+
464+
class UserProfile:
465+
466+
def __init__(self):
467+
self.role = "Engineer"
468+
self.department = "Engineering"
469+
470+
class User:
471+
472+
def __init__(self):
473+
self.name = "Frank"
474+
self.profile = UserProfile()
475+
476+
instruction_template = "Name: {user.name}, Role: {user.profile.role}"
477+
invocation_context = await _create_test_readonly_context(
478+
state={"user": User()}
479+
)
480+
481+
populated_instruction = await instructions_utils.inject_session_state(
482+
instruction_template, invocation_context
483+
)
484+
assert populated_instruction == "Name: Frank, Role: Engineer"

0 commit comments

Comments
 (0)