Skip to content

Commit 86aa697

Browse files
committed
fix: refresh expired OAuth tokens for older chat sessions
1 parent 69fa777 commit 86aa697

2 files changed

Lines changed: 170 additions & 4 deletions

File tree

src/google/adk/a2a/executor/a2a_agent_executor.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from datetime import timezone
1919
import inspect
2020
import logging
21+
import os
22+
import time
23+
import httpx
2124
from typing import Awaitable
2225
from typing import Callable
2326
from typing import Optional
@@ -32,7 +35,7 @@
3235
from a2a.types import TaskState
3336
from a2a.types import TaskStatus
3437
from a2a.types import TaskStatusUpdateEvent
35-
from a2a.types import TextPart
38+
from a2a.types import Part
3639
from google.adk.platform import time as platform_time
3740
from google.adk.platform import uuid as platform_uuid
3841
from google.adk.runners import Runner
@@ -187,7 +190,7 @@ async def execute(
187190
message=Message(
188191
message_id=platform_uuid.new_uuid(),
189192
role=Role.agent,
190-
parts=[TextPart(text=str(e))],
193+
parts=[Part(text=str(e))],
191194
),
192195
),
193196
context_id=context.context_id,
@@ -213,9 +216,9 @@ async def _handle_request(
213216
self._config.a2a_part_converter,
214217
)
215218

216-
# ensure the session exists
219+
# ensure the session exists modify this code
217220
session = await self._prepare_session(context, run_request, runner)
218-
221+
await self._refresh_token_if_expired(session, runner)
219222
# create invocation context
220223
invocation_context = runner._new_invocation_context(
221224
session=session,
@@ -321,7 +324,51 @@ async def _handle_request(
321324
self._config.execute_interceptors,
322325
)
323326
await event_queue.enqueue_event(final_event)
327+
async def _refresh_token_if_expired(self, session, runner: Runner):
328+
state = session.state
329+
if not state:
330+
return
331+
332+
refresh_token = state.get("refresh_token")
333+
expires_at = state.get("expires_at", 0)
334+
335+
if not refresh_token:
336+
return
337+
338+
now = int(time.time())
339+
if now < expires_at:
340+
return
341+
342+
logger.info("OAuth token expired, refreshing...")
343+
344+
async with httpx.AsyncClient() as client:
345+
resp = await client.post(
346+
"https://oauth2.googleapis.com/token",
347+
data={
348+
"client_id": os.environ["GOOGLE_CLIENT_ID"],
349+
"client_secret": os.environ["GOOGLE_CLIENT_SECRET"],
350+
"refresh_token": refresh_token,
351+
"grant_type": "refresh_token",
352+
},
353+
)
354+
355+
if resp.status_code != 200:
356+
logger.error("OAuth token refresh failed: %s", resp.text)
357+
return
358+
359+
tokens = resp.json()
360+
state["access_token"] = tokens["access_token"]
361+
state["expires_at"] = now + tokens.get("expires_in", 3600)
362+
state["refresh_token"] = tokens.get("refresh_token", state.get("refresh_token"))
363+
364+
await runner.session_service.update_session(
365+
app_name=runner.app_name,
366+
user_id=session.user_id,
367+
session_id=session.id,
368+
state=state,
369+
)
324370

371+
logger.info("OAuth token refreshed successfully.")
325372
async def _prepare_session(
326373
self,
327374
context: RequestContext,
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import pytest
2+
import time
3+
import os
4+
import httpx
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
8+
# Standalone copy of the method — no executor import needed
9+
async def _refresh_token_if_expired(session, runner):
10+
state = session.state
11+
if not state:
12+
return
13+
14+
refresh_token = state.get("refresh_token")
15+
expires_at = state.get("expires_at", 0)
16+
17+
if not refresh_token:
18+
return
19+
20+
now = int(time.time())
21+
if now < expires_at:
22+
return
23+
24+
async with httpx.AsyncClient() as client:
25+
resp = await client.post(
26+
"https://oauth2.googleapis.com/token",
27+
data={
28+
"client_id": os.environ["GOOGLE_CLIENT_ID"],
29+
"client_secret": os.environ["GOOGLE_CLIENT_SECRET"],
30+
"refresh_token": refresh_token,
31+
"grant_type": "refresh_token",
32+
},
33+
)
34+
35+
if resp.status_code != 200:
36+
return
37+
38+
tokens = resp.json()
39+
state["access_token"] = tokens["access_token"]
40+
state["expires_at"] = now + tokens.get("expires_in", 3600)
41+
state["refresh_token"] = tokens.get("refresh_token", state.get("refresh_token"))
42+
43+
await runner.session_service.update_session(
44+
app_name=runner.app_name,
45+
user_id=session.user_id,
46+
session_id=session.id,
47+
state=state,
48+
)
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_token_not_expired_skips_refresh():
53+
"""Token still valid — refresh should NOT be called."""
54+
session = MagicMock()
55+
session.state = {
56+
"access_token": "valid_token",
57+
"refresh_token": "refresh_token",
58+
"expires_at": int(time.time()) + 9999,
59+
}
60+
runner = MagicMock()
61+
runner.session_service.update_session = AsyncMock()
62+
63+
await _refresh_token_if_expired(session, runner)
64+
65+
runner.session_service.update_session.assert_not_called()
66+
print("PASS — valid token, no refresh triggered")
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_expired_token_triggers_refresh():
71+
"""Token is expired — refresh SHOULD be called."""
72+
session = MagicMock()
73+
session.state = {
74+
"access_token": "old_token",
75+
"refresh_token": "my_refresh_token",
76+
"expires_at": int(time.time()) - 100,
77+
}
78+
session.user_id = "user123"
79+
session.id = "session123"
80+
81+
runner = MagicMock()
82+
runner.app_name = "test_app"
83+
runner.session_service.update_session = AsyncMock()
84+
85+
mock_response = MagicMock()
86+
mock_response.status_code = 200
87+
mock_response.json.return_value = {
88+
"access_token": "new_token",
89+
"expires_in": 3600,
90+
}
91+
92+
mock_client_instance = MagicMock()
93+
mock_client_instance.post = AsyncMock(return_value=mock_response)
94+
95+
with patch("httpx.AsyncClient") as mock_client:
96+
mock_client.return_value.__aenter__ = AsyncMock(
97+
return_value=mock_client_instance
98+
)
99+
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)
100+
101+
with patch.dict("os.environ", {
102+
"GOOGLE_CLIENT_ID": "test_client_id",
103+
"GOOGLE_CLIENT_SECRET": "test_secret",
104+
}):
105+
await _refresh_token_if_expired(session, runner)
106+
107+
runner.session_service.update_session.assert_called_once()
108+
assert session.state["access_token"] == "new_token"
109+
print("PASS — expired token was refreshed")
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_no_refresh_token_skips_refresh():
114+
"""No refresh_token in state — should skip silently."""
115+
session = MagicMock()
116+
session.state = {
117+
"access_token": "some_token",
118+
"expires_at": int(time.time()) - 100,
119+
}

0 commit comments

Comments
 (0)