diff --git a/fedn/network/storage/dbconnection.py b/fedn/network/storage/dbconnection.py index dc1352a8d..44168e3d1 100644 --- a/fedn/network/storage/dbconnection.py +++ b/fedn/network/storage/dbconnection.py @@ -5,6 +5,8 @@ DatabaseConnection: A singleton class for managing database connections and stores. """ +import os + import pymongo from pymongo.database import Database from sqlalchemy import create_engine @@ -88,8 +90,9 @@ def __new__(cls, *, force_create_new: bool = False) -> "DatabaseConnection": def _init_connection(self) -> None: statestore_config = get_statestore_config() network_id = get_network_config() + statestore_type = os.environ.get("FEDN_STATESTORE_TYPE", statestore_config["type"]) - if statestore_config["type"] == "MongoDB": + if statestore_type == "MongoDB": mdb: Database = self._setup_mongo(statestore_config, network_id) client_store = MongoDBClientStore(mdb, "network.clients") @@ -103,7 +106,7 @@ def _init_connection(self) -> None: session_store = MongoDBSessionStore(mdb, "control.sessions") analytic_store = MongoDBAnalyticStore(mdb, "control.analytics") - elif statestore_config["type"] in ["SQLite", "PostgreSQL"]: + elif statestore_type in ["SQLite", "PostgreSQL"]: Session = self._setup_sql(statestore_config) # noqa: N806 client_store = SQLClientStore(Session) @@ -144,18 +147,18 @@ def _setup_mongo(self, statestore_config: dict, network_id: str) -> "DatabaseCon return mdb def _setup_sql(self, statestore_config: dict) -> "DatabaseConnection": - if statestore_config["type"] == "SQLite": + statestore_type = os.environ.get("FEDN_STATESTORE_TYPE", statestore_config.get("type", "")) + if statestore_type == "SQLite": sqlite_config = statestore_config["sqlite_config"] dbname = sqlite_config["dbname"] engine = create_engine(f"sqlite:///{dbname}", echo=False) - elif statestore_config["type"] == "PostgreSQL": - postgres_config = statestore_config["postgres_config"] - username = postgres_config["username"] - password = postgres_config["password"] - host = postgres_config["host"] - port = postgres_config["port"] - - engine = create_engine(f"postgresql://{username}:{password}@{host}:{port}/fedn_db", echo=False) + elif statestore_type == "PostgreSQL": + username = os.environ.get("FEDN_STATESTORE_USERNAME", statestore_config.get("postgres_config", {}).get("username")) + password = os.environ.get("FEDN_STATESTORE_PASSWORD", statestore_config.get("postgres_config", {}).get("password")) + host = os.environ.get("FEDN_STATESTORE_HOST", statestore_config.get("postgres_config", {}).get("host")) + port = os.environ.get("FEDN_STATESTORE_PORT", statestore_config.get("postgres_config", {}).get("port")) + dbname = os.environ.get("FEDN_STATESTORE_DBNAME", statestore_config.get("postgres_config", {}).get("dbname")) + engine = create_engine(f"postgresql://{username}:{password}@{host}:{port}/{dbname}", echo=False) Session = sessionmaker(engine) # noqa: N806