diff --git a/database/repositories/order_repository.py b/database/repositories/order_repository.py index 3bf7ee21..27acdc7f 100644 --- a/database/repositories/order_repository.py +++ b/database/repositories/order_repository.py @@ -111,9 +111,9 @@ async def get_orders(self, account_name: Optional[str] = None, async def get_active_orders(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, trading_pair: Optional[str] = None) -> List[Order]: - """Get active orders (SUBMITTED, OPEN, PARTIALLY_FILLED).""" + """Get active orders (SUBMITTED, OPEN, PARTIALLY_FILLED, PENDING_CANCEL).""" query = select(Order).where( - Order.status.in_(["SUBMITTED", "OPEN", "PARTIALLY_FILLED"]) + Order.status.in_(["SUBMITTED", "OPEN", "PARTIALLY_FILLED", "PENDING_CANCEL"]) ) # Apply filters diff --git a/main.py b/main.py index e668dab7..80c18105 100644 --- a/main.py +++ b/main.py @@ -24,13 +24,15 @@ def patched_save_to_yml(yml_path, cm): from hummingbot.client.config import config_helpers config_helpers.save_to_yml = patched_save_to_yml -from hummingbot.core.rate_oracle.rate_oracle import RateOracle +from hummingbot.core.rate_oracle.rate_oracle import RateOracle, RATE_ORACLE_SOURCES from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient from hummingbot.client.config.client_config_map import GatewayConfigMap -from fastapi import Depends, FastAPI, HTTPException, status +from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.middleware.cors import CORSMiddleware +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse from hummingbot.data_feed.market_data_provider import MarketDataProvider from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger @@ -40,6 +42,7 @@ def patched_save_to_yml(yml_path, cm): from services.docker_service import DockerService from services.gateway_service import GatewayService from services.market_data_feed_manager import MarketDataFeedManager +# from services.executor_service import ExecutorService from utils.bot_archiver import BotArchiver from routers import ( accounts, @@ -49,11 +52,13 @@ def patched_save_to_yml(yml_path, cm): connectors, controllers, docker, + # executors, gateway, gateway_swap, gateway_clmm, market_data, portfolio, + rate_oracle, scripts, trading ) @@ -107,10 +112,47 @@ async def lifespan(app: FastAPI): # Initialize MarketDataProvider with empty connectors (will use non-trading connectors) market_data_provider = MarketDataProvider(connectors={}) + # Read rate oracle configuration from conf_client.yml + from utils.file_system import FileSystemUtil + fs_util = FileSystemUtil() + + try: + conf_client_path = "credentials/master_account/conf_client.yml" + config_data = fs_util.read_yaml_file(conf_client_path) + + # Get rate_oracle_source configuration + rate_oracle_source_data = config_data.get("rate_oracle_source", {}) + source_name = rate_oracle_source_data.get("name", "binance") + + # Get global_token configuration + global_token_data = config_data.get("global_token", {}) + quote_token = global_token_data.get("global_token_name", "USDT") + + # Create rate source instance + if source_name in RATE_ORACLE_SOURCES: + rate_source = RATE_ORACLE_SOURCES[source_name]() + logging.info(f"Configured RateOracle with source: {source_name}, quote_token: {quote_token}") + else: + logging.warning(f"Unknown rate oracle source '{source_name}', defaulting to binance") + rate_source = RATE_ORACLE_SOURCES["binance"]() + source_name = "binance" + + # Initialize RateOracle with configured source and quote token + rate_oracle_instance = RateOracle.get_instance() + rate_oracle_instance.source = rate_source + rate_oracle_instance.quote_token = quote_token + + except FileNotFoundError: + logging.warning("conf_client.yml not found, using default RateOracle configuration (binance, USDT)") + rate_oracle_instance = RateOracle.get_instance() + except Exception as e: + logging.warning(f"Error reading conf_client.yml: {e}, using default RateOracle configuration") + rate_oracle_instance = RateOracle.get_instance() + # Initialize MarketDataFeedManager with lifecycle management market_data_feed_manager = MarketDataFeedManager( market_data_provider=market_data_provider, - rate_oracle=RateOracle.get_instance(), + rate_oracle=rate_oracle_instance, cleanup_interval=settings.market_data.cleanup_interval, feed_timeout=settings.market_data.feed_timeout ) @@ -139,6 +181,18 @@ async def lifespan(app: FastAPI): # Initialize database await accounts_service.ensure_db_initialized() + # # Initialize ExecutorService for running executors directly via API + # executor_service = ExecutorService( + # connector_manager=accounts_service.connector_manager, + # market_data_feed_manager=market_data_feed_manager, + # db_manager=accounts_service.db_manager, + # default_account="master_account", + # update_interval=1.0, + # max_retries=10 + # ) + # # Store reference in accounts_service for router access + # accounts_service._executor_service = executor_service + # Store services in app state app.state.bots_orchestrator = bots_orchestrator app.state.accounts_service = accounts_service @@ -146,11 +200,13 @@ async def lifespan(app: FastAPI): app.state.gateway_service = gateway_service app.state.bot_archiver = bot_archiver app.state.market_data_feed_manager = market_data_feed_manager + # app.state.executor_service = executor_service # Start services bots_orchestrator.start() accounts_service.start() market_data_feed_manager.start() + # executor_service.start() yield @@ -158,6 +214,9 @@ async def lifespan(app: FastAPI): bots_orchestrator.stop() await accounts_service.stop() + # Stop executor service + # await executor_service.stop() + # Stop market data feed manager (which will stop all feeds) market_data_feed_manager.stop() @@ -185,6 +244,31 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """ + Custom handler for validation errors to log detailed error messages. + """ + # Build a readable error message from validation errors + error_messages = [] + for error in exc.errors(): + loc = " -> ".join(str(l) for l in error.get("loc", [])) + msg = error.get("msg", "Validation error") + error_messages.append(f"{loc}: {msg}") + + # Log the validation error with details + logging.warning( + f"Validation error on {request.method} {request.url.path}: {'; '.join(error_messages)}" + ) + + # Return standard FastAPI validation error response + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={"detail": exc.errors()}, + ) + + logfire.configure(send_to_logfire="if-token-present", environment=settings.app.logfire_environment, service_name="hummingbot-api") logfire.instrument_fastapi(app) @@ -223,8 +307,10 @@ def auth_user( app.include_router(controllers.router, dependencies=[Depends(auth_user)]) app.include_router(scripts.router, dependencies=[Depends(auth_user)]) app.include_router(market_data.router, dependencies=[Depends(auth_user)]) +app.include_router(rate_oracle.router, dependencies=[Depends(auth_user)]) app.include_router(backtesting.router, dependencies=[Depends(auth_user)]) app.include_router(archived_bots.router, dependencies=[Depends(auth_user)]) +# app.include_router(executors.router, dependencies=[Depends(auth_user)]) @app.get("/") async def root(): diff --git a/models/__init__.py b/models/__init__.py index bb64637b..b0d04fb4 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -191,6 +191,20 @@ ExecutorsResponse, ) +# Rate Oracle models +from .rate_oracle import ( + RateOracleSourceEnum, + GlobalTokenConfig, + RateOracleSourceConfig, + RateOracleConfig, + RateOracleConfigResponse, + RateOracleConfigUpdateRequest, + RateOracleConfigUpdateResponse, + RateRequest, + RateResponse, + SingleRateResponse, +) + __all__ = [ # Bot orchestration models "BotAction", @@ -338,4 +352,15 @@ "TradeHistoryResponse", "OrderHistoryResponse", "ExecutorsResponse", + # Rate Oracle models + "RateOracleSourceEnum", + "GlobalTokenConfig", + "RateOracleSourceConfig", + "RateOracleConfig", + "RateOracleConfigResponse", + "RateOracleConfigUpdateRequest", + "RateOracleConfigUpdateResponse", + "RateRequest", + "RateResponse", + "SingleRateResponse", ] \ No newline at end of file diff --git a/models/rate_oracle.py b/models/rate_oracle.py new file mode 100644 index 00000000..3aaa2205 --- /dev/null +++ b/models/rate_oracle.py @@ -0,0 +1,114 @@ +""" +Pydantic models for the rate oracle router. + +These models define the request/response schemas for rate oracle configuration endpoints. +""" + +from typing import Optional, List, Dict +from enum import Enum +from pydantic import BaseModel, Field + + +class RateOracleSourceEnum(str, Enum): + """Available rate oracle sources.""" + BINANCE = "binance" + BINANCE_US = "binance_us" + COIN_GECKO = "coin_gecko" + COIN_CAP = "coin_cap" + KUCOIN = "kucoin" + ASCEND_EX = "ascend_ex" + GATE_IO = "gate_io" + COINBASE_ADVANCED_TRADE = "coinbase_advanced_trade" + CUBE = "cube" + DEXALOT = "dexalot" + HYPERLIQUID = "hyperliquid" + DERIVE = "derive" + TEGRO = "tegro" + + +class GlobalTokenConfig(BaseModel): + """Global token configuration for displaying values.""" + global_token_name: str = Field( + default="USDT", + description="The token to use as global quote (e.g., USDT, USD, BTC)" + ) + global_token_symbol: str = Field( + default="$", + description="Symbol to display for the global token" + ) + + +class RateOracleSourceConfig(BaseModel): + """Rate oracle source configuration.""" + name: RateOracleSourceEnum = Field( + default=RateOracleSourceEnum.BINANCE, + description="The rate oracle source to use for price data" + ) + + +class RateOracleConfig(BaseModel): + """Complete rate oracle configuration.""" + rate_oracle_source: RateOracleSourceConfig = Field( + default_factory=RateOracleSourceConfig, + description="Rate oracle source configuration" + ) + global_token: GlobalTokenConfig = Field( + default_factory=GlobalTokenConfig, + description="Global token configuration" + ) + + +class RateOracleConfigResponse(BaseModel): + """Response for rate oracle configuration GET endpoint.""" + rate_oracle_source: RateOracleSourceConfig = Field( + description="Current rate oracle source configuration" + ) + global_token: GlobalTokenConfig = Field( + description="Current global token configuration" + ) + available_sources: List[str] = Field( + description="List of available rate oracle sources" + ) + + +class RateOracleConfigUpdateRequest(BaseModel): + """Request model for updating rate oracle configuration.""" + rate_oracle_source: Optional[RateOracleSourceConfig] = Field( + default=None, + description="New rate oracle source configuration (optional)" + ) + global_token: Optional[GlobalTokenConfig] = Field( + default=None, + description="New global token configuration (optional)" + ) + + +class RateOracleConfigUpdateResponse(BaseModel): + """Response for rate oracle configuration update.""" + success: bool = Field(description="Whether the update was successful") + message: str = Field(description="Status message") + config: RateOracleConfig = Field(description="Updated configuration") + + +class RateRequest(BaseModel): + """Request for getting rates.""" + trading_pairs: List[str] = Field( + description="List of trading pairs to get rates for (e.g., ['BTC-USDT', 'ETH-USDT'])" + ) + + +class RateResponse(BaseModel): + """Response containing rates for trading pairs.""" + source: str = Field(description="Rate oracle source used") + quote_token: str = Field(description="Quote token used") + rates: Dict[str, Optional[float]] = Field( + description="Mapping of trading pairs to their rates (None if rate not found)" + ) + + +class SingleRateResponse(BaseModel): + """Response for a single trading pair rate.""" + trading_pair: str = Field(description="The trading pair") + rate: Optional[float] = Field(description="The rate (None if not found)") + source: str = Field(description="Rate oracle source used") + quote_token: str = Field(description="Quote token used") diff --git a/routers/connectors.py b/routers/connectors.py index 482f9b10..3c59b87f 100644 --- a/routers/connectors.py +++ b/routers/connectors.py @@ -1,12 +1,12 @@ -from typing import List, Optional, Dict +from typing import Dict, List, Optional -from fastapi import APIRouter, Depends, Request, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query, Request from hummingbot.client.settings import AllConnectorSettings -from services.accounts_service import AccountsService -from services.market_data_feed_manager import MarketDataFeedManager from deps import get_accounts_service from models import AddTokenRequest +from services.accounts_service import AccountsService +from services.market_data_feed_manager import MarketDataFeedManager router = APIRouter(tags=["Connectors"], prefix="/connectors") @@ -22,16 +22,19 @@ async def available_connectors(): return list(AllConnectorSettings.get_connector_settings().keys()) -@router.get("/{connector_name}/config-map", response_model=List[str]) +@router.get("/{connector_name}/config-map", response_model=Dict[str, dict]) async def get_connector_config_map(connector_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): """ - Get configuration fields required for a specific connector. - + Get configuration fields required for a specific connector with type information. + Args: connector_name: Name of the connector to get config map for - + Returns: - List of configuration field names required for the connector + Dictionary mapping field names to their type information. + Each field contains: + - type: The expected data type (e.g., "str", "SecretStr", "int") + - required: Whether the field is required """ return accounts_service.get_connector_config_map(connector_name) diff --git a/routers/rate_oracle.py b/routers/rate_oracle.py new file mode 100644 index 00000000..64df979d --- /dev/null +++ b/routers/rate_oracle.py @@ -0,0 +1,332 @@ +""" +Rate Oracle router for managing rate oracle configuration and retrieving rates. + +Provides CRUD endpoints for rate_oracle_source and global_token configuration, +with persistence to conf_client.yml. +""" + +from typing import List +from decimal import Decimal + +from fastapi import APIRouter, Request, HTTPException +from hummingbot.core.rate_oracle.rate_oracle import RateOracle, RATE_ORACLE_SOURCES + +from models.rate_oracle import ( + RateOracleConfig, + RateOracleConfigResponse, + RateOracleConfigUpdateRequest, + RateOracleConfigUpdateResponse, + RateOracleSourceConfig, + GlobalTokenConfig, + RateRequest, + RateResponse, + SingleRateResponse, +) +from utils.file_system import FileSystemUtil + +router = APIRouter(tags=["Rate Oracle"], prefix="/rate-oracle") + +# Path to conf_client.yml relative to the FileSystemUtil base_path ("bots") +CONF_CLIENT_PATH = "credentials/master_account/conf_client.yml" + + +def get_rate_oracle(request: Request) -> RateOracle: + """Get RateOracle instance from the market data feed manager.""" + return request.app.state.market_data_feed_manager.rate_oracle + + +def get_file_system_util() -> FileSystemUtil: + """Get FileSystemUtil instance.""" + return FileSystemUtil() + + +@router.get("/sources", response_model=List[str]) +async def get_available_sources(): + """ + Get list of all available rate oracle sources. + + Returns: + List of available source names that can be configured + """ + return list(RATE_ORACLE_SOURCES.keys()) + + +@router.get("/config", response_model=RateOracleConfigResponse) +async def get_rate_oracle_config(request: Request): + """ + Get current rate oracle configuration. + + Returns the current rate_oracle_source and global_token settings, + along with the list of available sources. + + Returns: + Current rate oracle configuration and available sources + """ + try: + fs_util = get_file_system_util() + + # Read current config from file + config_data = fs_util.read_yaml_file(CONF_CLIENT_PATH) + + # Extract rate_oracle_source + rate_oracle_source_data = config_data.get("rate_oracle_source", {}) + source_name = rate_oracle_source_data.get("name", "binance") + + # Extract global_token + global_token_data = config_data.get("global_token", {}) + global_token_name = global_token_data.get("global_token_name", "USDT") + global_token_symbol = global_token_data.get("global_token_symbol", "$") + + return RateOracleConfigResponse( + rate_oracle_source=RateOracleSourceConfig(name=source_name), + global_token=GlobalTokenConfig( + global_token_name=global_token_name, + global_token_symbol=global_token_symbol + ), + available_sources=list(RATE_ORACLE_SOURCES.keys()) + ) + + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Configuration file not found: {CONF_CLIENT_PATH}" + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error reading configuration: {str(e)}" + ) + + +@router.put("/config", response_model=RateOracleConfigUpdateResponse) +async def update_rate_oracle_config( + request: Request, + update_request: RateOracleConfigUpdateRequest +): + """ + Update rate oracle configuration. + + Updates rate_oracle_source and/or global_token settings. Changes are: + 1. Applied to the running RateOracle instance immediately + 2. Persisted to conf_client.yml + + Args: + update_request: Configuration updates to apply + + Returns: + Updated configuration with success status + """ + try: + fs_util = get_file_system_util() + rate_oracle = get_rate_oracle(request) + + # Read current config + config_data = fs_util.read_yaml_file(CONF_CLIENT_PATH) + + # Track if we made changes + changes_made = [] + + # Update rate_oracle_source if provided + if update_request.rate_oracle_source is not None: + new_source_name = update_request.rate_oracle_source.name.value + + # Validate source exists + if new_source_name not in RATE_ORACLE_SOURCES: + raise HTTPException( + status_code=400, + detail=f"Invalid rate oracle source: {new_source_name}. " + f"Available sources: {list(RATE_ORACLE_SOURCES.keys())}" + ) + + # Update config data + if "rate_oracle_source" not in config_data: + config_data["rate_oracle_source"] = {} + config_data["rate_oracle_source"]["name"] = new_source_name + + # Update running RateOracle instance + new_source_class = RATE_ORACLE_SOURCES[new_source_name] + rate_oracle.source = new_source_class() + + changes_made.append(f"rate_oracle_source updated to {new_source_name}") + + # Update global_token if provided + if update_request.global_token is not None: + if "global_token" not in config_data: + config_data["global_token"] = {} + + if update_request.global_token.global_token_name is not None: + config_data["global_token"]["global_token_name"] = update_request.global_token.global_token_name + # Update RateOracle quote token + rate_oracle.quote_token = update_request.global_token.global_token_name + changes_made.append(f"global_token_name updated to {update_request.global_token.global_token_name}") + + if update_request.global_token.global_token_symbol is not None: + config_data["global_token"]["global_token_symbol"] = update_request.global_token.global_token_symbol + changes_made.append(f"global_token_symbol updated to {update_request.global_token.global_token_symbol}") + + # Persist changes to file + if changes_made: + fs_util.dump_dict_to_yaml(CONF_CLIENT_PATH, config_data) + + # Build response + current_source = config_data.get("rate_oracle_source", {}).get("name", "binance") + current_global_token = config_data.get("global_token", {}) + + return RateOracleConfigUpdateResponse( + success=True, + message="; ".join(changes_made) if changes_made else "No changes made", + config=RateOracleConfig( + rate_oracle_source=RateOracleSourceConfig(name=current_source), + global_token=GlobalTokenConfig( + global_token_name=current_global_token.get("global_token_name", "USDT"), + global_token_symbol=current_global_token.get("global_token_symbol", "$") + ) + ) + ) + + except HTTPException: + raise + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Configuration file not found: {CONF_CLIENT_PATH}" + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error updating configuration: {str(e)}" + ) + + +@router.post("/rates", response_model=RateResponse) +async def get_rates(request: Request, rate_request: RateRequest): + """ + Get rates for specified trading pairs. + + Uses the configured rate oracle source to fetch current rates. + + Args: + rate_request: List of trading pairs to get rates for + + Returns: + Rates for the requested trading pairs + """ + try: + rate_oracle = get_rate_oracle(request) + + rates = {} + for pair in rate_request.trading_pairs: + try: + rate = rate_oracle.get_pair_rate(pair) + rates[pair] = float(rate) if rate and rate != Decimal("0") else None + except Exception: + rates[pair] = None + + return RateResponse( + source=rate_oracle.source.name, + quote_token=rate_oracle.quote_token, + rates=rates + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching rates: {str(e)}" + ) + + +@router.get("/rate/{trading_pair}", response_model=SingleRateResponse) +async def get_single_rate(request: Request, trading_pair: str): + """ + Get rate for a single trading pair. + + Args: + trading_pair: Trading pair in format BASE-QUOTE (e.g., BTC-USDT) + + Returns: + Rate for the specified trading pair + """ + try: + rate_oracle = get_rate_oracle(request) + + rate = rate_oracle.get_pair_rate(trading_pair) + rate_value = float(rate) if rate and rate != Decimal("0") else None + + return SingleRateResponse( + trading_pair=trading_pair, + rate=rate_value, + source=rate_oracle.source.name, + quote_token=rate_oracle.quote_token + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching rate for {trading_pair}: {str(e)}" + ) + + +@router.get("/rate-async/{trading_pair}", response_model=SingleRateResponse) +async def get_rate_async(request: Request, trading_pair: str): + """ + Get rate for a trading pair using async fetch (direct from exchange). + + This bypasses the cached prices and fetches directly from the source. + Useful when cached data may be stale or not yet initialized. + + Args: + trading_pair: Trading pair in format BASE-QUOTE (e.g., BTC-USDT) + + Returns: + Rate for the specified trading pair + """ + try: + rate_oracle = get_rate_oracle(request) + + rate = await rate_oracle.rate_async(trading_pair) + rate_value = float(rate) if rate and rate != Decimal("0") else None + + return SingleRateResponse( + trading_pair=trading_pair, + rate=rate_value, + source=rate_oracle.source.name, + quote_token=rate_oracle.quote_token + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching async rate for {trading_pair}: {str(e)}" + ) + + +@router.get("/prices") +async def get_cached_prices(request: Request): + """ + Get all cached prices from the rate oracle. + + Returns the complete price dictionary that the rate oracle has fetched + from its configured source. + + Returns: + Dictionary of all cached prices + """ + try: + rate_oracle = get_rate_oracle(request) + + prices = rate_oracle.prices + # Convert Decimal to float for JSON serialization + float_prices = {pair: float(price) for pair, price in prices.items()} + + return { + "source": rate_oracle.source.name, + "quote_token": rate_oracle.quote_token, + "prices_count": len(float_prices), + "prices": float_prices + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching cached prices: {str(e)}" + ) diff --git a/routers/trading.py b/routers/trading.py index adf8e700..8dc6220f 100644 --- a/routers/trading.py +++ b/routers/trading.py @@ -709,7 +709,7 @@ def _standardize_in_flight_order_response(order, account_name: str, connector_na status_mapping = { OrderState.PENDING_CREATE: "SUBMITTED", OrderState.OPEN: "OPEN", - OrderState.PENDING_CANCEL: "OPEN", # Still open until cancelled + OrderState.PENDING_CANCEL: "PENDING_CANCEL", # Cancellation in progress OrderState.CANCELED: "CANCELLED", OrderState.PARTIALLY_FILLED: "PARTIALLY_FILLED", OrderState.FILLED: "FILLED", diff --git a/services/accounts_service.py b/services/accounts_service.py index e12a98ee..82a1e6ae 100644 --- a/services/accounts_service.py +++ b/services/accounts_service.py @@ -59,9 +59,11 @@ def __init__(self, self.secrets_manager = ETHKeyFileSecretManger(settings.security.config_password) self.accounts_state = {} self.update_account_state_interval = account_update_interval * 60 + self.order_status_poll_interval = 60 # Poll order status every 1 minute self.default_quote = default_quote self.market_data_feed_manager = market_data_feed_manager self._update_account_state_task: Optional[asyncio.Task] = None + self._order_status_polling_task: Optional[asyncio.Task] = None # Database setup for account states and orders self.db_manager = AsyncDatabaseManager(settings.database.url) @@ -108,6 +110,10 @@ def start(self): # Start the update loop which will call check_all_connectors self._update_account_state_task = asyncio.create_task(self.update_account_state_loop()) + # Start order status polling loop (every 1 minute) + self._order_status_polling_task = asyncio.create_task(self.order_status_polling_loop()) + logger.info("Order status polling started (1 minute interval)") + # Start Gateway transaction poller if not self._gateway_poller_started: asyncio.create_task(self._start_gateway_poller()) @@ -135,6 +141,12 @@ async def stop(self): self._update_account_state_task = None logger.info("Stopped account state update loop") + # Stop the order status polling loop + if self._order_status_polling_task: + self._order_status_polling_task.cancel() + self._order_status_polling_task = None + logger.info("Stopped order status polling loop") + # Stop Gateway transaction poller if self._gateway_poller_started: try: @@ -167,6 +179,21 @@ async def update_account_state_loop(self): finally: await asyncio.sleep(self.update_account_state_interval) + async def order_status_polling_loop(self): + """ + Sync order state to database for all connectors at a frequent interval (1 minute). + + The connector's built-in _lost_orders_update_polling_loop already polls the exchange. + This loop just syncs that state to our database and cleans up closed orders. + """ + while True: + try: + await self.connector_manager.sync_order_state_to_database_for_all_connectors() + except Exception as e: + logger.error(f"Error syncing order state to database: {e}") + finally: + await asyncio.sleep(self.order_status_poll_interval) + async def dump_account_state(self): """ Save the current account state to the database. diff --git a/utils/connector_manager.py b/utils/connector_manager.py index 5b8e2280..d9ca6e29 100644 --- a/utils/connector_manager.py +++ b/utils/connector_manager.py @@ -105,13 +105,63 @@ def clear_cache(self, account_name: Optional[str] = None, connector_name: Option @staticmethod def get_connector_config_map(connector_name: str): """ - Get the connector config map for the specified connector. + Get the connector config map for the specified connector with type information. :param connector_name: The name of the connector. - :return: The connector config map. + :return: Dictionary mapping field names to their type information. """ + from typing import Literal, get_args, get_origin + connector_config = HummingbotAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) - return [key for key in connector_config.hb_config.__fields__.keys() if key != "connector"] + fields_info = {} + + for key, field in connector_config.hb_config.model_fields.items(): + if key == "connector": + continue + + # Get the type annotation + field_type = field.annotation + type_name = getattr(field_type, "__name__", str(field_type)) + allowed_values = None + + # Handle Optional and Literal types + origin = get_origin(field_type) + args = get_args(field_type) + + if origin is Literal: + # It's a Literal type, extract the allowed values + type_name = "Literal" + allowed_values = list(args) + elif origin is not None: + # Handle Union types (Optional is Union[X, None]) + if type(None) in args: + # It's an Optional type, get the actual type + actual_types = [arg for arg in args if arg is not type(None)] + if actual_types: + inner_type = actual_types[0] + inner_origin = get_origin(inner_type) + inner_args = get_args(inner_type) + + if inner_origin is Literal: + # Optional[Literal[...]] + type_name = "Literal" + allowed_values = list(inner_args) + else: + type_name = getattr(inner_type, "__name__", str(inner_type)) + else: + type_name = str(field_type) + + field_info = { + "type": type_name, + "required": field.is_required(), + } + + if allowed_values is not None: + field_info["allowed_values"] = allowed_values + + fields_info[key] = field_info + + return fields_info async def update_connector_keys(self, account_name: str, connector_name: str, keys: dict): """ @@ -122,7 +172,8 @@ async def update_connector_keys(self, account_name: str, connector_name: str, ke :param keys: Dictionary of API keys to update. :return: The updated connector instance. """ - BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager) + if not BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager): + raise ValueError(f"Failed to authenticate for account '{account_name}'. Password validation failed.") connector_config = HummingbotAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) for key, value in keys.items(): @@ -240,7 +291,7 @@ async def _create_and_initialize_connector(self, account_name: str, connector_na await self._start_connector_network(connector) # Perform initial update of connector state - await self._update_connector_state(connector, connector_name) + await self._update_connector_state(connector, connector_name, account_name) logger.info(f"Initialized connector {connector_name} for account {account_name}") return connector @@ -313,10 +364,14 @@ async def _stop_connector_network(self, connector: ConnectorBase): except Exception as e: logger.error(f"Error stopping connector network: {e}") - async def _update_connector_state(self, connector: ConnectorBase, connector_name: str): + async def _update_connector_state(self, connector: ConnectorBase, connector_name: str, account_name: str = None): """ Update connector state including balances, orders, positions, and trading rules. This function can be called both during initialization and periodically. + + :param connector: The connector instance + :param connector_name: The name of the connector + :param account_name: The name of the account (optional, used for order sync) """ try: # Update current timestamp @@ -324,18 +379,22 @@ async def _update_connector_state(self, connector: ConnectorBase, connector_name # Update balances await connector._update_balances() - + # Update trading rules await connector._update_trading_rules() - + # Update positions for perpetual connectors if "_perpetual" in connector_name: await connector._update_positions() - + # Update order status for in-flight orders if hasattr(connector, '_update_order_status') and connector.in_flight_orders: await connector._update_order_status() - + + # Sync updated order state to database and cleanup closed orders + if account_name: + await self._sync_orders_to_database(connector, account_name, connector_name) + logger.debug(f"Updated connector state for {connector_name}") except Exception as e: @@ -349,10 +408,32 @@ async def update_all_connector_states(self): for cache_key, connector in self._connector_cache.items(): account_name, connector_name = cache_key.split(":", 1) try: - await self._update_connector_state(connector, connector_name) + await self._update_connector_state(connector, connector_name, account_name) except Exception as e: logger.error(f"Error updating state for {account_name}/{connector_name}: {e}") + async def sync_order_state_to_database_for_all_connectors(self): + """ + Sync connector's in_flight_orders state to database for all connectors. + + The connector's built-in _lost_orders_update_polling_loop already polls the exchange + and updates in_flight_orders. This method just syncs that state to our database + and cleans up closed orders. Called every minute. + """ + for cache_key, connector in self._connector_cache.items(): + account_name, connector_name = cache_key.split(":", 1) + try: + # Only process if there are in-flight orders + if not connector.in_flight_orders: + continue + + # Sync connector state to database and cleanup closed orders + await self._sync_orders_to_database(connector, account_name, connector_name) + logger.debug(f"Synced order state to DB for {account_name}/{connector_name}") + + except Exception as e: + logger.error(f"Error syncing order state for {account_name}/{connector_name}: {e}") + async def _load_existing_orders_from_database(self, connector: ConnectorBase, account_name: str, connector_name: str): """ Load existing active orders from database and add them to connector's in_flight_orders. @@ -395,6 +476,85 @@ async def _load_existing_orders_from_database(self, connector: ConnectorBase, ac except Exception as e: logger.error(f"Error loading existing orders from database for {account_name}/{connector_name}: {e}") + def _map_order_state_to_status(self, order_state: OrderState) -> str: + """ + Map Hummingbot OrderState to database status string. + + :param order_state: The OrderState enum value from Hummingbot + :return: Database status string + """ + status_mapping = { + OrderState.PENDING_CREATE: "SUBMITTED", + OrderState.OPEN: "OPEN", + OrderState.PENDING_CANCEL: "PENDING_CANCEL", + OrderState.CANCELED: "CANCELLED", + OrderState.PARTIALLY_FILLED: "PARTIALLY_FILLED", + OrderState.FILLED: "FILLED", + OrderState.FAILED: "FAILED", + OrderState.PENDING_APPROVAL: "SUBMITTED", + OrderState.APPROVED: "SUBMITTED", + OrderState.CREATED: "SUBMITTED", + OrderState.COMPLETED: "FILLED", + } + return status_mapping.get(order_state, "SUBMITTED") + + async def _sync_orders_to_database(self, connector: ConnectorBase, account_name: str, connector_name: str): + """ + Sync connector's in_flight_orders state to database and cleanup closed orders. + + This method ensures that the database reflects the current state of orders + as reported by the exchange, and removes terminal orders from in_flight_orders. + + :param connector: The connector instance + :param account_name: The name of the account + :param connector_name: The name of the connector + """ + if not self.db_manager: + return + + terminal_states = [OrderState.FILLED, OrderState.CANCELED, OrderState.FAILED, OrderState.COMPLETED] + orders_to_remove = [] + + # Create a copy of keys to iterate safely while potentially modifying the dict + order_ids = list(connector.in_flight_orders.keys()) + + for client_order_id in order_ids: + order = connector.in_flight_orders.get(client_order_id) + if not order: + continue + + try: + # Import OrderRepository dynamically to avoid circular imports + from database import OrderRepository + + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + db_order = await order_repo.get_order_by_client_id(client_order_id) + + if db_order: + # Map connector state to database status + new_status = self._map_order_state_to_status(order.current_state) + + # Only update if status changed + if db_order.status != new_status: + await order_repo.update_order_status(client_order_id, new_status) + logger.info(f"Synced order {client_order_id} status: {db_order.status} -> {new_status}") + + # Mark terminal orders for removal from in_flight_orders + if order.current_state in terminal_states: + orders_to_remove.append(client_order_id) + + except Exception as e: + logger.error(f"Error syncing order {client_order_id} to database: {e}") + + # Remove terminal orders from in_flight_orders + for order_id in orders_to_remove: + connector.in_flight_orders.pop(order_id, None) + logger.debug(f"Removed closed order {order_id} from in_flight_orders") + + if orders_to_remove: + logger.info(f"Cleaned up {len(orders_to_remove)} terminal orders from {account_name}/{connector_name}") + def _convert_db_order_to_in_flight_order(self, order_record) -> InFlightOrder: """ Convert a database Order record to a Hummingbot InFlightOrder object. diff --git a/utils/security.py b/utils/security.py index c64bd823..2395a5b9 100644 --- a/utils/security.py +++ b/utils/security.py @@ -11,8 +11,8 @@ from hummingbot.client.config.security import Security from config import settings -from utils.hummingbot_api_config_adapter import HummingbotAPIConfigAdapter from utils.file_system import fs_util +from utils.hummingbot_api_config_adapter import HummingbotAPIConfigAdapter class BackendAPISecurity(Security): @@ -21,6 +21,9 @@ def login_account(cls, account_name: str, secrets_manager: BaseSecretsManager) - if not cls.validate_password(secrets_manager): return False cls.secrets_manager = secrets_manager + # Also set on parent Security class for hummingbot's ClientConfigAdapter methods + # that access Security.secrets_manager directly + Security.secrets_manager = secrets_manager cls.decrypt_all(account_name=account_name) return True