diff --git a/app/config/__init__.py b/app/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/config/config.py b/app/config/config.py new file mode 100644 index 0000000..48cc74a --- /dev/null +++ b/app/config/config.py @@ -0,0 +1,197 @@ +from typing import Dict, Any, Optional + +from app.config.environment import get_environment_config, EnvironmentConfig +from app.config.loader import load_and_validate_config, get_legacy_config_dict, ConfigValidationError +from app.config.validation import ImpulseConfig +from app.logging import logger + + +class UnifiedConfig: + """ + Unified configuration that combines environment and validated application config. + Uses existing ImpulseConfig as source of truth for application configuration. + """ + + def __init__(self, env: EnvironmentConfig, app: ImpulseConfig, legacy: Dict[str, Any]): + self.env = env + self.app = app + self.legacy = legacy + + self.INCIDENT_ACTUAL_VERSION = 'v0.4' + self.check_updates = True + + @property + def settings(self) -> Dict[str, Any]: + return self.legacy + + @property + def incident(self): + return self.app.incident + + @property + def experimental(self): + return self.app.experimental + + @property + def application(self): + return self.app.application + + @property + def ui_config(self): + return self.app.ui + + @property + def slack_bot_user_oauth_token(self) -> str: + return self.env.slack_bot_user_oauth_token + + @property + def slack_verification_token(self) -> str: + return self.env.slack_verification_token + + @property + def mattermost_access_token(self) -> str: + return self.env.mattermost_access_token + + @property + def telegram_bot_token(self) -> str: + return self.env.telegram_bot_token + + @property + def data_path(self) -> str: + return self.env.data_path + + @property + def config_path(self) -> str: + return self.env.config_path + + @property + def incidents_path(self) -> str: + return self.env.incidents_path + + @property + def provider_sync_interval(self) -> int: + return self.env.provider_sync_interval + + @property + def provider_max_events(self) -> int: + return self.env.provider_max_events + + @property + def provider_days_to_sync(self) -> int: + return self.env.provider_days_to_sync + + @property + def provider_service_account_file(self) -> str: + return self.env.provider_service_account_file + + @property + def cors_allowed_origins(self) -> list: + return self.env.cors_allowed_origins + + +_config: Optional[UnifiedConfig] = None + + +def get_config() -> UnifiedConfig: + """ + Get the global configuration instance. + """ + global _config + if _config is None: + _config = load_unified_config() + return _config + + +def load_unified_config(config_path: Optional[str] = None, exit_on_error: bool = True) -> UnifiedConfig: + """ + Load and create unified configuration from environment and YAML file. + + Args: + config_path: Optional path to configuration file. Uses env CONFIG_PATH if not provided. + exit_on_error: If True, exit process on validation errors. If False, raise exception. + + Returns: + UnifiedConfig: Complete configuration object + + Raises: + ConfigValidationError: If configuration loading or validation fails + SystemExit: If configuration is invalid and exit_on_error is True + """ + try: + env_config = get_environment_config() + + if config_path is None: + config_path = env_config.config_file_path + + validated_config, raw_config = load_and_validate_config(config_path) + + legacy_config = get_legacy_config_dict(validated_config) + + return UnifiedConfig( + env=env_config, + app=validated_config, + legacy=legacy_config + ) + + except ConfigValidationError as e: + error_msg = (f"{e}\n" + f"Please check your impulse.yml file and fix any validation errors.\n" + f"Documentation: https://docs.impulse.bot/latest/config_file/") + if exit_on_error: + logger.error(error_msg) + raise SystemExit(1) + else: + logger.warning(error_msg) + raise + except Exception as e: + error_msg = f"Failed to load configuration: {e}" + if exit_on_error: + logger.error(error_msg) + raise SystemExit(1) + else: + logger.warning(error_msg) + raise + + +def reload_config(config_path: Optional[str] = None) -> bool: + """ + Reload configuration from file with graceful error handling. + If validation fails, keeps the current configuration and logs a warning. + + Args: + config_path: Optional path to configuration file. Uses env CONFIG_PATH if not provided. + + Returns: + bool: True if reload was successful, False if failed and kept old config + """ + global _config + current_config = _config + + try: + new_config = load_unified_config(config_path, exit_on_error=False) + if new_config.app.application.type == current_config.app.application.type: + _config = new_config + logger.info("Configuration reloaded successfully") + return True + else: + logger.warning("Application type changed, keeping current configuration") + return False + + except ConfigValidationError as e: + logger.warning("Configuration validation failed, keeping current configuration") + _config = current_config + return False + except Exception as e: + logger.warning(f"Configuration reload failed, keeping current configuration: {e}") + _config = current_config + return False + + +def force_reload_config(config_path: Optional[str] = None) -> UnifiedConfig: + """ + Force reload configuration from file (original behavior). + Useful for testing or when you want the process to exit on validation errors. + """ + global _config + _config = load_unified_config(config_path) + return _config diff --git a/app/config/environment.py b/app/config/environment.py new file mode 100644 index 0000000..70be810 --- /dev/null +++ b/app/config/environment.py @@ -0,0 +1,126 @@ +import os +from typing import List +from pydantic import BaseModel, Field, field_validator +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + + +class EnvironmentConfig(BaseModel): + """Environment-based configuration loaded from environment variables""" + + # Authentication tokens and secrets + slack_bot_user_oauth_token: str = Field( + default_factory=lambda: os.getenv('SLACK_BOT_USER_OAUTH_TOKEN', ''), + description="Slack Bot User OAuth Token" + ) + slack_verification_token: str = Field( + default_factory=lambda: os.getenv('SLACK_VERIFICATION_TOKEN', ''), + description="Slack Verification Token" + ) + mattermost_access_token: str = Field( + default_factory=lambda: os.getenv('MATTERMOST_ACCESS_TOKEN', ''), + description="Mattermost Access Token" + ) + telegram_bot_token: str = Field( + default_factory=lambda: os.getenv('TELEGRAM_BOT_TOKEN', ''), + description="Telegram Bot Token" + ) + + # Paths + data_path: str = Field( + default_factory=lambda: os.getenv('DATA_PATH', './data'), + description="Path to data directory" + ) + config_path: str = Field( + default_factory=lambda: os.getenv('CONFIG_PATH', './'), + description="Path to configuration directory" + ) + + # Provider settings (for Google Calendar integration) + provider_sync_interval: int = Field( + default_factory=lambda: int(os.getenv('CHAIN_PROVIDER_SYNC_INTERVAL_SECONDS', '60')), + description="Provider sync interval in seconds" + ) + provider_max_events: int = Field( + default_factory=lambda: int(os.getenv('CHAIN_PROVIDER_MAX_EVENTS', '10')), + description="Maximum events to sync from provider" + ) + provider_days_to_sync: int = Field( + default_factory=lambda: int(os.getenv('CHAIN_PROVIDER_DAYS_TO_SYNC', '7')), + description="Number of days to sync from provider" + ) + provider_service_account_file: str = Field( + default_factory=lambda: os.getenv('GOOGLE_SERVICE_ACCOUNT_FILE', './key.json'), + description="Path to Google service account file" + ) + + # CORS configuration + cors_allowed_origins: List[str] = Field( + default_factory=lambda: os.getenv('CORS_ALLOWED_ORIGINS', 'http://localhost:5000').split(','), + description="Comma-separated list of allowed CORS origins" + ) + + # Logging + log_level: str = Field( + default_factory=lambda: os.getenv('LOG_LEVEL', 'INFO'), + description="Logging level" + ) + + @field_validator('provider_sync_interval', 'provider_max_events', 'provider_days_to_sync') + @classmethod + def validate_positive_integers(cls, v): + """Validate that provider settings are positive integers""" + if v <= 0: + raise ValueError("Provider configuration values must be positive integers") + return v + + @field_validator('cors_allowed_origins') + @classmethod + def validate_cors_origins(cls, v): + """Clean up CORS origins by removing whitespace""" + return [origin.strip() for origin in v if origin.strip()] + + @field_validator('log_level') + @classmethod + def validate_log_level(cls, v): + """Validate log level is valid""" + valid_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] + if v.upper() not in valid_levels: + raise ValueError(f"Log level must be one of: {', '.join(valid_levels)}") + return v.upper() + + @property + def incidents_path(self) -> str: + """Computed property for incidents path""" + return f"{self.data_path}/incidents" + + @property + def config_file_path(self) -> str: + """Computed property for config file path""" + return f"{self.config_path}/impulse.yml" + + +# Global instance - created once and reused +_env_config: EnvironmentConfig = None + + +def get_environment_config() -> EnvironmentConfig: + """Get the singleton instance of environment configuration""" + global _env_config + if _env_config is None: + _env_config = EnvironmentConfig() + return _env_config + + +# Convenience function for common environment variables +def get_messenger_token(messenger_type: str) -> str: + """Get the appropriate token based on messenger type""" + env_config = get_environment_config() + token_map = { + 'slack': env_config.slack_bot_user_oauth_token, + 'mattermost': env_config.mattermost_access_token, + 'telegram': env_config.telegram_bot_token, + } + return token_map.get(messenger_type, '') diff --git a/app/config/loader.py b/app/config/loader.py new file mode 100644 index 0000000..c9bcbc7 --- /dev/null +++ b/app/config/loader.py @@ -0,0 +1,168 @@ +import os +import yaml +from typing import Dict, Any, Tuple +from pydantic import ValidationError + +from app.config.validation import ImpulseConfig, validate_config +from app.logging import logger + +class ConfigValidationError(Exception): + """Custom exception for configuration validation errors""" + + def __init__(self, message: str, validation_errors: list = None): + super().__init__(message) + self.validation_errors = validation_errors or [] + + +def load_and_validate_config(config_path: str = None) -> Tuple[ImpulseConfig, Dict[str, Any]]: + """ + Load and validate Impulse configuration from YAML file. + + Args: + config_path: Path to the configuration file. If None, uses CONFIG_PATH env var + + Returns: + tuple: (validated_config, raw_config_dict) + + Raises: + FileNotFoundError: If config file doesn't exist + yaml.YAMLError: If YAML parsing fails + ConfigValidationError: If validation fails + """ + # Determine config path + if config_path is None: + config_path = os.getenv('CONFIG_PATH', './') + '/impulse.yml' + + # Check if file exists + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + # Load YAML file + try: + with open(config_path, 'r', encoding='utf-8') as file: + raw_config = yaml.safe_load(file) + except yaml.YAMLError as e: + raise ConfigValidationError(f"YAML parsing failed: {e}") + except Exception as e: + raise ConfigValidationError(f"Failed to read config file: {e}") + + if raw_config is None: + raise ConfigValidationError("Configuration file is empty") + + # Validate configuration + try: + validated_config = validate_config(raw_config) + return validated_config, raw_config + except ValidationError as e: + error_details = [] + for error in e.errors(): + loc = " -> ".join(str(x) for x in error['loc']) + error_details.append(f" {loc}: {error['msg']}") + + error_message = f"Configuration validation failed:\n" + "\n".join(error_details) + raise ConfigValidationError(error_message, e.errors()) + + +def get_legacy_config_dict(validated_config: ImpulseConfig) -> Dict[str, Any]: + """ + Convert validated Pydantic config to legacy dictionary format for backward compatibility. + + Args: + validated_config: Validated Pydantic configuration object + + Returns: + dict: Configuration in legacy format + """ + # Convert to dict and restructure for legacy code + config_dict = validated_config.model_dump() + + # Legacy format adjustments + legacy_config = { + 'application': config_dict['application'], + 'incident': config_dict.get('incident', {}), + 'experimental': config_dict.get('experimental', {}), + 'route': config_dict['route'], + 'ui': config_dict.get('ui', {}), + 'webhooks': config_dict.get('webhooks', {}) + } + + # Ensure template_files is properly handled in application section + if 'template_files' not in legacy_config['application'] or legacy_config['application']['template_files'] is None: + legacy_config['application']['template_files'] = {} + + # Ensure incident defaults + if 'incident' not in legacy_config or legacy_config['incident'] is None: + legacy_config['incident'] = {} + + incident_defaults = { + 'alerts_firing_notifications': False, + 'alerts_resolved_notifications': False, + 'timeouts': { + 'firing': '6h', + 'unknown': '6h', + 'resolved': '12h' + } + } + + for key, default_value in incident_defaults.items(): + if key not in legacy_config['incident']: + legacy_config['incident'][key] = default_value + + # Ensure experimental defaults + if 'experimental' not in legacy_config or legacy_config['experimental'] is None: + legacy_config['experimental'] = {} + + experimental_defaults = { + 'recreate_chain': False + } + + for key, default_value in experimental_defaults.items(): + if key not in legacy_config['experimental']: + legacy_config['experimental'][key] = default_value + + return legacy_config + + +def format_validation_errors(errors: list) -> str: + """ + Format validation errors for user-friendly display. + + Args: + errors: List of validation error messages + + Returns: + Formatted error string + """ + if not errors: + return "No validation errors" + + formatted_errors = [] + for i, error in enumerate(errors, 1): + formatted_errors.append(f"{i}. {error}") + + return f"Configuration validation errors:\n" + "\n".join(formatted_errors) + + +def validate_config_and_show_errors(config_path: str = None) -> ImpulseConfig: + """ + Validate configuration and show user-friendly errors on failure. + + Args: + config_path: Path to the configuration file + + Returns: + Validated configuration object + + Raises: + SystemExit: If validation fails (after showing errors) + """ + try: + validated_config, _ = load_and_validate_config(config_path) + return validated_config + except FileNotFoundError as e: + logger.error(f"\nConfiguration file not found: {e}") + raise SystemExit(1) + except yaml.YAMLError as e: + logger.error(f"\nYAML parsing error: {e}") + raise SystemExit(1) + \ No newline at end of file diff --git a/app/config/validation.py b/app/config/validation.py new file mode 100644 index 0000000..3d71f78 --- /dev/null +++ b/app/config/validation.py @@ -0,0 +1,431 @@ +from typing import Dict, List, Optional, Union, Any +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal +from pydantic import BaseModel, Field, field_validator, model_validator +from enum import Enum +import re +import yaml +import os +from pathlib import Path + + +class ApplicationType(str, Enum): + """Supported application types""" + SLACK = "slack" + MATTERMOST = "mattermost" + TELEGRAM = "telegram" + + +class ChainType(str, Enum): + """Supported chain types""" + SCHEDULE = "schedule" + CLOUD = "cloud" + + +class CloudProvider(str, Enum): + """Supported cloud providers""" + GOOGLE = "google" + + +class DatetimeFormat(str, Enum): + """Supported datetime formats""" + ABSOLUTE = "absolute" + RELATIVE = "relative" + + +class SortOrder(str, Enum): + """Supported sort orders""" + ASC = "asc" + DESC = "desc" + NONE = "none" + + +class TelegramUser(BaseModel): + """Telegram user configuration""" + id: int = Field(..., description="User ID") + name: Optional[str] = Field(None, description="User display name") + username: Optional[str] = Field(None, description="Username") + + +class SlackUser(BaseModel): + """Slack user configuration""" + id: str = Field(..., description="User ID") + + +class MattermostUser(BaseModel): + """Mattermost user configuration""" + id: str = Field(..., description="User ID") + + +class TelegramChannel(BaseModel): + """Telegram channel configuration""" + id: int = Field(..., description="Channel ID") + name: Optional[str] = Field(None, description="Channel name") + + +class SlackChannel(BaseModel): + """Slack channel configuration""" + id: str = Field(..., description="Channel ID") + + +class MattermostChannel(BaseModel): + """Mattermost channel configuration""" + id: str = Field(..., description="Channel ID") + + +class SimpleChainStep(BaseModel): + """Base chain step""" + user: Optional[str] = Field(None, description="User to notify") + user_group: Optional[str] = Field(None, description="User group to notify") + webhook: Optional[str] = Field(None, description="Webhook to call") + chain: Optional[str] = Field(None, description="Nested chain to execute") + wait: Optional[str] = Field(None, description="Wait duration (e.g., '5m', '1h')") + + @model_validator(mode='after') + def validate_step_type(self): + """Validate that exactly one step type is specified""" + fields = [self.user, self.user_group, self.webhook, self.chain, self.wait] + non_none_fields = [f for f in fields if f is not None] + + if len(non_none_fields) != 1: + raise ValueError("Exactly one of user, user_group, webhook, chain, or wait must be specified") + + return self + + @field_validator('wait') + @classmethod + def validate_wait_format(cls, v): + """Validate wait duration format""" + if v is None: + return v + + # Check format like "5m", "1h", "30s", "2d" + if not re.match(r'^\d+[smhd]$', v): + raise ValueError("Wait duration must be in format like '5m', '1h', '30s', or '2d'") + + return v + + +class ScheduleMatcherExpression(BaseModel): + """Schedule matcher expression - fully flexible""" + expr: str = Field(..., description="Custom expression") + start_time: Any = Field(..., description="Start time in any format") + duration: Any = Field(..., description="Duration in any format") + + +class ScheduleEntry(BaseModel): + """Schedule entry configuration""" + matcher: Optional[List[ScheduleMatcherExpression]] = Field(None, description="List of matcher expressions") + steps: List[SimpleChainStep] = Field(..., description="Chain steps") + + +class SimpleChain(BaseModel): + """Simple chain configuration - just a list of steps""" + pass # This will be handled as List[SimpleChainStep] directly + + +class ScheduleChain(BaseModel): + """Schedule chain configuration""" + type: Literal[ChainType.SCHEDULE] = Field(..., description="Chain type") + timezone: Optional[str] = Field("UTC", description="Timezone") + schedule: List[ScheduleEntry] = Field(..., description="Schedule entries") + + +class CloudChain(BaseModel): + """Cloud chain configuration""" + type: Literal[ChainType.CLOUD] = Field(..., description="Chain type") + provider: CloudProvider = Field(..., description="Cloud provider") + calendar_id: str = Field(..., description="Calendar ID") + default_steps: Optional[List[SimpleChainStep]] = Field(None, description="Default steps") + + +class UserGroup(BaseModel): + """User group configuration""" + users: List[str] = Field(..., description="List of user names") + + +class TemplateFiles(BaseModel): + """Template files configuration""" + status_icons: Optional[str] = Field(None, description="Status icons template path") + header: Optional[str] = Field(None, description="Header template path") + body: Optional[str] = Field(None, description="Body template path") + + def get(self, key: str, default: str = None) -> str: + return getattr(self, key) or default + + +class ApplicationConfig(BaseModel): + """Application configuration""" + type: ApplicationType = Field(..., description="Application type") + channels: Dict[str, Union[SlackChannel, MattermostChannel, TelegramChannel]] = Field(..., description="Channel definitions") + users: Dict[str, Union[SlackUser, MattermostUser, TelegramUser]] = Field(..., description="User definitions") + admin_users: List[str] = Field(..., description="Admin users") + user_groups: Optional[Dict[str, UserGroup]] = Field(None, description="User groups") + chains: Optional[Dict[str, Any]] = Field(None, description="Chain definitions") + + # Type-specific fields + address: Optional[str] = Field(None, description="Mattermost server address") + team: Optional[str] = Field(None, description="Mattermost team name") + impulse_address: Optional[str] = Field(None, description="Impulse callback address") + template_files: Optional[TemplateFiles] = Field(TemplateFiles(status_icons=None, header=None, body=None), description="Template files") + + @model_validator(mode='after') + def validate_type_specific_fields(self): + """Validate type-specific required fields""" + if self.type == ApplicationType.MATTERMOST: + if not self.address: + raise ValueError("'address' is required for Mattermost") + if not self.team: + raise ValueError("'team' is required for Mattermost") + if not self.impulse_address: + raise ValueError("'impulse_address' is required for Mattermost") + + elif self.type == ApplicationType.TELEGRAM: + if not self.impulse_address: + raise ValueError("'impulse_address' is required for Telegram") + + return self + + @field_validator('admin_users') + @classmethod + def validate_admin_users_exist(cls, v, info): + """Validate that admin users exist in users""" + if 'users' in info.data and info.data['users']: + for admin_user in v: + if admin_user not in info.data['users']: + raise ValueError(f"Admin user '{admin_user}' not found in users") + return v + + @field_validator('chains') + @classmethod + def validate_chains_structure_and_references(cls, v, info): + """Validate chain structure and references""" + if v is None: + return v + + users = info.data.get('users', {}) + user_groups = info.data.get('user_groups', {}) + + def validate_chain_steps(steps): + if not isinstance(steps, list): + return + for step in steps: + if isinstance(step, dict): + if 'user' in step and step['user'] and users and step['user'] not in users: + raise ValueError(f"User '{step['user']}' in chain not found in users") + if 'user_group' in step and step['user_group'] and user_groups and step['user_group'] not in user_groups: + raise ValueError(f"User group '{step['user_group']}' in chain not found in user_groups") + if 'chain' in step and step['chain'] and step['chain'] not in v: + raise ValueError(f"Nested chain '{step['chain']}' not found in chains") + + validated_chains = {} + + for chain_name, chain_config in v.items(): + if isinstance(chain_config, list): + # Simple chain - validate steps + validated_steps = [] + for step in chain_config: + validated_steps.append(SimpleChainStep(**step)) + validated_chains[chain_name] = validated_steps + validate_chain_steps(chain_config) + + elif isinstance(chain_config, dict): + if chain_config.get('type') == 'schedule': + # Schedule chain + validated_chains[chain_name] = ScheduleChain(**chain_config) + if 'schedule' in chain_config: + for schedule_entry in chain_config['schedule']: + if isinstance(schedule_entry, dict) and 'steps' in schedule_entry: + validate_chain_steps(schedule_entry['steps']) + + elif chain_config.get('type') == 'cloud': + # Cloud chain + validated_chains[chain_name] = CloudChain(**chain_config) + if 'default_steps' in chain_config: + validate_chain_steps(chain_config['default_steps']) + + else: + raise ValueError(f"Unknown chain type for chain '{chain_name}': {chain_config.get('type')}") + + return validated_chains + + +class IncidentTimeouts(BaseModel): + """Incident timeout configuration""" + firing: Optional[str] = Field("6h", description="Firing timeout") + unknown: Optional[str] = Field("6h", description="Unknown timeout") + resolved: Optional[str] = Field("12h", description="Resolved timeout") + + def get(self, key: str) -> str: + return getattr(self, key) or None + + +class IncidentConfig(BaseModel): + """Incident configuration""" + alerts_firing_notifications: Optional[bool] = Field(False, description="Enable firing notifications") + alerts_resolved_notifications: Optional[bool] = Field(False, description="Enable resolved notifications") + timeouts: Optional[IncidentTimeouts] = Field(None, description="Incident timeouts") + + +class ExperimentalConfig(BaseModel): + """Experimental configuration""" + recreate_chain: Optional[bool] = Field(False, description="Recreate chain on new alerts") + + +class RouteConfig(BaseModel): + """Route configuration""" + channel: str = Field(..., description="Default channel") + chain: Optional[str] = Field(None, description="Default chain") + matchers: Optional[List[str]] = Field(None, description="Route matchers") + routes: Optional[List['RouteConfig']] = Field(None, description="Nested routes") + + +class UIColumn(BaseModel): + """UI column configuration""" + name: str = Field(..., description="Column name") + header: str = Field(..., description="Column header") + value: str = Field(..., description="Column value path") + type: Optional[str] = Field("string", description="Column type (string, datetime, link, etc.)") + visible: Optional[bool] = Field(True, description="Column visibility") + url: Optional[str] = Field(None, description="URL for link type") + format: Optional[str] = Field("relative", description="Datetime format (absolute, relative)") + + @model_validator(mode='after') + def validate_link_type(self): + """Validate link type requirements""" + if self.type == "link" and not self.url: + raise ValueError("'url' is required when type is 'link'") + return self + + +class UIConfig(BaseModel): + """UI configuration""" + columns: List[UIColumn] = Field(..., description="Column configurations") + colors: Optional[Dict[str, Dict[str, str]]] = Field({None}, description="Color configurations") + filters: Optional[List[str]] = Field(None, description="Default filters") + sorting: Optional[List[Dict[str, Union[str, List[str]]]]] = Field(None, description="Sort rules") + + @field_validator('sorting') + @classmethod + def validate_sorting_format(cls, v): + """Validate sorting format""" + if v is None: + return v + + for sort_rule in v: + if not isinstance(sort_rule, dict): + raise ValueError("Each sorting rule must be a dictionary") + + # Each dictionary should have exactly one column name as key + column_keys = [k for k in sort_rule.keys() if k != 'order'] + if len(column_keys) != 1: + raise ValueError("Each sorting rule must have exactly one column name as key") + + column_name = column_keys[0] + sort_order = sort_rule[column_name] + + if sort_order not in ['asc', 'desc', 'none']: + raise ValueError(f"Sort order must be 'asc', 'desc', or 'none', got: {sort_order}") + + # If sort order is 'none', order field is required + if sort_order == 'none' and 'order' not in sort_rule: + raise ValueError("'order' field is required when sort order is 'none'") + + return v + + +class WebhookConfig(BaseModel): + """Webhook configuration""" + url: str = Field(..., description="Webhook URL") + data: Optional[Dict[str, Any]] = Field(None, description="Webhook data") + auth: Optional[str] = Field(None, description="HTTP Basic Auth") + + +class ImpulseConfig(BaseModel): + """Main Impulse configuration""" + application: ApplicationConfig = Field(..., description="Application configuration") + incident: Optional[IncidentConfig] = Field(None, description="Incident configuration") + experimental: Optional[ExperimentalConfig] = Field(None, description="Experimental configuration") + route: RouteConfig = Field(..., description="Route configuration") + ui: Optional[UIConfig] = Field(None, description="UI configuration") + webhooks: Optional[Dict[str, WebhookConfig]] = Field(None, description="Webhook configurations") + + @model_validator(mode='after') + def validate_route_channel_exists(self): + """Validate that route channels exist in application channels""" + def validate_route_channels(route_config): + if route_config.channel not in self.application.channels: + raise ValueError(f"Route channel '{route_config.channel}' not found in application channels") + + if route_config.routes: + for nested_route in route_config.routes: + validate_route_channels(nested_route) + + validate_route_channels(self.route) + return self + + @model_validator(mode='after') + def validate_route_chain_exists(self): + """Validate that route chains exist in application chains""" + if not self.application.chains: + return self + + chains = self.application.chains + + def validate_route_chain(route_config): + if hasattr(route_config, 'chain') and route_config.chain: + if route_config.chain not in chains: + raise ValueError(f"Route chain '{route_config.chain}' not found in application chains") + + if hasattr(route_config, 'routes') and route_config.routes: + for nested_route in route_config.routes: + validate_route_chain(nested_route) + + validate_route_chain(self.route) + return self + + +# Update forward references +RouteConfig.model_rebuild() + + +def validate_config(config_dict: dict) -> ImpulseConfig: + """ + Validate configuration dictionary using Pydantic models. + + Args: + config_dict: Dictionary containing configuration data + + Returns: + ImpulseConfig: Validated configuration object + + Raises: + pydantic.ValidationError: If validation fails + """ + return ImpulseConfig(**config_dict) + + +def validate_config_file(config_path: str) -> ImpulseConfig: + """ + Load and validate configuration from YAML file. + + Args: + config_path: Path to the YAML configuration file + + Returns: + ImpulseConfig: Validated configuration object + + Raises: + FileNotFoundError: If config file doesn't exist + yaml.YAMLError: If YAML parsing fails + pydantic.ValidationError: If validation fails + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + with open(config_path, 'r', encoding='utf-8') as file: + config_dict = yaml.safe_load(file) + + return validate_config(config_dict) diff --git a/app/im/application.py b/app/im/application.py index 4fef148..7a3eb01 100644 --- a/app/im/application.py +++ b/app/im/application.py @@ -4,24 +4,27 @@ import aiohttp from aiohttp import ClientTimeout, ClientSession from aiohttp_retry import ExponentialRetry, RetryClient +from typing import Union, Dict from app.im.chain.chain_factory import ChainFactory from app.im.groups import generate_user_groups from app.im.template import JinjaTemplate, notification_user, notification_user_group, update_status from app.logging import logger +from app.config.validation import ApplicationConfig, MattermostUser, SlackUser, TelegramUser + class Application(ABC): - def __init__(self, app_config, channels, default_channel): + def __init__(self, app_config: ApplicationConfig, channels, default_channel): self.http = None # Will be initialized async - self.type = app_config['type'] + self.type = app_config.type self.url = self.get_url(app_config) self.public_url = None # Will be set in async initialization self.team = self.get_team_name(app_config) self._app_config = app_config # Store for async initialization - self.chains = ChainFactory.generate(app_config.get('chains', dict())) - self.templates = app_config.get('template_files', dict()) + self.chains = ChainFactory.generate(app_config.chains) + self.templates = app_config.template_files self.body_template, self.header_template, self.status_icons_template = self.generate_template() # Application-specific parameters @@ -38,9 +41,9 @@ def __init__(self, app_config, channels, default_channel): self.admin_users = None # Will be initialized async # Store config for async initialization - self._users_config = app_config['users'] - self._user_groups_config = app_config.get('user_groups') - self._admin_users_config = app_config['admin_users'] + self._users_config = app_config.users + self._user_groups_config = app_config.user_groups + self._admin_users_config = app_config.admin_users async def initialize_async(self): """Initialize async components after object creation""" @@ -59,18 +62,18 @@ async def close(self): if self.http: await self.http.close() - def get_url(self, app_config): + def get_url(self, app_config: ApplicationConfig): return self._get_url(app_config) - def get_team_name(self, app_config): + def get_team_name(self, app_config: ApplicationConfig): return self._get_team_name(app_config) - async def _generate_users(self, users_dict): + async def _generate_users(self, users_dict: Dict[str, Union[SlackUser, MattermostUser, TelegramUser]]): logger.info(f'Creating users') users = dict() for name, user_info in users_dict.items(): - if user_info.get('id') is not None: + if user_info.id is not None: user_details = await self.get_user_details(user_info) if not user_details['exists']: logger.warning(f'.. user {name} not found in {self.type.capitalize()} and will not be notified') @@ -207,16 +210,16 @@ def _markdown_links_to_native_format(self, text): pass @abstractmethod - def _get_url(self, app_config): + def _get_url(self, app_config: ApplicationConfig): pass @abstractmethod - def _get_public_url(self, app_config): + def _get_public_url(self, app_config: ApplicationConfig): """Get the public URL of the application to share with users.""" pass @abstractmethod - def _get_team_name(self, app_config): + def _get_team_name(self, app_config: ApplicationConfig): pass @abstractmethod @@ -260,7 +263,7 @@ async def _update_thread(self, id_, payload): pass @abstractmethod - async def get_user_details(self, user_details): + async def get_user_details(self, user_info: Union[SlackUser, MattermostUser, TelegramUser]): """Fetch user-specific details (ID, name, etc.) from the system. Must be implemented by subclasses.""" pass diff --git a/app/im/chain/google_calendar_chain.py b/app/im/chain/google_calendar_chain.py index cbbb104..47d5046 100644 --- a/app/im/chain/google_calendar_chain.py +++ b/app/im/chain/google_calendar_chain.py @@ -12,12 +12,15 @@ from app.im.chain.schedule_chain import ScheduleChain from app.logging import logger from app.tools import HTMLTextExtractor -from config import provider_sync_interval, provider_max_events, provider_days_to_sync, provider_service_account_file +from app.config.config import get_config class GoogleCalendarChain(ScheduleChain): def __init__(self, name, config: dict): super().__init__(name) + + # Get environment configuration + self._app_config = get_config() self.calendar_id = config.get('calendar_id') if not self.calendar_id: @@ -126,7 +129,7 @@ def cleanup(self) -> None: def _load_credentials(self) -> None: """Load service account credentials from JSON file.""" try: - with open(provider_service_account_file, 'r') as f: + with open(self._app_config.provider_service_account_file, 'r') as f: self.credentials = json.load(f) # Validate required fields required_fields = ['client_email', 'private_key', 'token_uri'] @@ -134,9 +137,9 @@ def _load_credentials(self) -> None: if field not in self.credentials: raise ValueError(f"Missing required field '{field}' in service account file") except FileNotFoundError: - raise ValueError(f"Service account file {provider_service_account_file} not found") + raise ValueError(f"Service account file {self._app_config.provider_service_account_file} not found") except json.JSONDecodeError: - raise ValueError(f"Invalid JSON in service account file {provider_service_account_file}") + raise ValueError(f"Invalid JSON in service account file {self._app_config.provider_service_account_file}") def _get_access_token(self) -> str: """Get access token using JWT with retry logic.""" @@ -223,12 +226,12 @@ def _fetch_events(self) -> List[Dict[str, Any]]: token = self._get_access_token() date_from = datetime.datetime.utcnow() - date_to = date_from + datetime.timedelta(days=provider_days_to_sync) + date_to = date_from + datetime.timedelta(days=self._app_config.provider_days_to_sync) params = { 'timeMin': date_from.isoformat() + 'Z', 'timeMax': date_to.isoformat() + 'Z', - 'maxResults': provider_max_events, + 'maxResults': self._app_config.provider_max_events, 'singleEvents': 'true', 'orderBy': 'startTime' } @@ -302,19 +305,16 @@ async def _sync_calendar(self) -> None: """Periodically sync calendar events with error recovery.""" while True: try: - # First sync the timezone calendar_timezone = self._get_calendar_timezone() self._update_timezone(calendar_timezone) - # Then sync events events = self._fetch_events() self._update_schedule(events) logger.debug(f"Synced {len(events)} events from Google Calendar") except Exception as e: logger.error(f"Error syncing Google Calendar: {str(e)}") - # Add exponential backoff for errors - await asyncio.sleep(min(provider_sync_interval * 2, 300)) # Max 5 minutes + await asyncio.sleep(min(self._app_config.provider_sync_interval * 2, 300)) # Max 5 minutes continue - await asyncio.sleep(provider_sync_interval) + await asyncio.sleep(self._app_config.provider_sync_interval) diff --git a/app/im/channels.py b/app/im/channels.py index 31e9a3c..9c75e16 100644 --- a/app/im/channels.py +++ b/app/im/channels.py @@ -1,17 +1,19 @@ from app.logging import logger +from typing import Dict, Union +from app.config.validation import SlackChannel, MattermostChannel, TelegramChannel -def check_channels(channels_list, channels_config, default_channel): +def check_channels(channels_list, channels_config: Dict[str, Union[SlackChannel, MattermostChannel, TelegramChannel]], default_channel): logger.info(f'Checking all channels defined') for channel in channels_list: if channel not in channels_config.keys(): logger.warning(f'.. channel {channel} not defined. Using default channel instead') - channels_config[channel] = {'id': channels_config.get(default_channel)['id']} + channels_config[channel] = {'id': channels_config.get(default_channel).id} else: if 'id' not in channels_config[channel]: logger.warning(f'.. channel \'{channel}\' has no \'id\'. Using default channel instead') - channels_config[channel] = {'id': channels_config.get(default_channel)['id']} - elif channels_config[channel].get('id') is None: + channels_config[channel] = {'id': channels_config.get(default_channel).id} + elif channels_config[channel].id is None: logger.warning(f'.. channel {channel} \'id\' is empty. Using default channel instead') - channels_config[channel] = {'id': channels_config.get(default_channel)['id']} + channels_config[channel] = {'id': channels_config.get(default_channel).id} return channels_config diff --git a/app/im/helpers.py b/app/im/helpers.py index 76b16b2..da4aa0e 100644 --- a/app/im/helpers.py +++ b/app/im/helpers.py @@ -1,10 +1,11 @@ from app.im.mattermost.mattermost_application import MattermostApplication from app.im.slack.slack_application import SlackApplication from app.im.telegram.telegram_application import TelegramApplication +from app.config.validation import ApplicationConfig -def get_application(app_config, channels, default_channel): - app_type = app_config['type'] +def get_application(app_config: ApplicationConfig, channels, default_channel): + app_type = app_config.type if app_type == 'slack': return SlackApplication(app_config, channels, default_channel) elif app_type == 'mattermost': diff --git a/app/im/mattermost/config.py b/app/im/mattermost/config.py index 9197c1f..bb77ad2 100644 --- a/app/im/mattermost/config.py +++ b/app/im/mattermost/config.py @@ -1,11 +1,5 @@ from jinja2 import Environment -from config import mattermost_access_token - -mattermost_headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {mattermost_access_token}', -} buttons = { # styles: good, warning, danger, default, primary, and success 'chain': { diff --git a/app/im/mattermost/mattermost_application.py b/app/im/mattermost/mattermost_application.py index c31f158..4b94d53 100644 --- a/app/im/mattermost/mattermost_application.py +++ b/app/im/mattermost/mattermost_application.py @@ -5,22 +5,27 @@ from app.im.application import Application from app.im.colors import status_colors -from app.im.mattermost.config import (mattermost_headers, mattermost_request_delay, mattermost_bold_text, +from app.im.mattermost.config import (mattermost_request_delay, mattermost_bold_text, mattermost_env, mattermost_admins_template_string) from app.im.mattermost.threads import mattermost_get_create_thread_payload, mattermost_get_update_payload, \ mattermost_get_button_update_payload from app.im.mattermost.user import User from app.logging import logger +from app.config.config import get_config +from app.config.validation import ApplicationConfig, MattermostUser class MattermostApplication(Application): - def __init__(self, app_config, channels, default_channel): + def __init__(self, app_config: ApplicationConfig, channels, default_channel): super().__init__(app_config, channels, default_channel) def _initialize_specific_params(self): self.post_message_url = f'{self.url}/api/v4/posts' - self.headers = mattermost_headers + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {get_config().mattermost_access_token}', + } self.post_delay = mattermost_request_delay self.thread_id_key = 'id' @@ -39,18 +44,18 @@ async def _get_channels(self, team): logger.error(f'Failed to retrieve channel list: {e}') return {} - def _get_url(self, app_config): - return app_config['address'] + def _get_url(self, app_config: ApplicationConfig): + return app_config.address - def _get_public_url(self, app_config): - return app_config['address'] + def _get_public_url(self, app_config: ApplicationConfig): + return app_config.address - def _get_team_name(self, app_config): + def _get_team_name(self, app_config: ApplicationConfig): logger.info(f'Get {self.type.capitalize()} team name') - return app_config['team'] + return app_config.team - async def get_user_details(self, user_details): - id_ = user_details.get('id') if user_details is not None else None + async def get_user_details(self, user_info: MattermostUser): + id_ = user_info.id if user_info is not None else None if id_ is not None: async with self.http.get(f'{self.url}/api/v4/users/{id_}?user_id={id_}', headers=self.headers) as response: data = await response.json() @@ -161,7 +166,7 @@ def update_thread_payload(self, channel_id, id_, body, header, status_icons, sta async def _update_thread(self, id_, payload): async with self.http.put( f'{self.url}/api/v4/posts/{id_}', - headers=mattermost_headers, + headers=self.headers, json=payload ) as response: await asyncio.sleep(self.post_delay) diff --git a/app/im/mattermost/threads.py b/app/im/mattermost/threads.py index 51251c5..bd3613e 100644 --- a/app/im/mattermost/threads.py +++ b/app/im/mattermost/threads.py @@ -1,6 +1,6 @@ from app.im.colors import status_colors from app.im.mattermost.config import buttons -from config import application +from app.config.config import get_config def chain_attrs(chain_enabled, status): @@ -18,6 +18,7 @@ def chain_attrs(chain_enabled, status): def mattermost_get_button_update_payload(body, header, status_icons, status, chain_enabled, status_enabled): + config = get_config() payload = { 'update': { 'message': f'{status_icons} {header}', @@ -34,7 +35,7 @@ def mattermost_get_button_update_payload(body, header, status_icons, status, cha "name": chain_attrs(chain_enabled, status)[0], "style": chain_attrs(chain_enabled, status)[1], "integration": { - "url": f"{application.get('impulse_address')}/app", + "url": f"{config.application.impulse_address}/app", "context": { "action": "chain" } @@ -48,7 +49,7 @@ def mattermost_get_button_update_payload(body, header, status_icons, status, cha "style": buttons['status']['enabled']['style'] if status_enabled else buttons['status']['disabled']['style'], "integration": { - "url": f"{application.get('impulse_address')}/app", + "url": f"{config.application.impulse_address}/app", "context": { "action": "status" } @@ -65,6 +66,7 @@ def mattermost_get_button_update_payload(body, header, status_icons, status, cha def mattermost_get_update_payload(channel_id, thread_id, body, header, status_icons, status, chain_enabled, status_enabled): + config = get_config() payload = { 'channel_id': channel_id, 'id': thread_id, @@ -82,7 +84,7 @@ def mattermost_get_update_payload(channel_id, thread_id, body, header, status_ic "name": chain_attrs(chain_enabled, status)[0], "style": chain_attrs(chain_enabled, status)[1], "integration": { - "url": f"{application.get('impulse_address')}/app", + "url": f"{config.application.impulse_address}/app", "context": { "action": "chain" } @@ -96,7 +98,7 @@ def mattermost_get_update_payload(channel_id, thread_id, body, header, status_ic "style": buttons['status']['enabled']['style'] if status_enabled else buttons['status']['disabled']['style'], "integration": { - "url": f"{application.get('impulse_address')}/app", + "url": f"{config.application.impulse_address}/app", "context": { "action": "status" } @@ -111,6 +113,7 @@ def mattermost_get_update_payload(channel_id, thread_id, body, header, status_ic def mattermost_get_create_thread_payload(channel_id, body, header, status_icons, status): + config = get_config() payload = { 'channel_id': channel_id, 'message': f'{status_icons} {header}', @@ -127,7 +130,7 @@ def mattermost_get_create_thread_payload(channel_id, body, header, status_icons, "name": buttons['chain']['takeit']['text'], "style": buttons['chain']['takeit']['style'], "integration": { - "url": f"{application.get('impulse_address')}/app", + "url": f"{config.application.impulse_address}/app", "context": { "action": "chain" } @@ -139,7 +142,7 @@ def mattermost_get_create_thread_payload(channel_id, body, header, status_icons, "name": buttons['status']['enabled']['text'], "style": buttons['status']['enabled']['style'], "integration": { - "url": f"{application.get('impulse_address')}/app", + "url": f"{config.application.impulse_address}/app", "context": { "action": "status" } diff --git a/app/im/slack/config.py b/app/im/slack/config.py index 43dae83..d79ce25 100644 --- a/app/im/slack/config.py +++ b/app/im/slack/config.py @@ -1,11 +1,5 @@ from jinja2 import Environment -from config import slack_bot_user_oauth_token - -slack_headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {slack_bot_user_oauth_token}', -} buttons = { # styles: normal, danger, primary 'chain': { diff --git a/app/im/slack/slack_application.py b/app/im/slack/slack_application.py index fea9623..379c507 100644 --- a/app/im/slack/slack_application.py +++ b/app/im/slack/slack_application.py @@ -7,42 +7,46 @@ from app.im.application import Application from app.im.colors import status_colors from app.im.slack import reformat_message -from app.im.slack.config import slack_headers, slack_request_delay, slack_bold_text, slack_env, \ +from app.im.slack.config import slack_request_delay, slack_bold_text, slack_env, \ slack_admins_template_string from app.im.slack.threads import slack_get_create_thread_payload, slack_get_update_payload from app.im.slack.user import User from app.logging import logger -from config import slack_verification_token +from app.config.config import get_config +from app.config.validation import ApplicationConfig, SlackUser class SlackApplication(Application): - def __init__(self, app_config, channels, default_channel): + def __init__(self, app_config: ApplicationConfig, channels, default_channel): super().__init__(app_config, channels, default_channel) def _initialize_specific_params(self): self.post_message_url = f'{self.url}/api/chat.postMessage' - self.headers = slack_headers + self.headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {get_config().slack_bot_user_oauth_token}', + } self.post_delay = slack_request_delay self.thread_id_key = 'ts' - def _get_url(self, app_config): + def _get_url(self, app_config: ApplicationConfig): return 'https://slack.com' - async def _get_public_url(self, app_config): + async def _get_public_url(self, app_config: ApplicationConfig): async with self.http.get( f'https://slack.com/api/auth.test', - headers=slack_headers + headers=self.headers ) as response: await asyncio.sleep(slack_request_delay) json_ = await response.json() return json_.get('url') - def _get_team_name(self, app_config): + def _get_team_name(self, app_config: ApplicationConfig): return None - async def get_user_details(self, user_details): - id_ = user_details.get('id') if user_details is not None else None + async def get_user_details(self, user_info: SlackUser): + id_ = user_info.id if user_info is not None else None if id_ is not None: async with self.http.get(f'{self.url}/api/users.info?user={id_}', headers=self.headers) as response: data = await response.json() @@ -99,7 +103,8 @@ async def send_message(self, channel_id, text, attachment): return response_json.get('ts') async def buttons_handler(self, payload, incidents, queue_, route): - if payload.get('token') != slack_verification_token: + config = get_config() + if payload.get('token') != config.slack_verification_token: logger.error(f'Unauthorized request to \'/slack\'') return JSONResponse({}, status_code=401) @@ -153,7 +158,7 @@ def update_thread_payload(self, channel_id, id_, body, header, status_icons, sta async def _update_thread(self, id_, payload): async with self.http.post( f'{self.url}/api/chat.update', - headers=slack_headers, + headers=self.headers, json=payload ) as response: await asyncio.sleep(self.post_delay) diff --git a/app/im/telegram/telegram_application.py b/app/im/telegram/telegram_application.py index 9946b81..cda0fca 100644 --- a/app/im/telegram/telegram_application.py +++ b/app/im/telegram/telegram_application.py @@ -7,7 +7,8 @@ from app.im.telegram.config import buttons from app.im.telegram.user import User from app.logging import logger -from config import telegram_bot_token, application +from app.config.config import get_config +from app.config.validation import ApplicationConfig, TelegramUser class TelegramApplication(Application): @@ -18,11 +19,11 @@ class TelegramApplication(Application): '5408906741125490282': '🏁' # closed } - def __init__(self, app_config, channels, users): + def __init__(self, app_config: ApplicationConfig, channels, users): super().__init__(app_config, channels, users) def _initialize_specific_params(self): - self.url += telegram_bot_token + self.url += get_config().telegram_bot_token self.post_message_url = self.url + '/sendMessage' self.headers = {'Content-Type': 'application/json'} self.post_delay = 0.15 @@ -33,13 +34,13 @@ async def initialize_async(self): await super().initialize_async() await self._setup_webhook() - def _get_url(self, app_config): + def _get_url(self, app_config: ApplicationConfig): return 'https://api.telegram.org/bot' - def _get_public_url(self, app_config): + def _get_public_url(self, app_config: ApplicationConfig): return 'https://api.telegram.org/bot' - def _get_team_name(self, app_config): + def _get_team_name(self, app_config: ApplicationConfig): return None def get_notification_destinations(self): @@ -244,8 +245,8 @@ async def _update_thread(self, id_, payload): def _markdown_links_to_native_format(self, text): return text - async def get_user_details(self, user_details): - id_ = user_details.get('id') if user_details is not None else None + async def get_user_details(self, user_info: TelegramUser): + id_ = user_info.id if user_info is not None else None exists = False if id_ is not None: try: @@ -269,10 +270,11 @@ def create_user(self, name, user_details): ) async def _setup_webhook(self): + config = get_config() try: async with self.http.post( f'{self.url}/setWebhook', - params={'url': f"{application.get('impulse_address')}/app"}, + params={'url': f"{config.application.impulse_address}/app"}, headers=self.headers ) as response: await asyncio.sleep(self.post_delay) diff --git a/app/incident/incident.py b/app/incident/incident.py index 996e0c7..340a6f6 100644 --- a/app/incident/incident.py +++ b/app/incident/incident.py @@ -10,7 +10,7 @@ from app.tools import NoAliasDumper from app.ui.websocket import incident_ws from app.utils import get_attr_by_key_chain, normalize_param, filter_dict_keys -from config import incidents_path, incident, INCIDENT_ACTUAL_VERSION +from app.config.config import get_config from app.logging import logger @@ -36,7 +36,7 @@ class Incident: status_enabled: bool = False updated: datetime = field(default_factory=datetime.utcnow) created: datetime = field(default_factory=datetime.utcnow) - version: str = INCIDENT_ACTUAL_VERSION + version: str = get_config().INCIDENT_ACTUAL_VERSION uuid: str = field(init=False) ts: str = field(default='') link: str = field(default='') @@ -148,14 +148,15 @@ def set_next_status(self): return self.update_status(new_status) @classmethod - def load(cls, dump_file: str, config: IncidentConfig): + def load(cls, dump_file: str, incident_config: IncidentConfig): + config = get_config() with open(dump_file, 'r') as f: content = yaml.load(f, Loader=yaml.CLoader) incident_ = cls( last_state=content.get('last_state'), status=content.get('status'), channel_id=content.get('channel_id'), - config=config, + config=incident_config, chain=content.get('chain', []), chain_enabled=content.get('chain_enabled', False), status_enabled=content.get('status_enabled', False), @@ -164,12 +165,13 @@ def load(cls, dump_file: str, config: IncidentConfig): created=content.get('created'), assigned_user_id=content.get('assigned_user_id', ''), assigned_user=content.get('assigned_user', ''), - version=content.get('version', INCIDENT_ACTUAL_VERSION) + version=content.get('version', config.INCIDENT_ACTUAL_VERSION) ) - incident_.set_thread(content.get('ts'), config.application_url) + incident_.set_thread(content.get('ts'), incident_config.application_url) return incident_ def dump(self): + config = get_config() data = { "chain_enabled": self.chain_enabled, "chain": self.chain, @@ -185,7 +187,7 @@ def dump(self): "assigned_user": self.assigned_user, "version": self.version } - with open(f'{incidents_path}/{self.uuid}.yml', 'w') as f: + with open(f'{config.incidents_path}/{self.uuid}.yml', 'w') as f: yaml.dump(data, f, NoAliasDumper, default_flow_style=False) # Schedule async websocket update import asyncio @@ -258,7 +260,9 @@ def update_status(self, status: str) -> bool: now = datetime.utcnow() self.updated = now if status != 'closed': - self.status_update_datetime = now + unix_sleep_to_timedelta(incident['timeouts'].get(status)) + config = get_config() + timeout_value = config.incident.timeouts.get(status) + self.status_update_datetime = now + unix_sleep_to_timedelta(timeout_value) if self.status != status: self.set_status(status) logger.debug(f'Incident {self.uuid} status updated to {status}') diff --git a/app/incident/incidents.py b/app/incident/incidents.py index e5c8470..f6cd4ad 100644 --- a/app/incident/incidents.py +++ b/app/incident/incidents.py @@ -5,7 +5,7 @@ from app.incident.incident import Incident, IncidentConfig from app.logging import logger from app.ui.websocket import incident_ws -from config import incidents_path, INCIDENT_ACTUAL_VERSION +from app.config.config import get_config class Incidents: @@ -23,10 +23,11 @@ def add(self, incident: Incident): self.by_uuid[incident.uuid] = incident def del_by_uuid(self, uuid_: str): + config = get_config() incident = self.by_uuid.pop(uuid_, None) if incident: try: - os.remove(f'{incidents_path}/{uuid_}.yml') + os.remove(f'{config.incidents_path}/{uuid_}.yml') logger.info(f'Incident {uuid_} closed. Link: {incident.link}') except FileNotFoundError: logger.error(f'Failed to delete incident file for uuid: {uuid_}. File not found.') @@ -50,28 +51,29 @@ def get_table(self, params): @classmethod def create_or_load(cls, application_type, application_url, application_team): + config = get_config() # Ensure the incidents directory exists or create it - if not os.path.exists(incidents_path): + if not os.path.exists(config.incidents_path): logger.info('Creating incidents directory') - os.makedirs(incidents_path) + os.makedirs(config.incidents_path) logger.info('Loading existing incidents') incidents = cls([]) # Walk through the directory and load each incident - for path, directories, files in os.walk(incidents_path): + for path, directories, files in os.walk(config.incidents_path): for filename in files: - config = IncidentConfig( + incident_config = IncidentConfig( application_type=application_type, application_url=application_url, application_team=application_team ) incident_ = Incident.load( - dump_file=f'{incidents_path}/{filename}', - config=config + dump_file=f'{config.incidents_path}/{filename}', + incident_config=incident_config ) - if incident_.version != INCIDENT_ACTUAL_VERSION: + if incident_.version != config.INCIDENT_ACTUAL_VERSION: cls.update_incident(incident_) incidents.add(incident_) @@ -79,6 +81,7 @@ def create_or_load(cls, application_type, application_url, application_team): @staticmethod def update_incident(incident: Incident): - logger.info(f'Updating incident with uuid {incident.uuid} to version {INCIDENT_ACTUAL_VERSION}') - incident.version = INCIDENT_ACTUAL_VERSION + config = get_config() + logger.info(f'Updating incident with uuid {incident.uuid} to version {config.INCIDENT_ACTUAL_VERSION}') + incident.version = config.INCIDENT_ACTUAL_VERSION incident.dump() diff --git a/app/logging.py b/app/logging.py index 81bd1da..376b7fc 100644 --- a/app/logging.py +++ b/app/logging.py @@ -1,6 +1,5 @@ import logging - -from config import log_level +import os class CustomFormatter(logging.Formatter): @@ -39,10 +38,10 @@ def configure_uvicorn_logging(): """Configure uvicorn and FastAPI loggers to use our custom formatter and appropriate levels""" uvicorn_access_logger = logging.getLogger("uvicorn.access") uvicorn_access_logger.setLevel(logging.WARNING) - + uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.setLevel(logging.WARNING) - + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error"]: logger_obj = logging.getLogger(logger_name) logger_obj.handlers.clear() @@ -52,5 +51,13 @@ def configure_uvicorn_logging(): logger_obj.propagate = False -log_level = getattr(logging, log_level.upper(), logging.INFO) +try: + from app.config.environment import get_environment_config + env_config = get_environment_config() + log_level = getattr(logging, env_config.log_level, logging.INFO) +except ImportError: + # Fallback for cases where the config system isn't available yet + log_level = os.getenv('LOG_LEVEL', default='INFO') + log_level = getattr(logging, log_level.upper(), logging.INFO) + logger = create_logger('main_logger', log_level) diff --git a/app/queue/handlers/alert_handler.py b/app/queue/handlers/alert_handler.py index 3c8fc71..6f54a55 100644 --- a/app/queue/handlers/alert_handler.py +++ b/app/queue/handlers/alert_handler.py @@ -5,7 +5,7 @@ from app.logging import logger from app.queue.handlers.base_handler import BaseHandler from app.time import unix_sleep_to_timedelta -from config import INCIDENT_ACTUAL_VERSION, incident, experimental +from app.config.config import get_config class AlertHandler(BaseHandler): @@ -33,14 +33,17 @@ async def handle(self, alert_state): await self._handle_update(incident_.uuid, incident_, alert_state) async def _handle_create(self, alert_state): + config = get_config() + channel_name, chain_name = self.route.get_route(alert_state) channel = self.app.channels[channel_name] status = alert_state['status'] updated_datetime = datetime.utcnow() - status_update_datetime = datetime.utcnow() + unix_sleep_to_timedelta(incident['timeouts'].get(status)) + timeout_value = config.incident.timeouts.get(status) + status_update_datetime = datetime.utcnow() + unix_sleep_to_timedelta(timeout_value) - config = IncidentConfig( + incident_config = IncidentConfig( application_type=self.app.type, application_url=self.app.url, application_team=self.app.team @@ -49,7 +52,7 @@ async def _handle_create(self, alert_state): last_state=alert_state, status=status, channel_id=channel['id'], - config=config, + config=incident_config, chain=[], chain_enabled=True, status_enabled=True, @@ -57,7 +60,7 @@ async def _handle_create(self, alert_state): status_update_datetime=status_update_datetime, assigned_user_id="", assigned_user="", - version=INCIDENT_ACTUAL_VERSION + version=config.INCIDENT_ACTUAL_VERSION ) await self._create_thread(incident_, alert_state) incident_.dump() @@ -74,6 +77,8 @@ async def _handle_create(self, alert_state): await self.queue.recreate(status, incident_.uuid, incident_.chain) async def _handle_update(self, uuid_, incident_, alert_state): + config = get_config() + is_new_firing_alerts_added = False is_some_firing_alerts_removed = False prev_status = incident_.status @@ -86,10 +91,13 @@ async def _handle_update(self, uuid_, incident_, alert_state): await self.queue.recreate(alert_state.get('status'), uuid_, incident_.get_chain()) # Check new alerts firing or old alerts resolved - chain_recreate = experimental.get('recreate_chain', False) - if incident.get('alerts_firing_notifications') or chain_recreate: + chain_recreate = config.experimental.recreate_chain + alerts_firing_notifications = config.incident.alerts_firing_notifications + alerts_resolved_notifications = config.incident.alerts_resolved_notifications + + if alerts_firing_notifications or chain_recreate: is_new_firing_alerts_added = incident_.is_new_firing_alerts_added(alert_state) - if incident.get('alerts_resolved_notifications'): + if alerts_resolved_notifications: is_some_firing_alerts_removed = incident_.is_some_firing_alerts_removed(alert_state) is_status_updated, is_state_updated = incident_.update_state(alert_state) diff --git a/app/ui/table_config.py b/app/ui/table_config.py index 0da1eef..a46b863 100644 --- a/app/ui/table_config.py +++ b/app/ui/table_config.py @@ -1,4 +1,4 @@ -from config import ui_config +from app.config.config import get_config def get_all_ui_config(): @@ -23,6 +23,9 @@ def get_incident_table_config(): Returns: dict: The configuration for the incidents table. """ + config = get_config() + ui_config = config.ui_config + tabulator_config = [{ 'title': '', 'field': 'indicator', @@ -30,36 +33,37 @@ def get_incident_table_config(): 'headerSort': False }] - for field in ui_config.get('columns', []): - field_name = field['name'] - field_type = field.get('type') - field_visible = field.get('visible', True) - if field_type == 'link': - tabulator_config.append({ - 'title': field['header'], - 'field': field_name, - 'type': field_type, - 'urlField': f'{field_name}Url', - }) - tabulator_config.append({ - 'title': f'{field["header"]}Url', - 'field': f'{field_name}Url', - 'visible': False, - }) - elif field_type == 'datetime': - tabulator_config.append({ - 'title': field['header'], - 'field': field_name, - 'type': field_type, - 'formatType': field.get('format', 'relative'), - }) - else: - tabulator_config.append({ - 'title': field['header'], - 'field': field_name, - 'type': field_type, - 'visible': field_visible, - }) + if ui_config and ui_config.columns: + for field in ui_config.columns: + field_name = field.name + field_type = field.type or 'string' + field_visible = field.visible if field.visible is not None else True + if field_type == 'link': + tabulator_config.append({ + 'title': field.header, + 'field': field_name, + 'type': field_type, + 'urlField': f'{field_name}Url', + }) + tabulator_config.append({ + 'title': f'{field.header}Url', + 'field': f'{field_name}Url', + 'visible': False, + }) + elif field_type == 'datetime': + tabulator_config.append({ + 'title': field.header, + 'field': field_name, + 'type': field_type, + 'formatType': field.format or 'relative', + }) + else: + tabulator_config.append({ + 'title': field.header, + 'field': field_name, + 'type': field_type, + 'visible': field_visible, + }) return tabulator_config @@ -71,24 +75,28 @@ def get_incident_table_sorting(): Returns: list: A list of Tabulator-compatible sorting configurations. """ + config = get_config() + ui_config = config.ui_config + tabulator_sorting = [] - for rule in ui_config.get("sorting", []): - for column, sort_config in rule.items(): - if column == "order": - continue - - sorting_rule = {"column": column} - - if isinstance(sort_config, str): - if sort_config in ["asc", "desc"]: - sorting_rule["direction"] = sort_config - if "order" in rule: + if ui_config and ui_config.sorting: + for rule in ui_config.sorting: + for column, sort_config in rule.items(): + if column == "order": + continue + + sorting_rule = {"column": column} + + if isinstance(sort_config, str): + if sort_config in ["asc", "desc"]: + sorting_rule["direction"] = sort_config + if "order" in rule: + sorting_rule["order"] = rule["order"] + elif sort_config == "none" and "order" in rule: sorting_rule["order"] = rule["order"] - elif sort_config == "none" and "order" in rule: - sorting_rule["order"] = rule["order"] - - tabulator_sorting.append(sorting_rule) + + tabulator_sorting.append(sorting_rule) return tabulator_sorting @@ -99,7 +107,12 @@ def get_incident_table_colors(): Returns: dict: The colors for the incidents table. """ - return ui_config.get('colors', {}) + config = get_config() + ui_config = config.ui_config + + if ui_config and ui_config.colors: + return ui_config.colors + return {} def get_incident_table_filters(): """ @@ -108,4 +121,9 @@ def get_incident_table_filters(): Returns: dict: The filters for the incidents table. """ - return ui_config.get('filters', []) + config = get_config() + ui_config = config.ui_config + + if ui_config and ui_config.filters: + return ui_config.filters + return [] diff --git a/app/ui/websocket.py b/app/ui/websocket.py index e8d7e25..246655a 100644 --- a/app/ui/websocket.py +++ b/app/ui/websocket.py @@ -2,8 +2,8 @@ import json from typing import Set from fastapi import WebSocket -from config import ui_config from app.logging import logger +from app.config.config import get_config class AsyncIncidentWS: @@ -11,7 +11,7 @@ class AsyncIncidentWS: def __init__(self): self.connections: Set[WebSocket] = set() - self.table_config = ui_config + self.table_config = get_config().ui_config async def connect(self, websocket: WebSocket): """Accept a new WebSocket connection""" @@ -75,12 +75,12 @@ async def handle_request_data(self, websocket: WebSocket, incidents): def _get_values(self): """Get values mapping for table configuration""" values_map = {} - for field in self.table_config.get('columns', []): - if field.get('type') == 'link': - values_map[field['name']] = field['value'] - values_map[f'{field["name"]}Url'] = field['url'] + for field in self.table_config.columns: + if field.type == 'link': + values_map[field.name] = field.value + values_map[f'{field.name}Url'] = field.url else: - values_map[field['name']] = field['value'] + values_map[field.name] = field.value return values_map diff --git a/config.py b/config.py index 4339dcf..b67ea1c 100644 --- a/config.py +++ b/config.py @@ -1,43 +1,37 @@ -import os - -import yaml -from dotenv import load_dotenv - -load_dotenv() - -slack_bot_user_oauth_token = os.getenv('SLACK_BOT_USER_OAUTH_TOKEN') -slack_verification_token = os.getenv('SLACK_VERIFICATION_TOKEN') -mattermost_access_token = os.getenv('MATTERMOST_ACCESS_TOKEN') -telegram_bot_token = os.getenv('TELEGRAM_BOT_TOKEN') -data_path = os.getenv('DATA_PATH', default='./data') -config_path = os.getenv('CONFIG_PATH', default='./') -log_level = os.getenv('LOG_LEVEL', default='INFO') -provider_sync_interval = int(os.getenv('CHAIN_PROVIDER_SYNC_INTERVAL_SECONDS', default=60)) -provider_max_events = int(os.getenv('CHAIN_PROVIDER_MAX_EVENTS', default=10)) -provider_days_to_sync = int(os.getenv('CHAIN_PROVIDER_DAYS_TO_SYNC', default=7)) -provider_service_account_file = os.getenv('GOOGLE_SERVICE_ACCOUNT_FILE', default="./key.json") - -# Get CORS allowed origins from environment variable, default to localhost -cors_allowed_origins = os.getenv('CORS_ALLOWED_ORIGINS', 'http://localhost:5000').split(',') - -incidents_path = data_path + '/incidents' -INCIDENT_ACTUAL_VERSION = 'v0.4' - -with open(f'{config_path}/impulse.yml', 'r') as file: - try: - settings = yaml.safe_load(file) - - incident = dict() - incident['alerts_firing_notifications'] = settings.get('incident', {}).get('alerts_firing_notifications', False) - incident['alerts_resolved_notifications'] = settings.get('incident', {}).get('alerts_resolved_notifications', False) - incident['timeouts'] = dict() - incident['timeouts']['firing'] = settings.get('incident', {}).get('timeouts', {}).get('firing', '6h') - incident['timeouts']['unknown'] = settings.get('incident', {}).get('timeouts', {}).get('unknown', '6h') - incident['timeouts']['resolved'] = settings.get('incident', {}).get('timeouts', {}).get('resolved', '12h') - - experimental = settings.get('experimental', {}) - check_updates = True - application = settings.get('application') - ui_config = settings.get('ui', {}) - except yaml.YAMLError as e: - print(f"Error reading YAML file: {e}") +""" +Impulse Configuration Module + +This module provides backward-compatible access to both environment and application configuration. +Environment configuration (secrets, paths, etc.) is loaded from environment variables. +Application configuration (business logic) is loaded from YAML files and validated. +""" + +from app.config.config import get_config + +# Load the unified configuration +config = get_config() + +# Legacy environment variables (for backward compatibility) +slack_bot_user_oauth_token = config.slack_bot_user_oauth_token +slack_verification_token = config.slack_verification_token +mattermost_access_token = config.mattermost_access_token +telegram_bot_token = config.telegram_bot_token +data_path = config.data_path +config_path = config.config_path +provider_sync_interval = config.provider_sync_interval +provider_max_events = config.provider_max_events +provider_days_to_sync = config.provider_days_to_sync +provider_service_account_file = config.provider_service_account_file +cors_allowed_origins = config.cors_allowed_origins +incidents_path = config.incidents_path + +# Legacy application configuration (for backward compatibility) +settings = config.settings +incident = config.incident +experimental = config.experimental +application = config.application +ui_config = config.ui_config + +# Constants +INCIDENT_ACTUAL_VERSION = config.INCIDENT_ACTUAL_VERSION +check_updates = config.check_updates diff --git a/main.py b/main.py index 86510d6..95cf83d 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,8 @@ import asyncio import json +import argparse +import sys +import signal from contextlib import asynccontextmanager from datetime import datetime @@ -18,39 +21,83 @@ from app.ui.table_config import get_all_ui_config from app.ui.websocket import incident_ws from app.webhook import generate_webhooks -from config import settings, check_updates, application, ui_config +from app.config.config import get_config, reload_config + + +def setup_sighup_handler(): + """Setup only SIGHUP signal handler for configuration reloading, preserving other Uvicorn handlers""" + + def handle_sighup(signum, frame): + """Handle SIGHUP signal to reload configuration""" + try: + logger.info("Received SIGHUP signal, reloading configuration...") + success = reload_config() + if success: + logger.info("Configuration reload completed successfully") + except Exception as e: + logger.error(f"Error in SIGHUP signal handler: {e}") + logger.warning( + "Configuration reload aborted due to unexpected error, continuing with current configuration") + + if hasattr(signal, 'SIGHUP'): + signal.signal(signal.SIGHUP, handle_sighup) + logger.info("SIGHUP signal handler registered for configuration reloading (overriding Uvicorn)") + else: + logger.info("SIGHUP signal not available on this platform") + + +def validate_config_only(): + """Validate configuration and exit""" + try: + config = get_config() + logger.info("Configuration validation successful!\n" + f"Application type: {config.app.application.type}\n" + f"Channels configured: {len(config.app.application.channels)}\n" + f"Users configured: {len(config.app.application.users)}") + if config.app.incident: + logger.info(f"Incident config: Success") + if config.app.ui: + logger.info(f"UI config: Success") + if config.app.route: + logger.info(f"Route config: Success") + sys.exit(0) + except SystemExit as e: + if e.code != 0: + logger.error("Configuration validation failed!") + sys.exit(1) + except Exception as e: + logger.error(f"Configuration validation failed: {e}") + sys.exit(1) @asynccontextmanager async def lifespan(fastapi_app: FastAPI): """Manage application lifecycle""" - # Initialize components - route_dict = settings.get('route') - webhooks_dict = settings.get('webhooks') + config = get_config() + + route_dict = config.settings.get('route') + webhooks_dict = config.settings.get('webhooks') route = generate_route(route_dict) - channels = check_channels(route.get_uniq_channels(), application['channels'], route.channel) - messenger = get_application(application, channels, route.channel) + channels = check_channels(route.get_uniq_channels(), config.application.channels, route.channel) + messenger = get_application(config.application, channels, route.channel) await messenger.initialize_async() webhooks = generate_webhooks(webhooks_dict) incidents = Incidents.create_or_load(messenger.type, messenger.public_url, messenger.team) - # Create async queue and manager - queue = await AsyncQueue.recreate_queue(incidents, check_updates) + queue = await AsyncQueue.recreate_queue(incidents, config.check_updates) queue_manager = AsyncQueueManager(queue, messenger, incidents, webhooks, route) - # Attach to app state fastapi_app.state.queue = queue fastapi_app.state.queue_manager = queue_manager fastapi_app.state.incidents = incidents fastapi_app.state.messenger = messenger fastapi_app.state.webhooks = webhooks fastapi_app.state.route = route + fastapi_app.state.config = config - # Start background queue processing await queue_manager.start_processing() - # Start periodic update check asyncio.create_task(periodic_update_check(fastapi_app)) logger.info('IMPulse started!') @@ -60,11 +107,9 @@ async def lifespan(fastapi_app: FastAPI): if fastapi_app.state.queue_manager: await fastapi_app.state.queue_manager.stop_processing() - # Close HTTP session if hasattr(fastapi_app.state.messenger, 'close'): await fastapi_app.state.messenger.close() - # Cleanup chains if hasattr(fastapi_app.state.messenger, 'chains'): for chain in fastapi_app.state.messenger.chains.values(): if hasattr(chain, 'cleanup'): @@ -74,15 +119,14 @@ async def lifespan(fastapi_app: FastAPI): async def periodic_update_check(fastapi_app: FastAPI): - """Periodic task to check for updates (replaces APScheduler)""" + """Periodically check for updates""" while True: + await asyncio.sleep(24 * 60 * 60) try: - await asyncio.sleep(24 * 60 * 60) # 24 hours - await fastapi_app.state.queue.put(datetime.utcnow(), 'check_update', None, None) - except asyncio.CancelledError: - break + if hasattr(fastapi_app.state, 'queue'): + await fastapi_app.state.queue.check_for_updates() except Exception as e: - logger.error(f"Error in periodic update check: {e}") + logger.error(f"Error during periodic update check: {e}") app = FastAPI( @@ -94,7 +138,7 @@ async def periodic_update_check(fastapi_app: FastAPI): redoc_url=None ) -if ui_config: +if get_config().ui_config: app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @@ -134,7 +178,6 @@ async def handle_app_buttons(request: Request): else: payload = await request.json() - # Note: This needs to be made async in the messenger implementation return await request.app.state.messenger.buttons_handler( payload, request.app.state.incidents, @@ -165,7 +208,6 @@ async def websocket_endpoint(websocket: WebSocket): try: while True: - # Wait for messages from client data = await websocket.receive_text() try: @@ -184,15 +226,33 @@ async def websocket_endpoint(websocket: WebSocket): await incident_ws.disconnect(websocket) -if __name__ == "__main__": - import uvicorn +def parse_arguments(): + """Parse command line arguments""" + parser = argparse.ArgumentParser(description="IMPulse - Incident Management Platform") + parser.add_argument( + '--check', + action='store_true', + help='Validate configuration and exit' + ) + return parser.parse_args() - configure_uvicorn_logging() - uvicorn.run( - "main:app", - host="0.0.0.0", - port=5000, - reload=True, - log_level="warning" - ) +if __name__ == "__main__": + args = parse_arguments() + + if args.check: + validate_config_only() + else: + setup_sighup_handler() + + import uvicorn + + configure_uvicorn_logging() + + uvicorn.run( + "main:app", + host="0.0.0.0", + port=5000, + reload=True, + log_level="warning" + ) diff --git a/requirements.txt b/requirements.txt index 5eca87c..dabffb1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ requests~=2.32.3 aiohttp~=3.12.10 aiohttp-retry~=2.8.3 Jinja2~=3.1.3 +pydantic~=2.11.7 PyYAML~=6.0.1 python-dotenv~=1.0.1 python-multipart~=0.0.20