-
Notifications
You must be signed in to change notification settings - Fork 34
Add auth token to DCP requests #487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,8 @@ | |
| __pycache__/ | ||
| *.DS_Store | ||
| .vscode/settings.json | ||
| .env/ | ||
| .env | ||
| .venv/ | ||
| .idea/ | ||
| .vscode/ | ||
| *.iml | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,8 +22,12 @@ | |
| import sqlite3 | ||
| from typing import Any | ||
|
|
||
| from google.auth.exceptions import DefaultCredentialsError | ||
| import google.auth.transport.requests | ||
| from google.auth.transport.requests import AuthorizedSession | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| from google.cloud.sql.connector.connector import Connector | ||
| from google.cloud.sql.connector.connector import IPTypes | ||
| from google.oauth2 import id_token | ||
| import pandas as pd | ||
| from pyld import jsonld | ||
| from pymysql.connections import Connection | ||
|
|
@@ -415,6 +419,27 @@ class DataCommonsPlatformDb(Db): | |
|
|
||
| def __init__(self, config: dict) -> None: | ||
| self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL] | ||
| self.nodes_url = self.url + self.NODES_PATH | ||
|
|
||
| def _get_id_token(url): | ||
| # 1. Try to get default credentials | ||
| creds, _ = google.auth.default() | ||
| auth_req = google.auth.transport.requests.Request() | ||
|
|
||
| # 2. Refresh to ensure the token is loaded | ||
| creds.refresh(auth_req) | ||
|
|
||
| # 3. Check if the credentials already have an id_token (typical for local gcloud) | ||
| if hasattr(creds, 'id_token') and creds.id_token: | ||
| return creds.id_token | ||
|
|
||
| # 4. Fallback to fetching it (typical for Service Accounts/Cloud environments) | ||
| return google.oauth2.id_token.fetch_id_token(auth_req, url) | ||
| id_token = _get_id_token(self.url) | ||
|
|
||
| # 2. Make the authenticated request | ||
| self.headers = {"Authorization": f"Bearer {id_token}"} | ||
|
Comment on lines
+424
to
+441
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation fetches the ID token only once during object initialization and stores it in To fix this, the token acquisition logic should be executed before each request. I recommend moving the Example of the refactoring: class DataCommonsPlatformDb(Db):
# ...
def __init__(self, config: dict) -> None:
self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL]
self.nodes_url = self.url + self.NODES_PATH
def _get_auth_headers(self) -> dict[str, str]:
# This can be a nested function or another private method
def _get_id_token(url):
creds, _ = google.auth.default()
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
if hasattr(creds, 'id_token') and creds.id_token:
return creds.id_token
return google.oauth2.id_token.fetch_id_token(auth_req, url)
token = _get_id_token(self.url)
return {"Authorization": f"Bearer {token}"}
def insert_triples(self, triples: list[Triple]):
# ...
# ... existing logic to create jsonld ...
# ...
response = requests.post(self.nodes_url, json=jsonld, headers=self._get_auth_headers())
# ...This ensures that a valid token is used for every API call, making the authentication robust against token expiration. |
||
|
|
||
|
|
||
| def maybe_clear_before_import(self): | ||
| # Not applicable for Data Commons Platform. | ||
|
|
@@ -430,8 +455,7 @@ def insert_triples(self, triples: list[Triple]): | |
| "Writing %s triples (%s nodes) to Data Commons Platform at [%s]", | ||
| len(triples), len(jsonld["@graph"]), self.url) | ||
| logging.info("Writing jsonld: %s", json.dumps(jsonld, indent=2)) | ||
| nodes_url = self.url + self.NODES_PATH | ||
| response = requests.post(nodes_url, json=jsonld) | ||
| response = requests.post(self.nodes_url, json=jsonld, headers=self.headers) | ||
| if response.status_code != 200: | ||
| # TODO: For now, we just log a warning, but we should raise an exception. | ||
| logging.warning("Failed to write triples to Data Commons Platform: %s", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -350,13 +350,20 @@ def test_get_datacommons_platform_config_from_env(self): | |
| } | ||
| }) | ||
|
|
||
| @mock.patch('requests.post') | ||
| @mock.patch('stats.db.AuthorizedSession') | ||
| @mock.patch('stats.db.id_token.fetch_id_token') | ||
| @mock.patch('stats.db.google.auth.transport.requests.Request') | ||
| @mock.patch.dict( | ||
| os.environ, { | ||
| "USE_DATA_COMMONS_PLATFORM": "true", | ||
| "DATA_COMMONS_PLATFORM_URL": "https://test_url" | ||
| }) | ||
| def test_insert_triples_into_datacommons_platform(self, mock_post): | ||
| def test_insert_triples_into_datacommons_platform(self, mock_auth_request, mock_fetch_id_token, mock_authorized_session): | ||
|
|
||
| mock_session_instance = mock.Mock() | ||
| mock_authorized_session.return_value = mock_session_instance | ||
| mock_post = mock_session_instance.post | ||
|
Comment on lines
+353
to
+365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test seems to be incorrectly mocking As a result, To fix this, you should mock Example: @mock.patch('stats.db.requests.post')
@mock.patch('stats.db.google.auth.default')
def test_insert_triples_into_datacommons_platform(self, mock_auth_default, mock_post):
# 1. Setup mock credentials and token
mock_creds = mock.Mock()
mock_creds.id_token = "mock_id_token"
mock_auth_default.return_value = (mock_creds, "test-project")
# 2. Configure mock response for post
mock_post.return_value.status_code = 200
mock_post.return_value.text = "Success"
# 3. Get config and create db instance
config = get_datacommons_platform_config_from_env()
db = create_and_update_db(config)
# 4. Execute
db.insert_triples(_TRIPLES)
# 5. Assertions
expected_url = "https://test_url/nodes"
expected_headers = {"Authorization": "Bearer mock_id_token"}
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], expected_url)
self.assertEqual(kwargs['headers'], expected_headers)
# ... other assertions on json payload ... |
||
|
|
||
| config = get_datacommons_platform_config_from_env() | ||
| db = create_and_update_db(config) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DefaultCredentialsErroris imported but not used. Please remove this unused import to keep the code clean.