Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
__pycache__/
*.DS_Store
.vscode/settings.json
.env/
.env
.venv/
.idea/
.vscode/
*.iml
Expand Down
20 changes: 10 additions & 10 deletions run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,36 @@ set -e
# Fixes lint
function run_lint_fix {
echo -e "#### Fixing Python code"
python3 -m venv .env
source .env/bin/activate
python3 -m venv .venv
source .venv/bin/activate
pip3 install yapf==0.40.2 -q
if ! command -v isort &> /dev/null
then
pip3 install isort -q
fi
yapf -r -i -p --style='{based_on_style: google, indent_width: 2}' simple/ -e=*pb2.py -e=**/.env/**
isort simple/ --skip-glob=*pb2.py --skip-glob=**/.env/** --profile google
yapf -r -i -p --style='{based_on_style: google, indent_width: 2}' simple/ -e=*pb2.py -e=**/.venv/**
isort simple/ --skip-glob=*pb2.py --skip-glob=**/.venv/** --profile google
deactivate
}

# Lint test
function run_lint_test {
python3 -m venv .env
source .env/bin/activate
python3 -m venv .venv
source .venv/bin/activate
pip3 install yapf==0.40.2 -q
if ! command -v isort &> /dev/null
then
pip3 install isort -q
fi

echo -e "#### Checking Python style"
if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=**/.env/**; then
if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=**/.venv/**; then
echo "Fix Python lint errors by running ./run_test.sh -f"
exit 1
fi

echo -e "#### Checking Python import order"
if ! isort simple/ -c --skip-glob=*pb2.py --skip-glob=**/.env/** --profile google; then
if ! isort simple/ -c --skip-glob=*pb2.py --skip-glob=**/.venv/** --profile google; then
echo "Fix Python import sort orders by running ./run_test.sh -f"
exit 1
fi
Expand All @@ -72,8 +72,8 @@ function py_test {
# Do not use Cloud SQL.
export USE_CLOUDSQL=false

python3 -m venv .env
source .env/bin/activate
python3 -m venv .venv
source .venv/bin/activate

cd simple
pip3 install -r requirements.txt -q
Expand Down
1 change: 1 addition & 0 deletions simple/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ requests==2.31.0
rdflib==7.4.0
s2sphere==0.2.5
six==1.16.0
google-auth==2.49.0
tomli==2.0.1
tzdata==2023.3
urllib3==1.26.20
Expand Down
28 changes: 26 additions & 2 deletions simple/stats/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
import sqlite3
from typing import Any

from google.auth.exceptions import DefaultCredentialsError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

DefaultCredentialsError is imported but not used. Please remove this unused import to keep the code clean.

import google.auth.transport.requests
from google.auth.transport.requests import AuthorizedSession
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

AuthorizedSession is imported but not used. Please remove this unused import. The test test_insert_triples_into_datacommons_platform incorrectly mocks this, but the implementation uses requests.post.

from google.cloud.sql.connector.connector import Connector
from google.cloud.sql.connector.connector import IPTypes
from google.oauth2 import id_token
import pandas as pd
from pyld import jsonld
from pymysql.connections import Connection
Expand Down Expand Up @@ -415,6 +419,27 @@ class DataCommonsPlatformDb(Db):

def __init__(self, config: dict) -> None:
self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL]
self.nodes_url = self.url + self.NODES_PATH

def _get_id_token(url):
# 1. Try to get default credentials
creds, _ = google.auth.default()
auth_req = google.auth.transport.requests.Request()

# 2. Refresh to ensure the token is loaded
creds.refresh(auth_req)

# 3. Check if the credentials already have an id_token (typical for local gcloud)
if hasattr(creds, 'id_token') and creds.id_token:
return creds.id_token

# 4. Fallback to fetching it (typical for Service Accounts/Cloud environments)
return google.oauth2.id_token.fetch_id_token(auth_req, url)
id_token = _get_id_token(self.url)

# 2. Make the authenticated request
self.headers = {"Authorization": f"Bearer {id_token}"}
Comment on lines +424 to +441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation fetches the ID token only once during object initialization and stores it in self.headers. ID tokens are short-lived (typically 1 hour) and will expire. This will cause requests to fail with authentication errors if the DataCommonsPlatformDb object lives longer than the token's validity period.

To fix this, the token acquisition logic should be executed before each request. I recommend moving the _get_id_token function and the header creation logic out of __init__ and into a new private method, e.g., _get_auth_headers(). Then, call this method inside insert_triples to get fresh headers for each request.

Example of the refactoring:

class DataCommonsPlatformDb(Db):
  # ...

  def __init__(self, config: dict) -> None:
    self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL]
    self.nodes_url = self.url + self.NODES_PATH

  def _get_auth_headers(self) -> dict[str, str]:
    # This can be a nested function or another private method
    def _get_id_token(url):
      creds, _ = google.auth.default()
      auth_req = google.auth.transport.requests.Request()
      creds.refresh(auth_req)
      if hasattr(creds, 'id_token') and creds.id_token:
          return creds.id_token
      return google.oauth2.id_token.fetch_id_token(auth_req, url)
    
    token = _get_id_token(self.url)
    return {"Authorization": f"Bearer {token}"}

  def insert_triples(self, triples: list[Triple]):
    # ...
    # ... existing logic to create jsonld ...
    # ...
    response = requests.post(self.nodes_url, json=jsonld, headers=self._get_auth_headers())
    # ...

This ensures that a valid token is used for every API call, making the authentication robust against token expiration.



def maybe_clear_before_import(self):
# Not applicable for Data Commons Platform.
Expand All @@ -430,8 +455,7 @@ def insert_triples(self, triples: list[Triple]):
"Writing %s triples (%s nodes) to Data Commons Platform at [%s]",
len(triples), len(jsonld["@graph"]), self.url)
logging.info("Writing jsonld: %s", json.dumps(jsonld, indent=2))
nodes_url = self.url + self.NODES_PATH
response = requests.post(nodes_url, json=jsonld)
response = requests.post(self.nodes_url, json=jsonld, headers=self.headers)
if response.status_code != 200:
# TODO: For now, we just log a warning, but we should raise an exception.
logging.warning("Failed to write triples to Data Commons Platform: %s",
Expand Down
11 changes: 9 additions & 2 deletions simple/tests/stats/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,20 @@ def test_get_datacommons_platform_config_from_env(self):
}
})

@mock.patch('requests.post')
@mock.patch('stats.db.AuthorizedSession')
@mock.patch('stats.db.id_token.fetch_id_token')
@mock.patch('stats.db.google.auth.transport.requests.Request')
@mock.patch.dict(
os.environ, {
"USE_DATA_COMMONS_PLATFORM": "true",
"DATA_COMMONS_PLATFORM_URL": "https://test_url"
})
def test_insert_triples_into_datacommons_platform(self, mock_post):
def test_insert_triples_into_datacommons_platform(self, mock_auth_request, mock_fetch_id_token, mock_authorized_session):

mock_session_instance = mock.Mock()
mock_authorized_session.return_value = mock_session_instance
mock_post = mock_session_instance.post
Comment on lines +353 to +365
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test seems to be incorrectly mocking stats.db.AuthorizedSession. The implementation in stats/db.py uses requests.post directly with a manually constructed Authorization header; it does not use AuthorizedSession.

As a result, mock_post (which is mock_authorized_session.return_value.post) is never called, and the assertions on it will not test the actual behavior. The test is not validating that the POST request is made correctly.

To fix this, you should mock stats.db.requests.post instead and verify that it's called with the expected URL, JSON payload, and authentication headers. You will also need to mock the authentication flow to provide a predictable token for your header assertion.

Example:

  @mock.patch('stats.db.requests.post')
  @mock.patch('stats.db.google.auth.default')
  def test_insert_triples_into_datacommons_platform(self, mock_auth_default, mock_post):
    # 1. Setup mock credentials and token
    mock_creds = mock.Mock()
    mock_creds.id_token = "mock_id_token"
    mock_auth_default.return_value = (mock_creds, "test-project")

    # 2. Configure mock response for post
    mock_post.return_value.status_code = 200
    mock_post.return_value.text = "Success"

    # 3. Get config and create db instance
    config = get_datacommons_platform_config_from_env()
    db = create_and_update_db(config)

    # 4. Execute
    db.insert_triples(_TRIPLES)

    # 5. Assertions
    expected_url = "https://test_url/nodes"
    expected_headers = {"Authorization": "Bearer mock_id_token"}
    mock_post.assert_called_once()
    args, kwargs = mock_post.call_args
    self.assertEqual(args[0], expected_url)
    self.assertEqual(kwargs['headers'], expected_headers)
    # ... other assertions on json payload ...


config = get_datacommons_platform_config_from_env()
db = create_and_update_db(config)

Expand Down