Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions fedn/network/storage/dbconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
DatabaseConnection: A singleton class for managing database connections and stores.
"""

import os

import pymongo
from pymongo.database import Database
from sqlalchemy import create_engine
Expand Down Expand Up @@ -88,8 +90,9 @@ def __new__(cls, *, force_create_new: bool = False) -> "DatabaseConnection":
def _init_connection(self) -> None:
statestore_config = get_statestore_config()
network_id = get_network_config()
statestore_type = os.environ.get("FEDN_STATESTORE_TYPE", statestore_config["type"])

if statestore_config["type"] == "MongoDB":
if statestore_type == "MongoDB":
mdb: Database = self._setup_mongo(statestore_config, network_id)

client_store = MongoDBClientStore(mdb, "network.clients")
Expand All @@ -103,7 +106,7 @@ def _init_connection(self) -> None:
session_store = MongoDBSessionStore(mdb, "control.sessions")
analytic_store = MongoDBAnalyticStore(mdb, "control.analytics")

elif statestore_config["type"] in ["SQLite", "PostgreSQL"]:
elif statestore_type in ["SQLite", "PostgreSQL"]:
Session = self._setup_sql(statestore_config) # noqa: N806

client_store = SQLClientStore(Session)
Expand Down Expand Up @@ -144,18 +147,18 @@ def _setup_mongo(self, statestore_config: dict, network_id: str) -> "DatabaseCon
return mdb

def _setup_sql(self, statestore_config: dict) -> "DatabaseConnection":
if statestore_config["type"] == "SQLite":
statestore_type = os.environ.get("FEDN_STATESTORE_TYPE", statestore_config.get("type", ""))
if statestore_type == "SQLite":
sqlite_config = statestore_config["sqlite_config"]
dbname = sqlite_config["dbname"]
engine = create_engine(f"sqlite:///{dbname}", echo=False)
elif statestore_config["type"] == "PostgreSQL":
postgres_config = statestore_config["postgres_config"]
username = postgres_config["username"]
password = postgres_config["password"]
host = postgres_config["host"]
port = postgres_config["port"]

engine = create_engine(f"postgresql://{username}:{password}@{host}:{port}/fedn_db", echo=False)
elif statestore_type == "PostgreSQL":
username = os.environ.get("FEDN_STATESTORE_USERNAME", statestore_config.get("postgres_config", {}).get("username"))
password = os.environ.get("FEDN_STATESTORE_PASSWORD", statestore_config.get("postgres_config", {}).get("password"))
host = os.environ.get("FEDN_STATESTORE_HOST", statestore_config.get("postgres_config", {}).get("host"))
port = os.environ.get("FEDN_STATESTORE_PORT", statestore_config.get("postgres_config", {}).get("port"))
dbname = os.environ.get("FEDN_STATESTORE_DBNAME", statestore_config.get("postgres_config", {}).get("dbname"))
engine = create_engine(f"postgresql://{username}:{password}@{host}:{port}/{dbname}", echo=False)

Session = sessionmaker(engine) # noqa: N806

Expand Down
Loading