Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions packages/datacommons-api/datacommons_api/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from datacommons_api.app import app
from datacommons_api.core.config import get_config, initialize_config
from datacommons_api.core.logging import get_logger, setup_logging
from datacommons_db.session import get_session, initialize_db
from datacommons_db.session import get_session
from datacommons_api.services.graph_service import GraphService

setup_logging()
Expand Down Expand Up @@ -56,16 +56,6 @@ def start(
gcp_spanner_database_name=gcp_spanner_database_name,
)

# Initialize the database
logger.info("Initializing database...")
logger.info("GCP Project ID: %s", config.GCP_PROJECT_ID)
logger.info("GCP Spanner Instance ID: %s", config.GCP_SPANNER_INSTANCE_ID)
logger.info("GCP Spanner Database Name: %s", config.GCP_SPANNER_DATABASE_NAME)
initialize_db(
config.GCP_PROJECT_ID,
config.GCP_SPANNER_INSTANCE_ID,
config.GCP_SPANNER_DATABASE_NAME,
)
logger.info("Starting API server...")
uvicorn.run(
"datacommons_api.app:app",
Expand Down
35 changes: 1 addition & 34 deletions packages/datacommons-db/datacommons_db/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@

import logging

from sqlalchemy import Engine, create_engine, inspect
from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker

from datacommons_db.models.base import Base

logger = logging.getLogger(__name__)

REQUIRED_TABLES = ["Edge", "Node", "Observation"]


# DDL for Creating Property Graph
DDL_PROPERTY_GRAPH = """
Expand Down Expand Up @@ -93,34 +91,3 @@ def get_session(project_id: str, instance_id: str, database_name: str) -> Sessio
engine = get_engine(project_id, instance_id, database_name)
session = sessionmaker(bind=engine)
return session()


def initialize_db(project_id: str, instance_id: str, database_name: str):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with this function removed, can you add some instructions to the README on how users can get initialize spanner?

It might include something like:
(1) option 1, use terraform to deploy to GCP, and give instructions for provisioning spanner from there
(2) option 2, without terraform, clone the https://github.com/datacommonsorg/import/ repo, and use the import pipeline w / DirectRunner to initialize a spanner db

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but I will do it in a follow up so i can reference the terraforms once they're checked in! Will do it tomorrow at the latest!

"""Initialize the Spanner database.

Args:
project_id: GCP project ID
instance_id: Cloud Spanner instance ID
database_name: Cloud Spanner database name
"""
engine = get_engine(project_id, instance_id, database_name)

# Check if database is empty by inspecting existing tables
inspector = inspect(engine)
existing_tables = inspector.get_table_names()

# Check if all required tables exist
missing_tables = [
table for table in REQUIRED_TABLES if table not in existing_tables
]
if missing_tables:
logger.warning(
"Missing required tables in database %s: %s", database_name, missing_tables
)

# Only create tables if database is completely empty
if not existing_tables or missing_tables:
# Import all models so they are properly initialized with the call to Base.metadata.create_all
logger.info("Creating tables %s in database %s", REQUIRED_TABLES, database_name)
Base.metadata.create_all(engine)
create_property_graph(engine)
Loading