Skip to content

Commit

Permalink
move utils to common and migrate to get_env_key
Browse files Browse the repository at this point in the history
  • Loading branch information
samj committed Oct 9, 2024
1 parent 56bae51 commit c43f133
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 11 deletions.
5 changes: 3 additions & 2 deletions __main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
sys.path.append(str(base_dir))
from common.paths import backend_dir, venv_dir, cert_dir
from common.config import logging_config
from backend.utils import get_env_key

# check environment
from backend.env import check_env
Expand Down Expand Up @@ -42,8 +43,8 @@ def cleanup():
app = create_app()

# Define host and port
host = os.environ.get("PAIOS_HOST", "localhost")
port = int(os.environ.get("PAIOS_PORT", 8443))
host = get_env_key("PAIOS_HOST", "localhost")
port = int(get_env_key("PAIOS_PORT", 8443))

# Log connection details
logger.info(f"You can access pAI-OS at https://{host}:{port}.")
Expand Down
5 changes: 3 additions & 2 deletions backend/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
sys.path.append(str(base_dir))
from common.paths import backend_dir, venv_dir, cert_dir
from common.config import logging_config
from backend.utils import get_env_key

# check environment
from backend.env import check_env
Expand Down Expand Up @@ -43,8 +44,8 @@ def cleanup():
app = create_app()

# Define host and port
host = os.environ.get("PAIOS_HOST", "localhost")
port = int(os.environ.get("PAIOS_PORT", 8443))
host = get_env_key("PAIOS_HOST", "localhost")
port = int(get_env_key("PAIOS_PORT", 8443))

# Log connection details
logger.info(f"You can access pAI-OS at https://{host}:{port}.")
Expand Down
7 changes: 4 additions & 3 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from connexion.middleware import MiddlewarePosition
from starlette.middleware.cors import CORSMiddleware
from backend.db import init_db
from backend.utils import get_env_key

def create_backend_app():
# Initialize the database
Expand All @@ -19,9 +20,9 @@ def create_backend_app():
]

# Add PAIOS server URL if environment variables are set
paios_scheme = os.environ.get('PAIOS_SCHEME', 'https')
paios_host = os.environ.get('PAIOS_HOST', 'localhost')
paios_port = os.environ.get('PAIOS_PORT', '8443')
paios_scheme = get_env_key('PAIOS_SCHEME', 'https')
paios_host = get_env_key('PAIOS_HOST', 'localhost')
paios_port = get_env_key('PAIOS_PORT', '8443')

if paios_host:
paios_url = f"{paios_scheme}://{paios_host}"
Expand Down
2 changes: 1 addition & 1 deletion backend/encryption.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from cryptography.fernet import Fernet
from backend.utils import get_env_key
from common.utils import get_env_key

class Encryption:
_instance = None
Expand Down
2 changes: 1 addition & 1 deletion backend/managers/AbilitiesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import signal
from common.paths import abilities_dir, abilities_data_dir, venv_bin_dir
from backend.utils import remove_null_fields
from common.utils import remove_null_fields
from enum import Enum
from pathlib import Path
from threading import Lock
Expand Down
2 changes: 1 addition & 1 deletion backend/managers/AuthManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from webauthn.helpers.cose import COSEAlgorithmIdentifier
from connexion.exceptions import Unauthorized
from backend.utils import get_env_key
from common.utils import get_env_key

# set up logging
from common.log import get_logger
Expand Down
2 changes: 1 addition & 1 deletion backend/managers/DownloadsManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from urllib.parse import urlparse
from common.paths import data_dir, downloads_dir
from backend.utils import filter_dict, remove_null_fields
from common.utils import filter_dict, remove_null_fields
from threading import Lock

class DownloadStatus(Enum):
Expand Down
41 changes: 41 additions & 0 deletions common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from dotenv import set_key
from common.paths import base_dir

# set up logging
from common.log import get_logger
logger = get_logger(__name__)

def get_env_key(key_name, default=None):
value = os.environ.get(key_name)
if not value:
# If default is a function, call it to get the value, otherwise use it as the value
if default is not None:
if callable(default):
value = default()
else:
value = str(default)
else:
raise ValueError(f"{key_name} is not set in the environment variables")
set_key(base_dir / '.env', key_name, value)
return value

# Returns dict with null fields removed (e.g., for OpenAPI spec compliant
# responses without having to set nullable: true)
def remove_null_fields(data):
if isinstance(data, dict):
return {k: remove_null_fields(v) for k, v in data.items() if v is not None}
elif isinstance(data, list):
return [remove_null_fields(item) for item in data if item is not None]
else:
return data

# Returns dict with only keys_to_include (e.g., for OpenAPI spec compliant
# responses without unexpected fields present)
def filter_dict(data, keys_to_include):
return {k: data[k] for k in keys_to_include if k in data}

# Converts a db result into a dict with named fields (e.g.,
# ["x", "y"], [1, 2] -> { "x": 1, "y": 2})
def zip_fields(fields, result):
return {field: result[i] for i, field in enumerate(fields)}

0 comments on commit c43f133

Please sign in to comment.