diff --git a/.gitignore b/.gitignore index 66262ad7..8c591086 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ __pycache__/ *.DS_Store .vscode/settings.json -.env/ +.env +.venv/ .idea/ .vscode/ *.iml diff --git a/run_test.sh b/run_test.sh index 829019af..c06dde4a 100755 --- a/run_test.sh +++ b/run_test.sh @@ -19,22 +19,22 @@ set -e # Fixes lint function run_lint_fix { echo -e "#### Fixing Python code" - python3 -m venv .env - source .env/bin/activate + python3 -m venv .venv + source .venv/bin/activate pip3 install yapf==0.40.2 -q if ! command -v isort &> /dev/null then pip3 install isort -q fi - yapf -r -i -p --style='{based_on_style: google, indent_width: 2}' simple/ -e=*pb2.py -e=**/.env/** - isort simple/ --skip-glob=*pb2.py --skip-glob=**/.env/** --profile google + yapf -r -i -p --style='{based_on_style: google, indent_width: 2}' simple/ -e=*pb2.py -e=**/.venv/** + isort simple/ --skip-glob=*pb2.py --skip-glob=**/.venv/** --profile google deactivate } # Lint test function run_lint_test { - python3 -m venv .env - source .env/bin/activate + python3 -m venv .venv + source .venv/bin/activate pip3 install yapf==0.40.2 -q if ! command -v isort &> /dev/null then @@ -42,13 +42,13 @@ function run_lint_test { fi echo -e "#### Checking Python style" - if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=**/.env/**; then + if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=**/.venv/**; then echo "Fix Python lint errors by running ./run_test.sh -f" exit 1 fi echo -e "#### Checking Python import order" - if ! isort simple/ -c --skip-glob=*pb2.py --skip-glob=**/.env/** --profile google; then + if ! isort simple/ -c --skip-glob=*pb2.py --skip-glob=**/.venv/** --profile google; then echo "Fix Python import sort orders by running ./run_test.sh -f" exit 1 fi @@ -72,8 +72,8 @@ function py_test { # Do not use Cloud SQL. export USE_CLOUDSQL=false - python3 -m venv .env - source .env/bin/activate + python3 -m venv .venv + source .venv/bin/activate cd simple pip3 install -r requirements.txt -q diff --git a/simple/requirements.txt b/simple/requirements.txt index a06f59fe..2ad20545 100644 --- a/simple/requirements.txt +++ b/simple/requirements.txt @@ -27,6 +27,7 @@ requests==2.31.0 rdflib==7.4.0 s2sphere==0.2.5 six==1.16.0 +google-auth==2.49.0 tomli==2.0.1 tzdata==2023.3 urllib3==1.26.20 diff --git a/simple/stats/db.py b/simple/stats/db.py index 22e50fcb..846f77f8 100644 --- a/simple/stats/db.py +++ b/simple/stats/db.py @@ -22,8 +22,12 @@ import sqlite3 from typing import Any +from google.auth.exceptions import DefaultCredentialsError +import google.auth.transport.requests +from google.auth.transport.requests import AuthorizedSession from google.cloud.sql.connector.connector import Connector from google.cloud.sql.connector.connector import IPTypes +from google.oauth2 import id_token import pandas as pd from pyld import jsonld from pymysql.connections import Connection @@ -415,6 +419,27 @@ class DataCommonsPlatformDb(Db): def __init__(self, config: dict) -> None: self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL] + self.nodes_url = self.url + self.NODES_PATH + + def _get_id_token(url): + # 1. Try to get default credentials + creds, _ = google.auth.default() + auth_req = google.auth.transport.requests.Request() + + # 2. Refresh to ensure the token is loaded + creds.refresh(auth_req) + + # 3. Check if the credentials already have an id_token (typical for local gcloud) + if hasattr(creds, 'id_token') and creds.id_token: + return creds.id_token + + # 4. Fallback to fetching it (typical for Service Accounts/Cloud environments) + return google.oauth2.id_token.fetch_id_token(auth_req, url) + id_token = _get_id_token(self.url) + + # 2. Make the authenticated request + self.headers = {"Authorization": f"Bearer {id_token}"} + def maybe_clear_before_import(self): # Not applicable for Data Commons Platform. @@ -430,8 +455,7 @@ def insert_triples(self, triples: list[Triple]): "Writing %s triples (%s nodes) to Data Commons Platform at [%s]", len(triples), len(jsonld["@graph"]), self.url) logging.info("Writing jsonld: %s", json.dumps(jsonld, indent=2)) - nodes_url = self.url + self.NODES_PATH - response = requests.post(nodes_url, json=jsonld) + response = requests.post(self.nodes_url, json=jsonld, headers=self.headers) if response.status_code != 200: # TODO: For now, we just log a warning, but we should raise an exception. logging.warning("Failed to write triples to Data Commons Platform: %s", diff --git a/simple/tests/stats/db_test.py b/simple/tests/stats/db_test.py index 46f1b014..47fb0d82 100644 --- a/simple/tests/stats/db_test.py +++ b/simple/tests/stats/db_test.py @@ -350,13 +350,20 @@ def test_get_datacommons_platform_config_from_env(self): } }) - @mock.patch('requests.post') + @mock.patch('stats.db.AuthorizedSession') + @mock.patch('stats.db.id_token.fetch_id_token') + @mock.patch('stats.db.google.auth.transport.requests.Request') @mock.patch.dict( os.environ, { "USE_DATA_COMMONS_PLATFORM": "true", "DATA_COMMONS_PLATFORM_URL": "https://test_url" }) - def test_insert_triples_into_datacommons_platform(self, mock_post): + def test_insert_triples_into_datacommons_platform(self, mock_auth_request, mock_fetch_id_token, mock_authorized_session): + + mock_session_instance = mock.Mock() + mock_authorized_session.return_value = mock_session_instance + mock_post = mock_session_instance.post + config = get_datacommons_platform_config_from_env() db = create_and_update_db(config)