Add auth token to DCP requests#487
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the security and operational integrity of interactions with the Data Commons Platform by integrating robust authentication. It ensures that all data submission requests are authorized using Google ID tokens, aligning with best practices for secure API communication. Additionally, it refines development environment setup by standardizing virtual environment naming. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds authentication to requests made to the Data Commons Platform by fetching and including an ID token. The changes are mostly in simple/stats/db.py and its corresponding test file.
My review has identified a critical issue with how the authentication token is handled: it's fetched only once and will expire, leading to authentication failures in long-running processes. I've also found that the updated test doesn't correctly mock the new implementation, and there are a couple of unused imports. I've provided detailed comments and suggestions for how to address these points.
| def _get_id_token(url): | ||
| # 1. Try to get default credentials | ||
| creds, _ = google.auth.default() | ||
| auth_req = google.auth.transport.requests.Request() | ||
|
|
||
| # 2. Refresh to ensure the token is loaded | ||
| creds.refresh(auth_req) | ||
|
|
||
| # 3. Check if the credentials already have an id_token (typical for local gcloud) | ||
| if hasattr(creds, 'id_token') and creds.id_token: | ||
| return creds.id_token | ||
|
|
||
| # 4. Fallback to fetching it (typical for Service Accounts/Cloud environments) | ||
| return google.oauth2.id_token.fetch_id_token(auth_req, url) | ||
| id_token = _get_id_token(self.url) | ||
|
|
||
| # 2. Make the authenticated request | ||
| self.headers = {"Authorization": f"Bearer {id_token}"} |
There was a problem hiding this comment.
The current implementation fetches the ID token only once during object initialization and stores it in self.headers. ID tokens are short-lived (typically 1 hour) and will expire. This will cause requests to fail with authentication errors if the DataCommonsPlatformDb object lives longer than the token's validity period.
To fix this, the token acquisition logic should be executed before each request. I recommend moving the _get_id_token function and the header creation logic out of __init__ and into a new private method, e.g., _get_auth_headers(). Then, call this method inside insert_triples to get fresh headers for each request.
Example of the refactoring:
class DataCommonsPlatformDb(Db):
# ...
def __init__(self, config: dict) -> None:
self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL]
self.nodes_url = self.url + self.NODES_PATH
def _get_auth_headers(self) -> dict[str, str]:
# This can be a nested function or another private method
def _get_id_token(url):
creds, _ = google.auth.default()
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
if hasattr(creds, 'id_token') and creds.id_token:
return creds.id_token
return google.oauth2.id_token.fetch_id_token(auth_req, url)
token = _get_id_token(self.url)
return {"Authorization": f"Bearer {token}"}
def insert_triples(self, triples: list[Triple]):
# ...
# ... existing logic to create jsonld ...
# ...
response = requests.post(self.nodes_url, json=jsonld, headers=self._get_auth_headers())
# ...This ensures that a valid token is used for every API call, making the authentication robust against token expiration.
| @mock.patch('stats.db.AuthorizedSession') | ||
| @mock.patch('stats.db.id_token.fetch_id_token') | ||
| @mock.patch('stats.db.google.auth.transport.requests.Request') | ||
| @mock.patch.dict( | ||
| os.environ, { | ||
| "USE_DATA_COMMONS_PLATFORM": "true", | ||
| "DATA_COMMONS_PLATFORM_URL": "https://test_url" | ||
| }) | ||
| def test_insert_triples_into_datacommons_platform(self, mock_post): | ||
| def test_insert_triples_into_datacommons_platform(self, mock_auth_request, mock_fetch_id_token, mock_authorized_session): | ||
|
|
||
| mock_session_instance = mock.Mock() | ||
| mock_authorized_session.return_value = mock_session_instance | ||
| mock_post = mock_session_instance.post |
There was a problem hiding this comment.
This test seems to be incorrectly mocking stats.db.AuthorizedSession. The implementation in stats/db.py uses requests.post directly with a manually constructed Authorization header; it does not use AuthorizedSession.
As a result, mock_post (which is mock_authorized_session.return_value.post) is never called, and the assertions on it will not test the actual behavior. The test is not validating that the POST request is made correctly.
To fix this, you should mock stats.db.requests.post instead and verify that it's called with the expected URL, JSON payload, and authentication headers. You will also need to mock the authentication flow to provide a predictable token for your header assertion.
Example:
@mock.patch('stats.db.requests.post')
@mock.patch('stats.db.google.auth.default')
def test_insert_triples_into_datacommons_platform(self, mock_auth_default, mock_post):
# 1. Setup mock credentials and token
mock_creds = mock.Mock()
mock_creds.id_token = "mock_id_token"
mock_auth_default.return_value = (mock_creds, "test-project")
# 2. Configure mock response for post
mock_post.return_value.status_code = 200
mock_post.return_value.text = "Success"
# 3. Get config and create db instance
config = get_datacommons_platform_config_from_env()
db = create_and_update_db(config)
# 4. Execute
db.insert_triples(_TRIPLES)
# 5. Assertions
expected_url = "https://test_url/nodes"
expected_headers = {"Authorization": "Bearer mock_id_token"}
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], expected_url)
self.assertEqual(kwargs['headers'], expected_headers)
# ... other assertions on json payload ...| import sqlite3 | ||
| from typing import Any | ||
|
|
||
| from google.auth.exceptions import DefaultCredentialsError |
|
|
||
| from google.auth.exceptions import DefaultCredentialsError | ||
| import google.auth.transport.requests | ||
| from google.auth.transport.requests import AuthorizedSession |
No description provided.